Skip to content
Merged
4 changes: 2 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_invalid_granularity(self):
def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
match="Unsupported granularity types",
):
Float8DynamicActivationFloat8WeightConfig(
granularity=(PerTensor(), PerRow())
Expand All @@ -165,7 +165,7 @@ def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
with pytest.raises(ValueError, match="Unsupported granularity types"):
Float8DynamicActivationFloat8WeightConfig(
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
)
Expand Down
67 changes: 53 additions & 14 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
PerBlock,
PerRow,
PerTensor,
quantize_,
Expand All @@ -38,10 +39,10 @@


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
def __init__(self, in_features, out_features, bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: feels like bias is a bit confusing (since it can be a flag v.s. Tensor), even though it's used official in nn.Linear, maybe use has_bias as the other tests are doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me do that in a future PR

super().__init__()
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -64,7 +65,10 @@ def setUp(self):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"granularity",
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
)
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
Expand All @@ -74,9 +78,11 @@ def setUp(self):
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
((32, 128), 256, 512),
],
)
@common_utils.parametrize("bias", [False, True])
@torch.no_grad()
def test_fp8_linear_variants(
self,
dtype: torch.dtype,
Expand All @@ -85,14 +91,33 @@ def test_fp8_linear_variants(
granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
bias: bool,
):
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
return unittest.skip(
"per tensor with fbgemm kernel preferece does not work yet"
)
if isinstance(granularity, PerTensor):
if kernel_preference is KernelPreference.FBGEMM:
return unittest.skip(
"per tensor with fbgemm kernel preference does not work yet"
)
elif mode == "weight-only":
return unittest.skip("unimplemented")

elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
if dtype is torch.float32:
return unittest.skip("unimplemented")
elif mode == "weight-only":
return unittest.skip("unimplemented")
elif kernel_preference is KernelPreference.FBGEMM:
return unittest.skip("unimplemented")

if bias is True:
sizes_to_keep = ((128,), 256, 128)
if (
sizes != sizes_to_keep
or kernel_preference is not KernelPreference.TORCH
):
return unittest.skip(
"cut down on number of options to save test time"
)

error_message = None
if isinstance(granularity, PerRow):
Expand Down Expand Up @@ -122,7 +147,7 @@ def test_fp8_linear_variants(
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)

Expand All @@ -137,6 +162,20 @@ def test_fp8_linear_variants(

quantize_(quantized_model, config)

# ensure weight scaling is what we expect
qs1 = quantized_model.linear1.weight.scale
qs2 = quantized_model.linear2.weight.scale
if granularity == PerTensor():
assert qs1.shape == (1, 1)
assert qs2.shape == (1, 1)
elif granularity == PerRow():
assert qs1.shape == (N, 1)
assert qs2.shape == (K, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert qs1.shape == (N // 128, K // 128)
assert qs2.shape == (K // 128, N // 128)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

Expand Down Expand Up @@ -231,7 +270,7 @@ def test_kernel_preference_numerical_equivalence(self, granularity, sizes):
dtype = torch.bfloat16
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
model = ToyLinearModel(K, N, bias=False).eval().to(dtype).to("cuda")

# reference kernel preference and results
# we are using KerenelPreference.TORCH as the reference
Expand Down
46 changes: 46 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
MappingType,
ZeroPointDomain,
_choose_qparams_affine_tinygemm,
_choose_scale_float8,
_fake_quantize_affine,
_fake_quantize_affine_cachemask,
_maybe_expand_scale_to_tensor_shape,
_quantize_affine_float8,
choose_qparams_affine,
dequantize_affine,
quantize_affine,
Expand Down Expand Up @@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs):
return output1


# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100
# with scale modified to be the inverse of the version in PT core
def _tensor_to_scale_block(
x: torch.Tensor,
float8_dtype: torch.dtype,
block_outer: int,
block_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = amax / torch.finfo(float8_dtype).max
x = x.div(scale).to(float8_dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale


# Legacy tinygemm ops
def _get_groupwise_affine_qparams(
w,
Expand Down Expand Up @@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self):
self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8]))
self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2]))

def test_float8_blockwise_scaling(self):
M, K = 512, 1024
hp_tensor = torch.randn(M, K, dtype=torch.float)
# make the scales from some of the blocks obviously different
hp_tensor[0:128, 0:128] *= 3.0
hp_tensor[0:128, 128:256] *= 7.0
hp_tensor[128:256, 0:128] *= 2.0
hp_tensor[128:256, 128:256] *= 100.0

block_size = (128, 128)

scale = _choose_scale_float8(
hp_tensor,
float8_dtype=torch.float8_e4m3fn,
block_size=block_size,
hp_value_lb=None,
hp_value_ub=None,
)
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)

ref_data, ref_scale = _tensor_to_scale_block(
hp_tensor, torch.float8_e4m3fn, 128, 128
)

torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
75 changes: 59 additions & 16 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
Defines an nn module designed to be used during inference
"""

import math
from typing import List, NamedTuple, Optional, Tuple, Union

import torch

from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torchao.float8.types import FP8Granularity
from torchao.quantization.granularity import (
PerBlock,
PerRow,
PerTensor,
)
Expand Down Expand Up @@ -196,6 +198,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
)


def _is_1_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 1x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128


def _is_128_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 128x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) == 2 and b[0] == 128 and b[1] == 128


def _granularity_is_a_1_128_w_128_128(
g: Union[
FP8Granularity,
Tuple[FP8Granularity, FP8Granularity],
list[FP8Granularity],
],
) -> bool:
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))


def _normalize_granularity(
granularity: Optional[
Union[
Expand All @@ -211,22 +243,23 @@ def _normalize_granularity(
elif isinstance(granularity, (PerTensor, PerRow)):
processed_granularity = (granularity, granularity)
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
if not (
isinstance(granularity[0], (PerTensor, PerRow))
and isinstance(granularity[1], (PerTensor, PerRow))
):
raise ValueError(
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
)
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
granularity[1], PerTensor
)
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
granularity[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)

if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
raise ValueError(f"Unsupported granularity types: {granularity}.")
if not isinstance(granularity[0], type(granularity[1])):
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
f"Different granularities for activation and weight are not supported: {granularity}."
)
processed_granularity = tuple(granularity)
else:
raise ValueError(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
)
raise ValueError(f"Invalid granularity specification: {granularity}.")
return processed_granularity


Expand All @@ -243,12 +276,22 @@ def _check_hardware_support(
AssertionError: If hardware doesn't support the requested granularity
ValueError: If invalid granularity type is provided
"""
for _granularity in granularities:
if not isinstance(_granularity, (PerTensor, PerRow)):
raise ValueError(
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
)
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
granularities[1], PerTensor
)
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
granularities[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)

if is_per_tensor or is_per_row:
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
)
elif is_a_1_128_w_128_128:
# TODO(future PR): look into AMD support
assert is_sm_at_least_89(), (
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
)
else:
raise ValueError(f"Invalid granularities {granularities}.")
Loading
Loading