Skip to content
Merged
40 changes: 33 additions & 7 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
Expand Down Expand Up @@ -152,13 +153,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
return output.view(*output_shape)


def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
Expand All @@ -171,10 +169,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return output


def rocm_per_tensor_w8a8_scaled_mm_fake(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
dtype=out_dtype)


def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)


direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
mutates_args=[],
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)


def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
Expand Down