Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spelling fixes #662

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def test_nf4_bnb_linear(self, dtype: torch.dtype):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_state_dicts(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
input_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)

dummy_dict = {"param": inpt_tensor}
dummy_dict = {"param": input_tensor}
base_mod.load_state_dict(dummy_dict)

assert base_mod.param.block_size == 32
Expand All @@ -170,12 +170,12 @@ def test_load_from_state_dicts(self, dtype: torch.dtype):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
input_tensor = torch.rand(64, device='cuda', dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)

other_mod = self.TestMod(inpt_tensor, 32, 2)
other_mod = self.TestMod(input_tensor, 32, 2)
other_mod.load_state_dict(torch.load(saved_state_dict))
assert other_mod.param.block_size == 32
assert other_mod.param.scaler_block_size == 2
Expand All @@ -184,50 +184,50 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
inpt_tensor = torch.rand(128, device='cuda', dtype=dtype)
base_mod = self.TestMod(inpt_tensor, 32, 2)
input_tensor = torch.rand(128, device='cuda', dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)

other_mod = self.TestMod(inpt_tensor, 64, 1)
other_mod = self.TestMod(input_tensor, 64, 1)
other_mod.load_state_dict(torch.load(saved_state_dict))
assert other_mod.param.block_size == 64
assert other_mod.param.scaler_block_size == 1

@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_copy(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, device='cpu')
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)
input_tensor = torch.rand(128, device='cpu')
input_tensor_nf4 = to_nf4(input_tensor, 32, 2)
nf4_to_dtype = input_tensor_nf4.to(dtype)
torch.testing.assert_allclose(input_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

if torch.cuda.is_available():
inpt_tensor = torch.rand(128, device='cuda')
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
nf4_to_dtype = inpt_tensor_nf4.to(dtype)
torch.testing.assert_allclose(inpt_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)
input_tensor = torch.rand(128, device='cuda')
input_tensor_nf4 = to_nf4(input_tensor, 32, 2)
nf4_to_dtype = input_tensor_nf4.to(dtype)
torch.testing.assert_allclose(input_tensor, nf4_to_dtype, atol=0.13, rtol=0.13)

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
def test_to_copy_device(self):
inpt_tensor = torch.rand(128, device='cpu')
t = to_nf4(inpt_tensor, 32, 2)
input_tensor = torch.rand(128, device='cpu')
t = to_nf4(input_tensor, 32, 2)
assert t.device == torch.device('cpu')
z = t.cuda()
assert z.device.type == "cuda" # Because the device could be cuda:0
x = z.cpu()
assert x.device == torch.device('cpu')

inpt_tensor = torch.rand(128, device='cuda')
t = to_nf4(inpt_tensor, 32, 2)
input_tensor = torch.rand(128, device='cuda')
t = to_nf4(input_tensor, 32, 2)
assert t.device.type == "cuda"

@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_to_dtype(self, dtype: torch.dtype):
inpt_tensor = torch.rand(128, dtype=dtype)
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
assert type(inpt_tensor_nf4) != torch.Tensor
assert type(inpt_tensor_nf4.to(dtype)) == torch.Tensor
assert inpt_tensor_nf4.to(dtype).dtype == dtype
input_tensor = torch.rand(128, dtype=dtype)
input_tensor_nf4 = to_nf4(input_tensor, 32, 2)
assert type(input_tensor_nf4) != torch.Tensor
assert type(input_tensor_nf4.to(dtype)) == torch.Tensor
assert input_tensor_nf4.to(dtype).dtype == dtype

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
Expand Down
88 changes: 44 additions & 44 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,22 +387,22 @@ class SubclassTensorArgs:
requires_grad: bool


def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor:
def get_block_absmax(input_tensor: torch.Tensor, block_size: int) -> torch.Tensor:
"""Iterate through a flattened tensor getting the absmax scalers for each block

Args:
inpt_tensor: Input tensor to get scalers for
input_tensor: Input tensor to get scalers for
block_size: Block size for the scanning window
Returns:
torch.Tensor: Tensor of scalers for each block
"""
assert inpt_tensor.dim() == 1, "Input tensor must be flattened"
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
inpt_tensor.numel() % block_size
) == 0, f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
input_tensor.numel() % block_size
) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"

