diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py index 14910191a6..d054d37090 100644 --- a/aiter/tuned_gemm.py +++ b/aiter/tuned_gemm.py @@ -27,7 +27,6 @@ from aiter import ( dtypes, gemm_a16w16_asm, - getHipblasltKernelName, hipb_create_extension, hipb_mm, logger, @@ -130,7 +129,10 @@ def gen_gemm_a16w16_fake_tensor( scale_c: Optional[Tensor] = None, ) -> Tensor: out = torch.empty( - A.view(-1, A.size(-1)).shape[0], B.shape[0], dtype=A.dtype, device=A.device + A.view(-1, A.size(-1)).shape[0], + B.shape[0], + dtype=otype or A.dtype, + device=A.device, ) return out.view(*A.shape[:-1], B.shape[0])