-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
fix memory for online fp8 quantization with streaming weight load #31914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,6 +86,7 @@ | |
| cutlass_fp8_supported, | ||
| normalize_e4m3fn_to_e4m3fnuz, | ||
| ) | ||
| from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight | ||
| from vllm.model_executor.parameter import ( | ||
| BlockQuantScaleParameter, | ||
| ModelWeightParameter, | ||
|
|
@@ -293,6 +294,16 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
| return out | ||
|
|
||
|
|
||
| def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: | ||
| """Copies any attrs present in `old` but not in `new` to `new`""" | ||
| new_attrs = set(dir(new)) | ||
| attrs_to_set = {} | ||
| for attr in dir(old): | ||
| if attr not in new_attrs: | ||
| attrs_to_set[attr] = getattr(old, attr) | ||
| set_weight_attrs(new, attrs_to_set) | ||
|
|
||
|
|
||
| class Fp8LinearMethod(LinearMethodBase): | ||
| """Linear method for FP8. | ||
| Supports loading FP8 checkpoints with static weight scale and | ||
|
|
@@ -578,6 +589,22 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| if not hasattr(layer, "_loaded_numel"): | ||
| layer._loaded_numel = 0 | ||
|
|
||
| # when the first `loaded_weight` is about to be | ||
| # loaded to `param`, materialize `param` just-in-time | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty_like(layer.weight, device=layer._load_device), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=patched_weight_loader, | ||
| ) | ||
| _copy_missing_attrs(layer.weight, weight) | ||
| layer.register_parameter("weight", weight) | ||
| del layer._load_device | ||
|
|
||
| # refresh the reference to `param` to reflect just-in-time | ||
| # materialization | ||
| param = layer.weight | ||
|
|
||
| # load the current weight chunk | ||
| copy_numel_counter = CopyNumelCounter() | ||
| with copy_numel_counter: | ||
|
|
@@ -590,30 +617,50 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| if layer._loaded_numel == target_loaded_numel: | ||
| self.process_weights_after_loading(layer) | ||
|
|
||
| # Delete the bookkeeping | ||
| del layer._loaded_numel | ||
| # Prevent the usual `process_weights_after_loading` call from doing | ||
| # anything | ||
| layer._already_called_process_weights_after_loading = True | ||
|
|
||
| # Note that we keep `layer._loaded_numel` around just in case | ||
| # there is logic added to vllm in the future which calls a | ||
| # weight loader twice - we do not want to re-initialize in | ||
| # that case. | ||
|
|
||
| return res | ||
|
|
||
| weight = ModelWeightParameter( | ||
| data=torch.empty( | ||
| output_size_per_partition, | ||
| input_size_per_partition, | ||
| # materialized just-in-time in `patched_weight_loader` | ||
| device="meta", | ||
| dtype=params_dtype, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=patched_weight_loader, | ||
| ) | ||
| # stash the correct device for `patched_weight_loader` | ||
| layer._load_device = torch.get_default_device() | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| def process_weights_after_loading(self, layer: Module) -> None: | ||
| if getattr(layer, "_already_called_process_weights_after_loading", False): | ||
| return | ||
|
|
||
| # deferred initialization of randomly initialized weights for the | ||
| # `--load_format dummy` feature | ||
| if layer.weight.device == torch.device("meta"): | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty_like(layer.weight, device=layer._load_device), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=layer.weight.weight_loader, | ||
| ) | ||
| _copy_missing_attrs(layer.weight, weight) | ||
| layer.register_parameter("weight", weight) | ||
| initialize_single_dummy_weight(layer.weight) | ||
|
|
||
| # TODO(future): support block_quant in online quant path | ||
| assert not self.block_quant | ||
|
|
||
|
|
@@ -1069,6 +1116,39 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| if not hasattr(layer, "_loaded_numel"): | ||
| layer._loaded_numel = 0 | ||
|
|
||
| # save the ids of original w13 and w2 so that we can | ||
| # distinguish which one `param` should map to further | ||
| # down in this file | ||
| layer._w13_weight_orig_id = id(layer.w13_weight) | ||
| layer._w2_weight_orig_id = id(layer.w2_weight) | ||
|
|
||
| # when the first `loaded_weight` is about to be | ||
| # loaded to `param`, materialize `param` just-in-time | ||
|
|
||
| w13_weight = torch.nn.Parameter( | ||
| torch.empty_like(layer.w13_weight, device=layer._load_device), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs(w13_weight, extra_weight_attrs) | ||
| _copy_missing_attrs(layer.w13_weight, w13_weight) | ||
| layer.register_parameter("w13_weight", w13_weight) | ||
|
|
||
| w2_weight = torch.nn.Parameter( | ||
| torch.empty_like(layer.w2_weight, device=layer._load_device), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||
| _copy_missing_attrs(layer.w2_weight, w2_weight) | ||
| layer.register_parameter("w2_weight", w2_weight) | ||
| del layer._load_device | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found del here will cause error in DP + EP case:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yma11 thanks! I will take a look directly after I fix the logging issue in CI. Just in case I don't repro right away, if you can share your repro command that would be great.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nm, I can repro, looking
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I fixed the issue by making sure we do not incorrectly reinitialize weights when EP is on, please let me know if there are any further issues |
||
|
|
||
| # refresh the reference to `param` to reflect just-in-time | ||
| # materialization | ||
| if id(param) == layer._w13_weight_orig_id: | ||
| param = layer.w13_weight | ||
| elif id(param) == layer._w2_weight_orig_id: | ||
| param = layer.w2_weight | ||
|
|
||
| # load the current weight chunk | ||
| copy_numel_counter = CopyNumelCounter() | ||
| with copy_numel_counter: | ||
|
|
@@ -1081,12 +1161,16 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| if layer._loaded_numel == target_loaded_numel: | ||
| self.process_weights_after_loading(layer) | ||
|
|
||
| # Delete the bookkeeping | ||
| del layer._loaded_numel | ||
| # Prevent the usual `process_weights_after_loading` call | ||
| # from doing anything | ||
| layer._already_called_process_weights_after_loading = True | ||
|
|
||
| # Note that we keep `layer._loaded_numel`, | ||
| # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` | ||
| # around because if EP is on, weight loaders for non-local | ||
| # experts will run but not actually copy any elements, and we | ||
| # need to not re-initialize in that case. | ||
|
|
||
| return res | ||
|
|
||
| new_extra_weight_attrs["weight_loader"] = patched_weight_loader | ||
|
|
@@ -1098,6 +1182,8 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| num_experts, | ||
| 2 * intermediate_size_per_partition, | ||
| hidden_size, | ||
| # materialized just-in-time in `patched_weight_loader` | ||
| device="meta", | ||
| dtype=params_dtype, | ||
| ), | ||
| requires_grad=False, | ||
|
|
@@ -1110,12 +1196,16 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| num_experts, | ||
| hidden_size, | ||
| intermediate_size_per_partition, | ||
| # materialized just-in-time in `patched_weight_loader` | ||
| device="meta", | ||
| dtype=params_dtype, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
| layer.register_parameter("w2_weight", w2_weight) | ||
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||
| # stash the correct device for `patched_weight_loader` | ||
| layer._load_device = torch.get_default_device() | ||
|
|
||
| # WEIGHT_SCALES | ||
| # Allocate 2 scales for w1 and w3 respectively. | ||
|
|
@@ -1138,6 +1228,31 @@ def process_weights_after_loading(self, layer: Module) -> None: | |
| if getattr(layer, "_already_called_process_weights_after_loading", False): | ||
| return | ||
|
|
||
| # deferred initialization of randomly initialized weights for the | ||
| # `--load_format dummy` feature | ||
| if layer.w13_weight.device == torch.device("meta"): | ||
| w13_weight = torch.nn.Parameter( | ||
| torch.empty_like(layer.w13_weight, device=layer._load_device), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs( | ||
| w13_weight, {"weight_loader": layer.w13_weight.weight_loader} | ||
| ) | ||
| _copy_missing_attrs(layer.w13_weight, w13_weight) | ||
| layer.register_parameter("w13_weight", w13_weight) | ||
| initialize_single_dummy_weight(layer.w13_weight) | ||
| if layer.w2_weight.device == torch.device("meta"): | ||
| w2_weight = torch.nn.Parameter( | ||
| torch.empty_like(layer.w2_weight, device=layer._load_device), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs( | ||
| w2_weight, {"weight_loader": layer.w2_weight.weight_loader} | ||
| ) | ||
| _copy_missing_attrs(layer.w2_weight, w2_weight) | ||
| layer.register_parameter("w2_weight", w2_weight) | ||
| initialize_single_dummy_weight(layer.w2_weight) | ||
|
|
||
| # If checkpoint is fp16, quantize in place. | ||
| fp8_dtype = current_platform.fp8_dtype() | ||
| w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,8 @@ | |
| initialize_model, | ||
| process_weights_after_loading, | ||
| ) | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.mem_utils import format_gib | ||
| from vllm.utils.torch_utils import set_default_torch_dtype | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
@@ -56,6 +58,17 @@ def load_model( | |
| logger.debug("Loading weights on %s ...", load_device) | ||
| # Quantization does not happen in `load_weights` but after it | ||
| self.load_weights(model, model_config) | ||
|
|
||
| # Log peak GPU memory after loading weights. This is needed | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note that the actual peak as logged here is not visible when just measuring peak memory after the llm object is initialized - seems like we need extra logging. Open on where to put this if there is a better place.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't add this log by default... could you make it a debug_once log and just set the logging level within the test?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, fixed! |
||
| # to have test coverage on peak memory for online quantization. | ||
| if current_platform.is_cuda(): | ||
| peak_memory = torch.cuda.max_memory_allocated() | ||
| logger.debug_once( | ||
| "Peak GPU memory after loading weights: %s GiB", | ||
| format_gib(peak_memory), | ||
| scope="local", | ||
| ) | ||
|
|
||
| process_weights_after_loading(model, model_config, target_device) | ||
|
|
||
| return model.eval() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.