diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index d2d9f9c816c7..7fd986f9ea5e 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -84,12 +84,13 @@ def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype): def init_tensor(dtype, shape): if dtype == torch.int8: - return torch.randint(0, 3, shape, device=device, dtype=dtype) + return torch.randint(0, 2, shape, device=device, dtype=dtype) elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: return torch.randn(shape, device=device, dtype=torch.float16).to(dtype) else: return torch.randn(shape, device=device, dtype=dtype) + torch.manual_seed(42) a = init_tensor(w_dtype, (M, K)) b = init_tensor(x_dtype, (K, N))