Skip to content

Commit 4d1c63d

Browse files
committed
update
1 parent 87ee201 commit 4d1c63d

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

examples/cast/example_group_per_split_token_cast_to_fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
161161
return x_fp8
162162

163163

164-
def main(M=8192, N=8192, BG=2, blk_m=8):
164+
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=[2048, 6144]):
165165
if dtype == "float":
166166
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
167167
elif dtype == "float16":
@@ -170,7 +170,7 @@ def main(M=8192, N=8192, BG=2, blk_m=8):
170170
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
171171
else:
172172
raise ValueError(f"Unsupported dtype: {dtype}")
173-
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32)
173+
batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32)
174174
M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
175175

176176
print("batch_sizes:", batch_sizes)

examples/cast/test_example_cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def test_example_group_per_split_token_cast_to_fp8():
7-
example_group_per_split_token_cast_to_fp8.main(M=8192, N=2048, BG=1, blk_m=4)
7+
example_group_per_split_token_cast_to_fp8.main(M=1024, N=1024, BG=2, blk_m=4, batch_sizes=[128, 896])
88

99

1010
def test_example_per_token_cast_to_fp8():

0 commit comments

Comments
 (0)