n_blocks = inpt_tensor.numel() // block_size
blocks = inpt_tensor.view(n_blocks, block_size)
n_blocks = input_tensor.numel() // block_size
blocks = input_tensor.view(n_blocks, block_size)
block_scalers = blocks.abs().max(dim=1).values
return block_scalers

Expand Down Expand Up @@ -478,18 +478,18 @@ def __init__(
@torch.no_grad()
def from_tensor(
cls,
inpt_tensor: torch.Tensor,
input_tensor: torch.Tensor,
block_size: int,
scaler_block_size: int,
):
assert inpt_tensor.dim() <= 2, f"expect input tensor dim <= 2 but got dim = {inpt_tensor.dim()}"
assert input_tensor.dim() <= 2, f"expect input tensor dim <= 2 but got dim = {input_tensor.dim()}"
assert (
inpt_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
assert inpt_tensor.is_contiguous, "Input tensor must be contiguous!"
input_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {block_size}"
assert input_tensor.is_contiguous, "Input tensor must be contiguous!"
# I think I want do this
# assert not inpt_tensor.requires_grad, "Input tensor must not require grad"
device = inpt_tensor.device
# assert not input_tensor.requires_grad, "Input tensor must not require grad"
device = input_tensor.device
# Cache the tensor on the class def
nf4 = torch.tensor(
[
Expand All @@ -511,27 +511,27 @@ def from_tensor(
1.0000,
],
device=device,
dtype=inpt_tensor.dtype,
dtype=input_tensor.dtype,
)
n_blocks = inpt_tensor.numel() // block_size
n_blocks = input_tensor.numel() // block_size
# Double quantization
(
quantized_scalers,
quantization_factor,
scaler_mean,
) = cls.double_quantize_scalers(
inpt_tensor.flatten(), block_size, scaler_block_size
input_tensor.flatten(), block_size, scaler_block_size
)
quantized_data = cls.convert_to_norm_float_weight(
inpt_tensor, n_blocks, block_size, nf4
input_tensor, n_blocks, block_size, nf4
)
tensor_meta = SubclassTensorArgs(
inpt_tensor.size(),
inpt_tensor.stride(),
inpt_tensor.storage_offset(),
inpt_tensor.dtype,
inpt_tensor.device,
inpt_tensor.requires_grad,
input_tensor.size(),
input_tensor.stride(),
input_tensor.storage_offset(),
input_tensor.dtype,
input_tensor.device,
input_tensor.requires_grad,
)
return cls(
tensor_meta,
Expand All @@ -547,7 +547,7 @@ def from_tensor(

@staticmethod
def double_quantize_scalers(
inpt_tensor: torch.Tensor,
input_tensor: torch.Tensor,
block_size: int,
scaler_block_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -557,22 +557,22 @@ def double_quantize_scalers(
And then we calculate the absmax quantization factors for each block again. We then quantize the scalers to int8.

Args:
inpt_tensor: Input tensor to convert to QLoRA format, typically a weight tensor
input_tensor: Input tensor to convert to QLoRA format, typically a weight tensor

Returns:
torch.Tensor: Tensor of per_block quantization factors stored in int8 format
size: (n_blocks)
torch.Tensor: Tensor of per_scaler_block quantization factors stored in int16 format
size: (n_scaler_blocks)
"""
assert inpt_tensor.dim() == 1, "Input tensor must be flattened"
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
inpt_tensor.numel() % scaler_block_size
) == 0, f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {scaler_block_size}"
input_tensor.numel() % scaler_block_size
) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"

# First round of quantization
# Produces: A tensor of size (n_blocks) of inpt_tensor.dtype
scalers_1 = get_block_absmax(inpt_tensor, block_size)
# Produces: A tensor of size (n_blocks) of input_tensor.dtype
scalers_1 = get_block_absmax(input_tensor, block_size)
scalers_1_mean = scalers_1.mean()
scalers_1 = scalers_1 - scalers_1_mean
# Second round of quantization
Expand Down Expand Up @@ -607,52 +607,52 @@ def double_quantize_scalers(

def dequantize_scalers(
self,
inpt_tensor: torch.Tensor,
input_tensor: torch.Tensor,
quantization_factor: torch.Tensor,
scaler_block_size: int,
) -> torch.Tensor:
"""Used to unpack the double quantized scalers

Args;
inpt_tensor: Input tensor to convert to QLoRA format this is the quantized scalers in int8 format
input_tensor: Input tensor to convert to QLoRA format this is the quantized scalers in int8 format
quantization_factor: Tensor of per_scaler_block quantization factors stored in inpt_weight.dtype
size: (n_scaler_blocks)
scaler_block_size: Scaler block size to use for double quantization.

"""
assert inpt_tensor.dim() == 1, "Input tensor must be flattened"
assert input_tensor.dim() == 1, "Input tensor must be flattened"
assert (
inpt_tensor.numel() % scaler_block_size
) == 0, f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {scaler_block_size}"
n_scaler_blocks = inpt_tensor.numel() // scaler_block_size
inpt_tensor = inpt_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (inpt_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
input_tensor.numel() % scaler_block_size
) == 0, f"Input tensor must be divisible by block size, got {input_tensor.numel()} and {scaler_block_size}"
n_scaler_blocks = input_tensor.numel() // scaler_block_size
input_tensor = input_tensor.view(n_scaler_blocks, scaler_block_size)
dequantized = (input_tensor / quantization_factor.unsqueeze(-1)).flatten().to(
self.dtype
) + self.scaler_mean
return dequantized

@staticmethod
def convert_to_norm_float_weight(
inpt_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.Tensor
input_tensor: torch.Tensor, n_blocks: int, block_size: int, nf4: torch.Tensor
) -> torch.Tensor:
"""Convert a tensor to the normalized float weight format"""
flattened_tensor = inpt_tensor.flatten()
flattened_tensor = input_tensor.flatten()
# Since we are using uint8 we will encode 2 entries per byte
numel = inpt_tensor.numel()
numel = input_tensor.numel()
assert (
numel % 2 == 0
), "Number of elements must be even just to not have to think about the end"
# Reshape the flattened tensor into blocks of size self.block_size
blocks = flattened_tensor.view(n_blocks, block_size)

# Scale the blocks
scalers = get_block_absmax(inpt_tensor.flatten(), block_size)
scalers = get_block_absmax(input_tensor.flatten(), block_size)
scales = scalers.unsqueeze(-1).expand(n_blocks, block_size)
scaled_blocks = blocks / scales

# Returns a flattened tensor with each element quantized to nf4 index
# See Note: Quantize in Chunks
quantized_blocks = torch.empty(numel, dtype=torch.uint8, device=inpt_tensor.device)
quantized_blocks = torch.empty(numel, dtype=torch.uint8, device=input_tensor.device)
flattened = scaled_blocks.flatten()
for chunk_num in range(math.ceil(numel / CHUNK_SIZE)):
start = chunk_num * CHUNK_SIZE
Expand Down
12 changes: 6 additions & 6 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def from_float(


def cast_to_float8_e4m3_inference(
inpt_tensor: torch.Tensor,
input_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
static_quantization_scale: Optional[torch.Tensor] = None,
) -> Float8Tensor:
"""Casts an input tensor to the Float8 (e4m3fn*)

Args:
inpt_tensor: The input tensor to be cast.
input_tensor: The input tensor to be cast.
linear_mm_config: Configuration settings for the matrix multiplication
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
static_quantization_scale: Optional tensor specifying the scale for activation. Default is None.
Expand All @@ -193,15 +193,15 @@ def cast_to_float8_e4m3_inference(
Note:
If the input tensor is already in Float8 format, it is returned as is without re-casting.
"""
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
if tensor_already_casted_to_fp8(input_tensor):
return input_tensor
scale = (
static_quantization_scale
if static_quantization_scale is not None
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
else tensor_to_scale(input_tensor, e4m3_dtype, reduce_amax)
)
return hp_tensor_and_scale_to_float8(
inpt_tensor,
input_tensor,
scale,
e4m3_dtype,
linear_mm_config,
Expand Down
Loading