Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils._python_dispatch import TorchDispatchMode

import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
Expand Down Expand Up @@ -363,6 +364,26 @@ def get_cache_scale(self, name: str) -> str | None:
return None


class CopyNumelCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`. Useful for keeping
track of weight loading where underlying weights can be arbitrarily
transformed (such as with `narrow`) before calling copy.
"""

def __init__(self):
super().__init__()
self.copied_numel = 0

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*args, **kwargs)
if func == torch.ops.aten.copy_.default:
self.copied_numel += args[0].numel()
return out


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand Down Expand Up @@ -469,13 +490,15 @@ def create_weights(
else:

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
Expand Down Expand Up @@ -1348,13 +1371,15 @@ def create_weights(
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()

# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel

# if we have loaded all of the elements, call
# process_weights_after_loading
Expand Down