Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mori360 committed Feb 21, 2025
1 parent 9ad9ed0 commit 183d58a
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import sys
from dataclasses import dataclass, replace
from enum import auto, Enum
from enum import Enum, auto
from typing import Any, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -553,8 +553,8 @@ def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tenso
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
input_tensor.numel() % block_size
) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
(input_tensor.numel() % block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"

n_blocks = input_tensor.numel() // block_size
blocks = input_tensor.view(n_blocks, block_size)
Expand Down Expand Up @@ -724,8 +724,8 @@ def double_quantize_scalers(
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
input_tensor.numel() % scaler_block_size
) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
(input_tensor.numel() % scaler_block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"

# First round of quantization
# Produces: A tensor of size (n_blocks) of input_tensor.dtype
Expand Down Expand Up @@ -778,8 +778,8 @@ def dequantize_scalers(
"""
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
input_tensor.numel() % scaler_block_size
) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
(input_tensor.numel() % scaler_block_size) == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
n_scaler_blocks = input_tensor.numel() // scaler_block_size
input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
Expand Down

0 comments on commit 183d58a

Please sign in to comment.