Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Aug 5, 2024
1 parent ed83ae2 commit c6adfcb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
15 changes: 15 additions & 0 deletions test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
fp6_llm_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
Expand Down Expand Up @@ -98,6 +99,20 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_fp6_llm_quantize(self):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, device=device)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fp6_llm_weight_only())

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestQuantLlmLinearWeight)

Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/quant_llm/quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,4 +445,4 @@ def apply_quant_llm(weight: Tensor) -> Tensor:


def fp6_llm_weight_only():
return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2))
return quant_llm_fpx_weight_only(3, 2)

0 comments on commit c6adfcb

Please sign in to comment.