Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,9 @@ def forward(self, x):
config = Float8DynamicActivationFloat8WeightConfig()
quantize_(model, config)

ep = torch.export.export(model, (inp,))
# Need to export with strict=True
# https://github.com/pytorch/pytorch/issues/167007
ep = torch.export.export(model, (inp,), strict=True)
print(ep)
FileCheck().check_count(
"torch.ops.torchao.choose_scale_float8.default", 1, exactly=True
Expand Down
164 changes: 148 additions & 16 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Granularity,
PerBlock,
PerRow,
PerTensor,
Expand All @@ -42,6 +43,8 @@
class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features, bias):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.linear1 = torch.nn.Linear(in_features, out_features, bias=bias)
self.linear2 = torch.nn.Linear(out_features, in_features, bias=bias)

Expand All @@ -50,6 +53,21 @@ def forward(self, x):
x = self.linear2(x)
return x

def check_weight_scaling(self, granularity: Granularity):
qs1 = self.linear1.weight.scale
qs2 = self.linear2.weight.scale
N, K = (self.out_features, self.in_features)
if granularity == PerTensor():
assert qs1.shape == (1, 1)
assert qs2.shape == (1, 1)
elif granularity == PerRow():
assert qs1.shape == (N, 1)
assert qs2.shape == (K, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert qs1.shape == (N // 128, K // 128)
assert qs2.shape == (K // 128, N // 128)


class ToyConvModel(torch.nn.Module):
def __init__(
Expand All @@ -73,6 +91,37 @@ def forward(self, x):
return self.conv(x)


class ToyLoRAModel(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
lora_rank: int,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
self.lora_A = torch.nn.Parameter(torch.randn(in_features, lora_rank))
self.lora_B = torch.nn.Parameter(torch.randn(lora_rank, out_features))

def forward(self, x):
matmul_out = torch.matmul(x, self.linear.weight.t())
lora_out = x @ self.lora_A @ self.lora_B
return matmul_out + lora_out

def check_weight_scaling(self, granularity: Granularity):
qs = self.linear.weight.scale
N, K = (self.out_features, self.in_features)
if granularity == PerTensor():
assert qs.shape == (1, 1)
elif granularity == PerRow():
assert qs.shape == (N, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert qs.shape == (N // 128, K // 128)


# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -112,10 +161,74 @@ def test_fp8_linear_variants(
dtype: torch.dtype,
mode: str,
compile: bool,
granularity,
granularity: Granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
bias: bool,
):
_, N, K = sizes
self._test_fp8_matmul_model(
dtype,
mode,
compile,
granularity,
kernel_preference,
sizes,
bias,
ToyLinearModel(K, N, bias),
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
)
# Inputs are (M,..), K, N
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
],
)
def test_fp8_matmul_lora_variants(
self,
dtype: torch.dtype,
mode: str,
compile: bool,
granularity: Granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
):
_, N, K = sizes
self._test_fp8_matmul_model(
dtype,
mode,
compile,
granularity,
kernel_preference,
sizes,
bias=False,
model=ToyLoRAModel(K, N, lora_rank=8),
)

def _test_fp8_matmul_model(
self,
dtype: torch.dtype,
mode: str,
compile: bool,
granularity: Granularity,
kernel_preference: KernelPreference,
sizes: Tuple,
bias: bool,
model: torch.nn.Module,
):
if isinstance(granularity, PerTensor):
if kernel_preference is KernelPreference.FBGEMM:
Expand Down Expand Up @@ -172,9 +285,7 @@ def test_fp8_linear_variants(
with error_context:
M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")

# Create a linear layer with bfloat16 dtype
model = ToyLinearModel(K, N, bias).eval().to(dtype).to("cuda")
model = model.eval().to(dtype).to("cuda")

quantized_model = copy.deepcopy(model)

Expand All @@ -190,18 +301,7 @@ def test_fp8_linear_variants(
quantize_(quantized_model, config)

# ensure weight scaling is what we expect
qs1 = quantized_model.linear1.weight.scale
qs2 = quantized_model.linear2.weight.scale
if granularity == PerTensor():
assert qs1.shape == (1, 1)
assert qs2.shape == (1, 1)
elif granularity == PerRow():
assert qs1.shape == (N, 1)
assert qs2.shape == (K, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert qs1.shape == (N // 128, K // 128)
assert qs2.shape == (K // 128, N // 128)
quantized_model.check_weight_scaling(granularity)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)
Expand Down Expand Up @@ -801,6 +901,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):

self.assertEqual(sliced_dequantized, sliced_original)

def test_to_dtype_layout(self):
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
y_fp8 = torch.ops.aten.to.dtype_layout(
x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu"
)
self.assertEqual(y_fp8.dtype, x_fp8.dtype)
self.assertEqual(y_fp8.layout, x_fp8.layout)
self.assertEqual(y_fp8.device, torch.device("cpu"))

def test_has_compatible_shallow_copy_type(self):
x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
x1_fp8 = Float8Tensor.from_hp(x1)
x2_fp8 = Float8Tensor.from_hp(x2)
x3_fp8 = Float8Tensor.from_hp(x3)
self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8))
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2))
self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8))
# Wrong shape
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8))

