Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

# cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
Expand All @@ -186,7 +186,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)

vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
Comment thread
tlrmchlsmth marked this conversation as resolved.

return out

Expand Down
66 changes: 50 additions & 16 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@
logger = init_logger(__name__)


def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
version = torch.version.cuda
version = version[0] * 10 + version[1]

# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 900:
gpu_is_supported = version > 120
elif capability >= 890:
gpu_is_supported = version > 124

return gpu_is_supported


class Fp8Config(QuantizationConfig):
"""Config class for FP8."""

Expand Down Expand Up @@ -92,6 +110,7 @@ class Fp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

def _create_scale_param(
self,
Expand Down Expand Up @@ -233,25 +252,40 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)

if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)

# Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)

else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)

return torch.narrow(output, 0, 0, x.shape[0])

Expand Down