Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

# cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
Expand All @@ -186,7 +186,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)

vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales)
vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
Comment thread
tlrmchlsmth marked this conversation as resolved.

return out

Expand Down
66 changes: 50 additions & 16 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,29 +229,63 @@ def process_weights_after_loading(self, layer: Module) -> None:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")

def cutlass_fp8_supported(self) -> bool:
Comment thread
tlrmchlsmth marked this conversation as resolved.
Outdated
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
version = torch.version.cuda
version = version[0] * 10 + version[1]

# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 900:
gpu_is_supported = version > 120
elif capability >= 890:
gpu_is_supported = version > 124

return gpu_is_supported


def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)


if bias is None and self.cutlass_fp8_supported():
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)

# Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)

else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)

# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)

return torch.narrow(output, 0, 0, x.shape[0])

Expand Down