def test_transpose(self):
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
x_fp8 = Float8Tensor.from_hp(x)
x_fp8_t = x_fp8.t()
torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0)
torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0)
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)


common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

Expand Down
76 changes: 66 additions & 10 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,59 @@ def from_hp(
implements_torch_function = Float8Tensor.implements_torch_function


@implements([aten.linear.default])
@implements_torch_function([torch.nn.functional.linear])
@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
return _float8_mm_impl(input_tensor, weight_tensor.t(), bias)


@implements(aten.matmul.default)
@implements_torch_function(torch.matmul)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = args[0], args[1]
return _float8_mm_impl(input_tensor, weight_tensor)


@implements(aten.mm.default)
@implements_torch_function(torch.mm)
def _(func, types, args, kwargs):
input_tensor, weight_tensor = args[0], args[1]
return _float8_mm_impl(input_tensor, weight_tensor)


@implements(aten.addmm_.default)
def _(func, types, args, kwargs):
bias_tensor, input_tensor, weight_tensor = (
args[0],
args[1],
args[2],
)
assert kwargs.get("alpha", 1) == 1, "only alpha=1 is supported"
assert kwargs.get("beta", 1) == 1, "only beta=1 is supported"
out = _float8_mm_impl(input_tensor, weight_tensor)
return bias_tensor.add_(out)


def _float8_mm_impl(
input_tensor: Float8Tensor,
weight_tensor: Float8Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert isinstance(weight_tensor, Float8Tensor), (
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
)

act_quant_kwargs = weight_tensor.act_quant_kwargs
# quantize activation, if `act_quant_kwargs` is specified
if act_quant_kwargs is not None:
assert not isinstance(input_tensor, TorchAOBaseTensor), (
"input tensor was already quantized"
)
input_tensor = _choose_quant_func_and_quantize_tensor(
input_tensor, act_quant_kwargs
)
Expand Down Expand Up @@ -300,6 +338,7 @@ def _(func, types, args, kwargs):
mm_config = weight_tensor.mm_config
assert mm_config is not None
assert not _is_128_128_scaled(weight_tensor), "unimplemented"
weight_tensor = weight_tensor.t()

out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
Expand Down Expand Up @@ -334,28 +373,25 @@ def _(func, types, args, kwargs):
assert kernel_choice == "torch"
scaled_mm_config = weight_tensor.mm_config
assert scaled_mm_config is not None
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
out_shape = (*input_tensor.shape[:-1], weight_tensor.shape[1])

# Extract tensor data and scales
inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
w_data = weight_tensor.qdata
input_scale = input_tensor.scale
w_scale = weight_tensor.scale

# Handle rowwise scaling
if _is_rowwise_scaled(weight_tensor):
assert _is_rowwise_scaled(input_tensor), (
"Input tensor must be rowwise block size"
)
w_scale = w_scale.transpose(-1, -2)
elif _is_128_128_scaled(weight_tensor):
assert _is_1_128_scaled(input_tensor), (
"input_tensor must be 1x128 scaled"
)
w_scale = w_scale.transpose(-1, -2)

input_scale = preprocess_scale(input_scale, input_tensor.shape)
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
inpt_data, w_data = preprocess_data(inpt_data, w_data, scaled_mm_config)

if _is_128_128_scaled(weight_tensor):
# TODO(future PR): add testing for torch._scaled_mm with
Expand Down Expand Up @@ -389,9 +425,11 @@ def _(func, types, args, kwargs):
)
# when input is not `Float8Tensor`, we expect that it is not quantized
# so this is float8 weight only quantization
return torch.nn.functional.linear(
input_tensor, weight_tensor.dequantize(), bias
)
out = torch.matmul(input_tensor, weight_tensor.dequantize())
if bias is not None:
return out + bias
else:
return out


@implements_torch_function(torch.bmm)
Expand Down Expand Up @@ -709,6 +747,7 @@ def _(func, types, args, kwargs):
assert original_shape[-1] == size[-1], (
f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}"
)
# TODO: this seems wrong, we should merge the first two dimensions instead
qdata = self.qdata.reshape(*size)
scale = self.scale.reshape(*size)
block_size = self.block_size.copy()
Expand Down Expand Up @@ -817,6 +856,23 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, new)


@implements(aten.t.default)
def _(func, types, args, kwargs):
assert len(args) == 1
self = args[0]
assert len(self.block_size) == 2
new_tensor = self.__class__(
self.qdata.t(),
self.scale.t(),
(self.block_size[1], self.block_size[0]),
self.mm_config,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new_tensor)


Float8Tensor.__module__ = "torchao.quantization"

# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
Expand Down
Loading
Loading