diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 659029fd37f7..36d16960ec57 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, return output.view(*output_shape) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_impl( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a=scale_a, scale_b=scale_b, bias=bias) + return output + + +def rocm_per_tensor_w8a8_scaled_mm_fake( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), + dtype=out_dtype) + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) +direct_register_custom_op( + op_name="rocm_per_tensor_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_w8a8_scaled_mm_impl, + mutates_args=[], + fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, + dispatch_key=current_platform.dispatch_key, +) + + def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,