Skip to content

Commit 460e6db

Browse files
authored
rename splitk code to not mention float8, try 2
Differential Revision: D59977582 Pull Request resolved: pytorch#529
1 parent 66b5213 commit 460e6db

File tree

3 files changed

+2
-2
lines changed

3 files changed

+2
-2
lines changed

test/dtypes/test_fp8.py renamed to test/prototype/test_splitk.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchao.utils import TORCH_VERSION_AFTER_2_4
1111

1212
try:
13-
from torchao.prototype.fp8 import gemm_split_k, to_float8
13+
from torchao.prototype.splitk import gemm_split_k, to_float8
1414
triton_available = True
1515
except ImportError:
1616
triton_available = False
@@ -38,7 +38,7 @@ def test_gemm_split_k(self):
3838
x_fp8, x_inv_s = to_float8(x, dtype=qdtype)
3939
w_fp8, w_inv_s = to_float8(w, dtype=qdtype)
4040

41-
y_torch, _ = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
41+
y_torch = torch._scaled_mm(x_fp8, w_fp8.t(), out_dtype=dtype, scale_a=x_inv_s, scale_b=w_inv_s)
4242
y_triton = gemm_split_k(x_fp8, w_fp8.t(), scale_a=x_inv_s.item(), scale_b=w_inv_s.item())
4343
y_fp16 = torch.nn.functional.linear(x, w)
4444

0 commit comments

Comments
 (0)