Skip to content

Commit 600ca1f

Browse files
committed
fix
1 parent 2afa91a commit 600ca1f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

examples/cast/example_group_per_split_token_cast_to_fp8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
161161
return x_fp8
162162

163163

164-
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=[2048, 6144]):
164+
def main(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None):
165+
if batch_sizes is None:
166+
batch_sizes = [2048, 6144]
165167
if dtype == "float":
166168
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
167169
elif dtype == "float16":

0 commit comments

Comments
 (0)