diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5adcd09b01f9..8ef5d59f8a6d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -5,7 +5,6 @@ import torch from torch.nn import Module -from torch.utils._python_dispatch import TorchDispatchMode import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -275,26 +274,6 @@ def get_cache_scale(self, name: str) -> str | None: return None -class CopyNumelCounter(TorchDispatchMode): - """ - Tracks total number of elements modified with `copy_`. Useful for keeping - track of weight loading where underlying weights can be arbitrarily - transformed (such as with `narrow`) before calling copy. - """ - - def __init__(self): - super().__init__() - self.copied_numel = 0 - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - out = func(*args, **kwargs) - if func == torch.ops.aten.copy_.default: - self.copied_numel += args[0].numel() - return out - - class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -577,31 +556,6 @@ def create_weights( layer.weight_block_size = None # WEIGHT - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Delete the bookkeeping - del layer._loaded_numel - # Prevent the usual `process_weights_after_loading` call from doing - # anything - layer._already_called_process_weights_after_loading = True - - return res - weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, @@ -610,14 +564,11 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): ), input_dim=1, output_dim=0, - weight_loader=patched_weight_loader, + weight_loader=weight_loader, ) layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # TODO(future): support block_quant in online quant path assert not self.block_quant @@ -853,9 +804,6 @@ def _setup_kernel( ) def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight w2 = layer.w2_weight @@ -1132,42 +1080,6 @@ def create_weights( layer.orig_dtype = params_dtype layer.weight_block_size = None - # We are doing online quantization, patch the weight loaded - # to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded. - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - - # load the current weight chunk - copy_numel_counter = CopyNumelCounter() - with copy_numel_counter: - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - layer._loaded_numel += copy_numel_counter.copied_numel - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Delete the bookkeeping - del layer._loaded_numel - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -1211,9 +1123,6 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: - if getattr(layer, "_already_called_process_weights_after_loading", False): - return - # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)