Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 1 addition & 92 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode

import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
Expand Down Expand Up @@ -275,26 +274,6 @@ def get_cache_scale(self, name: str) -> str | None:
return None


class CopyNumelCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`. Useful for keeping
track of weight loading where underlying weights can be arbitrarily
transformed (such as with `narrow`) before calling copy.
"""

def __init__(self):
super().__init__()
self.copied_numel = 0

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*args, **kwargs)
if func == torch.ops.aten.copy_.default:
self.copied_numel += args[0].numel()
return out


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand Down Expand Up @@ -577,31 +556,6 @@ def create_weights(
layer.weight_block_size = None

# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

return res

weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
Expand All @@ -610,14 +564,11 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# TODO(future): support block_quant in online quant path
assert not self.block_quant

Expand Down Expand Up @@ -853,9 +804,6 @@ def _setup_kernel(
)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
Expand Down Expand Up @@ -1132,42 +1080,6 @@ def create_weights(
layer.orig_dtype = params_dtype
layer.weight_block_size = None

# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -1211,9 +1123,6 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
layer.w2_input_scale = None

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
Expand Down