Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand Down Expand Up @@ -93,6 +93,8 @@ def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
maybe_create_device_identity()

output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes

Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
apply_fp8_linear, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -84,6 +85,7 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
Expand Down Expand Up @@ -162,6 +162,8 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()

output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

Expand Down
14 changes: 8 additions & 6 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
TORCH_DEVICE_IDENTITY = None

# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
Expand Down Expand Up @@ -113,6 +113,13 @@ def requantize_with_max_scale(
return max_w_scale, weight


def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)


def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand Down Expand Up @@ -215,11 +222,6 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.

# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)

# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
Expand Down