-
-
Notifications
You must be signed in to change notification settings - Fork 15k
refactor fp8.py online quant weight loading to use layerwise reload utils #33814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |||
|
|
||||
| import torch | ||||
| from torch.nn import Module | ||||
| from torch.utils._python_dispatch import TorchDispatchMode | ||||
|
|
||||
| import vllm.envs as envs | ||||
| import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||||
|
|
@@ -223,26 +222,6 @@ def get_cache_scale(self, name: str) -> str | None: | |||
| return None | ||||
|
|
||||
|
|
||||
| class CopyNumelCounter(TorchDispatchMode): | ||||
| """ | ||||
| Tracks total number of elements modified with `copy_`. Useful for keeping | ||||
| track of weight loading where underlying weights can be arbitrarily | ||||
| transformed (such as with `narrow`) before calling copy. | ||||
| """ | ||||
|
|
||||
| def __init__(self): | ||||
| super().__init__() | ||||
| self.copied_numel = 0 | ||||
|
|
||||
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): | ||||
| if kwargs is None: | ||||
| kwargs = {} | ||||
| out = func(*args, **kwargs) | ||||
| if func == torch.ops.aten.copy_.default: | ||||
| self.copied_numel += args[0].numel() | ||||
| return out | ||||
|
|
||||
|
|
||||
| def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: | ||||
| """Copies any attrs present in `old` but not in `new` to `new`""" | ||||
| new_attrs = set(dir(new)) | ||||
|
|
@@ -515,75 +494,26 @@ def create_weights( | |||
| layer.weight_block_size = None | ||||
|
|
||||
| # WEIGHT | ||||
| def patched_weight_loader(param, loaded_weight, *args, **kwargs): | ||||
| # track how many elements we have updated | ||||
| if not hasattr(layer, "_loaded_numel"): | ||||
| layer._loaded_numel = 0 | ||||
|
|
||||
| # when the first `loaded_weight` is about to be | ||||
| # loaded to `param`, materialize `param` just-in-time | ||||
| weight = ModelWeightParameter( | ||||
| data=torch.empty_like(layer.weight, device=layer._load_device), | ||||
| input_dim=1, | ||||
| output_dim=0, | ||||
| weight_loader=patched_weight_loader, | ||||
| ) | ||||
| _copy_missing_attrs(layer.weight, weight) | ||||
| layer.register_parameter("weight", weight) | ||||
| del layer._load_device | ||||
|
|
||||
| # refresh the reference to `param` to reflect just-in-time | ||||
| # materialization | ||||
| param = layer.weight | ||||
|
|
||||
| # load the current weight chunk | ||||
| copy_numel_counter = CopyNumelCounter() | ||||
| with copy_numel_counter: | ||||
| res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] | ||||
| layer._loaded_numel += copy_numel_counter.copied_numel | ||||
|
|
||||
| # if we have loaded all of the elements, call | ||||
| # process_weights_after_loading | ||||
| target_loaded_numel = layer.weight.numel() | ||||
| if layer._loaded_numel == target_loaded_numel: | ||||
| self.process_weights_after_loading(layer) | ||||
|
|
||||
| # Prevent the usual `process_weights_after_loading` call from doing | ||||
| # anything | ||||
| layer._already_called_process_weights_after_loading = True | ||||
|
|
||||
| # Note that we keep `layer._loaded_numel` around just in case | ||||
| # there is logic added to vllm in the future which calls a | ||||
| # weight loader twice - we do not want to re-initialize in | ||||
| # that case. | ||||
|
|
||||
| return res | ||||
|
|
||||
| weight = ModelWeightParameter( | ||||
| data=torch.empty( | ||||
| output_size_per_partition, | ||||
| input_size_per_partition, | ||||
| # materialized just-in-time in `patched_weight_loader` | ||||
| # materialized just-in-time with layerwise loading | ||||
| device="meta", | ||||
| dtype=params_dtype, | ||||
| ), | ||||
| input_dim=1, | ||||
| output_dim=0, | ||||
| weight_loader=patched_weight_loader, | ||||
| weight_loader=weight_loader, | ||||
| ) | ||||
| # stash the correct device for `patched_weight_loader` | ||||
| layer._load_device = torch.get_default_device() | ||||
| layer.register_parameter("weight", weight) | ||||
|
|
||||
| def process_weights_after_loading(self, layer: Module) -> None: | ||||
| if getattr(layer, "_already_called_process_weights_after_loading", False): | ||||
| return | ||||
|
|
||||
| # deferred initialization of randomly initialized weights for the | ||||
| # `--load_format dummy` feature | ||||
| if layer.weight.device == torch.device("meta"): | ||||
| weight = ModelWeightParameter( | ||||
| data=torch.empty_like(layer.weight, device=layer._load_device), | ||||
| data=torch.empty_like(layer.weight, device=torch.get_default_device()), | ||||
| input_dim=1, | ||||
| output_dim=0, | ||||
| weight_loader=layer.weight.weight_loader, | ||||
|
|
@@ -612,9 +542,6 @@ def process_weights_after_loading(self, layer: Module) -> None: | |||
| weight = qweight.t() | ||||
| replace_parameter(layer, "weight", weight.data) | ||||
|
|
||||
| # Prevent duplicate processing (e.g., during weight reload) | ||||
| layer._already_called_process_weights_after_loading = True | ||||
|
|
||||
|
|
||||
| class Fp8MoEMethod(FusedMoEMethodBase): | ||||
| """MoE method for FP8. | ||||
|
|
@@ -1012,86 +939,13 @@ def create_weights( | |||
| layer.orig_dtype = params_dtype | ||||
| layer.weight_block_size = None | ||||
|
|
||||
| # We are doing online quantization, patch the weight loaded | ||||
| # to call `process_weights_after_loading` in a streaming fashion | ||||
| # as soon as the last weight chunk is loaded. | ||||
| weight_loader = extra_weight_attrs["weight_loader"] | ||||
| # create a new holder to prevent modifying behavior of any other | ||||
| # objects which might depend on the old one | ||||
| new_extra_weight_attrs = extra_weight_attrs | ||||
|
|
||||
| def patched_weight_loader(param, loaded_weight, *args, **kwargs): | ||||
| # add a counter to track how many elements we have updated | ||||
| if not hasattr(layer, "_loaded_numel"): | ||||
| layer._loaded_numel = 0 | ||||
|
|
||||
| # save the ids of original w13 and w2 so that we can | ||||
| # distinguish which one `param` should map to further | ||||
| # down in this file | ||||
| layer._w13_weight_orig_id = id(layer.w13_weight) | ||||
| layer._w2_weight_orig_id = id(layer.w2_weight) | ||||
|
|
||||
| # when the first `loaded_weight` is about to be | ||||
| # loaded to `param`, materialize `param` just-in-time | ||||
|
|
||||
| w13_weight = torch.nn.Parameter( | ||||
| torch.empty_like(layer.w13_weight, device=layer._load_device), | ||||
| requires_grad=False, | ||||
| ) | ||||
| set_weight_attrs(w13_weight, extra_weight_attrs) | ||||
| _copy_missing_attrs(layer.w13_weight, w13_weight) | ||||
| layer.register_parameter("w13_weight", w13_weight) | ||||
|
|
||||
| w2_weight = torch.nn.Parameter( | ||||
| torch.empty_like(layer.w2_weight, device=layer._load_device), | ||||
| requires_grad=False, | ||||
| ) | ||||
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||||
| _copy_missing_attrs(layer.w2_weight, w2_weight) | ||||
| layer.register_parameter("w2_weight", w2_weight) | ||||
| del layer._load_device | ||||
|
|
||||
| # refresh the reference to `param` to reflect just-in-time | ||||
| # materialization | ||||
| if id(param) == layer._w13_weight_orig_id: | ||||
| param = layer.w13_weight | ||||
| elif id(param) == layer._w2_weight_orig_id: | ||||
| param = layer.w2_weight | ||||
|
|
||||
| # load the current weight chunk | ||||
| copy_numel_counter = CopyNumelCounter() | ||||
| with copy_numel_counter: | ||||
| res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] | ||||
| layer._loaded_numel += copy_numel_counter.copied_numel | ||||
|
|
||||
| # if we have loaded all of the elements, call | ||||
| # process_weights_after_loading | ||||
| target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() | ||||
| if layer._loaded_numel == target_loaded_numel: | ||||
| self.process_weights_after_loading(layer) | ||||
|
|
||||
| # Prevent the usual `process_weights_after_loading` call | ||||
| # from doing anything | ||||
| layer._already_called_process_weights_after_loading = True | ||||
|
|
||||
| # Note that we keep `layer._loaded_numel`, | ||||
| # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` | ||||
| # around because if EP is on, weight loaders for non-local | ||||
| # experts will run but not actually copy any elements, and we | ||||
| # need to not re-initialize in that case. | ||||
|
|
||||
| return res | ||||
|
|
||||
| new_extra_weight_attrs["weight_loader"] = patched_weight_loader | ||||
| extra_weight_attrs = new_extra_weight_attrs | ||||
|
|
||||
| # WEIGHTS | ||||
| w13_weight = torch.nn.Parameter( | ||||
| torch.empty( | ||||
| num_experts, | ||||
| 2 * intermediate_size_per_partition, | ||||
| hidden_size, | ||||
| # materialized just-in-time in `patched_weight_loader` | ||||
| # materialized just-in-time with layerwise loading | ||||
| device="meta", | ||||
| dtype=params_dtype, | ||||
| ), | ||||
|
|
@@ -1105,7 +959,7 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |||
| num_experts, | ||||
| hidden_size, | ||||
| intermediate_size_per_partition, | ||||
| # materialized just-in-time in `patched_weight_loader` | ||||
| # materialized just-in-time with layerwise loading | ||||
| device="meta", | ||||
| dtype=params_dtype, | ||||
| ), | ||||
|
|
@@ -1118,52 +972,42 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |||
|
|
||||
| # BIASES (for models like GPT-OSS that have biased MoE) | ||||
| if self.moe.has_bias: | ||||
| # Use the original weight_loader (not patched) for biases | ||||
| orig_extra_weight_attrs = dict(extra_weight_attrs) | ||||
| orig_extra_weight_attrs["weight_loader"] = weight_loader | ||||
| w13_bias = torch.nn.Parameter( | ||||
| torch.zeros( | ||||
| num_experts, | ||||
| 2 * intermediate_size_per_partition, | ||||
| dtype=layer.orig_dtype, | ||||
| # materialized just-in-time with layerwise loading | ||||
| # Note: this is currently broken for gpt-oss because it | ||||
| # does not use weight loaders at all in the bf16 weights | ||||
| # path | ||||
| device="meta", | ||||
| ), | ||||
| requires_grad=False, | ||||
| ) | ||||
| layer.register_parameter("w13_bias", w13_bias) | ||||
| set_weight_attrs(w13_bias, orig_extra_weight_attrs) | ||||
| set_weight_attrs(w13_bias, extra_weight_attrs) | ||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to verify GPT-OSS 120B still works as this changes the code added by #34906 and there is no CI coverage
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. following up on this, GPT-OSS bf16 is not expected to work with fp8.py online quant because:
I'm not exactly sure how #34906 worked given 1 and 2 ^. Going to skip this for now as gpt-oss + online quant seems low pri because the official weights are in mxfp4, and we can follow-up if needed. for posterity, the easiest way to test this is using the 20b model from unsloth which goes through the same path as the 120b: |
||||
| w2_bias = torch.nn.Parameter( | ||||
| torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype), | ||||
| requires_grad=False, | ||||
| # materialized just-in-time with layerwise loading | ||||
| # Note: this is currently broken for gpt-oss because it | ||||
| # does not use weight loaders at all in the bf16 weights | ||||
| # path | ||||
| device="meta", | ||||
| ) | ||||
| layer.register_parameter("w2_bias", w2_bias) | ||||
| set_weight_attrs(w2_bias, orig_extra_weight_attrs) | ||||
|
|
||||
| # WEIGHT_SCALES | ||||
| # Allocate 2 scales for w1 and w3 respectively. | ||||
| # They will be combined to a single scale after weight loading. | ||||
| w13_weight_scale = torch.nn.Parameter( | ||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||
| ) | ||||
| w2_weight_scale = torch.nn.Parameter( | ||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||
| ) | ||||
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||||
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||||
| set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||||
| set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||||
| set_weight_attrs(w2_bias, extra_weight_attrs) | ||||
|
|
||||
| layer.w13_input_scale = None | ||||
| layer.w2_input_scale = None | ||||
|
|
||||
| def process_weights_after_loading(self, layer: Module) -> None: | ||||
| if getattr(layer, "_already_called_process_weights_after_loading", False): | ||||
| return | ||||
|
|
||||
| # deferred initialization of randomly initialized weights for the | ||||
| # `--load_format dummy` feature | ||||
| if layer.w13_weight.device == torch.device("meta"): | ||||
| w13_weight = torch.nn.Parameter( | ||||
| torch.empty_like(layer.w13_weight, device=layer._load_device), | ||||
| torch.empty_like(layer.w13_weight, device=torch.get_default_device()), | ||||
| requires_grad=False, | ||||
| ) | ||||
| set_weight_attrs( | ||||
|
|
@@ -1174,7 +1018,7 @@ def process_weights_after_loading(self, layer: Module) -> None: | |||
| initialize_single_dummy_weight(layer.w13_weight) | ||||
| if layer.w2_weight.device == torch.device("meta"): | ||||
| w2_weight = torch.nn.Parameter( | ||||
| torch.empty_like(layer.w2_weight, device=layer._load_device), | ||||
| torch.empty_like(layer.w2_weight, device=torch.get_default_device()), | ||||
| requires_grad=False, | ||||
| ) | ||||
| set_weight_attrs( | ||||
|
|
@@ -1184,6 +1028,20 @@ def process_weights_after_loading(self, layer: Module) -> None: | |||
| layer.register_parameter("w2_weight", w2_weight) | ||||
| initialize_single_dummy_weight(layer.w2_weight) | ||||
|
|
||||
| # WEIGHT_SCALES | ||||
| # Allocate 2 scales for w1 and w3 respectively. | ||||
| # They will be combined to a single scale after weight loading. | ||||
| num_experts = layer.num_experts | ||||
| with layer.w13_weight.device: | ||||
| w13_weight_scale = torch.nn.Parameter( | ||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||
| ) | ||||
| w2_weight_scale = torch.nn.Parameter( | ||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||
| ) | ||||
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||||
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||||
|
|
||||
| # If checkpoint is fp16, quantize in place. | ||||
| fp8_dtype = current_platform.fp8_dtype() | ||||
| w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) | ||||
|
|
@@ -1210,9 +1068,6 @@ def process_weights_after_loading(self, layer: Module) -> None: | |||
| layer.w2_input_scale, | ||||
| ) | ||||
|
|
||||
| # Prevent duplicate processing (e.g., during weight reload) | ||||
| layer._already_called_process_weights_after_loading = True | ||||
|
|
||||
|
|
||||
| class Fp8KVCacheMethod(BaseKVCacheMethod): | ||||
| """ | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gpt-oss bf16 is broken whether biases are initialized on gpu or on meta, going with meta to be consistent with layerwise loading infra
if we want gpt-oss to work with fp8.py we should refactor gpt_oss.py to use weight loaders