Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,30 @@ def create_weights(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True

return res

# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
Expand All @@ -470,7 +494,7 @@ def create_weights(
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
weight_loader=patched_weight_loader,
)
layer.register_parameter("weight", weight)

Expand Down Expand Up @@ -511,6 +535,9 @@ def create_weights(
layer.register_parameter("input_scale", None)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

size_k_first = True
input_scale = None
# TODO(rob): refactor block quant into separate class.
Expand Down Expand Up @@ -734,6 +761,41 @@ def create_weights(
f"weight quantization block_k = {block_k}."
)

# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if not self.quant_config.is_checkpoint_fp8_serialized:
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs

def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]

# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()

# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)

# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True

return res

new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs

# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
Expand Down Expand Up @@ -835,6 +897,9 @@ def create_weights(
self.rocm_aiter_moe_enabled = False

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return

# Lazy import to avoid importing triton too early.

self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
Expand Down