-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[wip] layerwise loading for fp8.py, take 2 #34020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -555,70 +555,24 @@ def create_weights( | |
| layer.weight_block_size = None | ||
|
|
||
| # WEIGHT | ||
| def patched_weight_loader(param, loaded_weight, *args, **kwargs): | ||
| # track how many elements we have updated | ||
| if not hasattr(layer, "_loaded_numel"): | ||
| layer._loaded_numel = 0 | ||
|
|
||
| # when the first `loaded_weight` is about to be | ||
| # loaded to `param`, materialize `param` just-in-time | ||
| weight = ModelWeightParameter( | ||
| data=torch.empty_like(layer.weight, device=layer._load_device), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=patched_weight_loader, | ||
| ) | ||
| _copy_missing_attrs(layer.weight, weight) | ||
| layer.register_parameter("weight", weight) | ||
| del layer._load_device | ||
|
|
||
| # refresh the reference to `param` to reflect just-in-time | ||
| # materialization | ||
| param = layer.weight | ||
|
|
||
| # load the current weight chunk | ||
| copy_numel_counter = CopyNumelCounter() | ||
| with copy_numel_counter: | ||
| res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] | ||
| layer._loaded_numel += copy_numel_counter.copied_numel | ||
|
|
||
| # if we have loaded all of the elements, call | ||
| # process_weights_after_loading | ||
| target_loaded_numel = layer.weight.numel() | ||
| if layer._loaded_numel == target_loaded_numel: | ||
| self.process_weights_after_loading(layer) | ||
|
|
||
| # Prevent the usual `process_weights_after_loading` call from doing | ||
| # anything | ||
| layer._already_called_process_weights_after_loading = True | ||
|
|
||
| # Note that we keep `layer._loaded_numel` around just in case | ||
| # there is logic added to vllm in the future which calls a | ||
| # weight loader twice - we do not want to re-initialize in | ||
| # that case. | ||
|
|
||
| return res | ||
|
|
||
| weight = ModelWeightParameter( | ||
| data=torch.empty( | ||
| output_size_per_partition, | ||
| input_size_per_partition, | ||
| # materialized just-in-time in `patched_weight_loader` | ||
| # TODO(before review): say where exactly this will be materialized | ||
| device="meta", | ||
| dtype=params_dtype, | ||
| ), | ||
| input_dim=1, | ||
| output_dim=0, | ||
| weight_loader=patched_weight_loader, | ||
| weight_loader=weight_loader, | ||
| ) | ||
| # stash the correct device for `patched_weight_loader` | ||
| layer._load_device = torch.get_default_device() | ||
| layer.register_parameter("weight", weight) | ||
|
|
||
| def process_weights_after_loading(self, layer: Module) -> None: | ||
| if getattr(layer, "_already_called_process_weights_after_loading", False): | ||
| return | ||
|
|
||
| # deferred initialization of randomly initialized weights for the | ||
| # `--load_format dummy` feature | ||
| if layer.weight.device == torch.device("meta"): | ||
|
|
@@ -1074,86 +1028,13 @@ def create_weights( | |
| layer.orig_dtype = params_dtype | ||
| layer.weight_block_size = None | ||
|
|
||
| # We are doing online quantization, patch the weight loaded | ||
| # to call `process_weights_after_loading` in a streaming fashion | ||
| # as soon as the last weight chunk is loaded. | ||
| weight_loader = extra_weight_attrs["weight_loader"] | ||
| # create a new holder to prevent modifying behavior of any other | ||
| # objects which might depend on the old one | ||
| new_extra_weight_attrs = extra_weight_attrs | ||
|
|
||
| def patched_weight_loader(param, loaded_weight, *args, **kwargs): | ||
| # add a counter to track how many elements we have updated | ||
| if not hasattr(layer, "_loaded_numel"): | ||
| layer._loaded_numel = 0 | ||
|
|
||
| # save the ids of original w13 and w2 so that we can | ||
| # distinguish which one `param` should map to further | ||
| # down in this file | ||
| layer._w13_weight_orig_id = id(layer.w13_weight) | ||
| layer._w2_weight_orig_id = id(layer.w2_weight) | ||
|
|
||
| # when the first `loaded_weight` is about to be | ||
| # loaded to `param`, materialize `param` just-in-time | ||
|
|
||
| w13_weight = torch.nn.Parameter( | ||
| torch.empty_like(layer.w13_weight, device=layer._load_device), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs(w13_weight, extra_weight_attrs) | ||
| _copy_missing_attrs(layer.w13_weight, w13_weight) | ||
| layer.register_parameter("w13_weight", w13_weight) | ||
|
|
||
| w2_weight = torch.nn.Parameter( | ||
| torch.empty_like(layer.w2_weight, device=layer._load_device), | ||
| requires_grad=False, | ||
| ) | ||
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||
| _copy_missing_attrs(layer.w2_weight, w2_weight) | ||
| layer.register_parameter("w2_weight", w2_weight) | ||
| del layer._load_device | ||
|
|
||
| # refresh the reference to `param` to reflect just-in-time | ||
| # materialization | ||
| if id(param) == layer._w13_weight_orig_id: | ||
| param = layer.w13_weight | ||
| elif id(param) == layer._w2_weight_orig_id: | ||
| param = layer.w2_weight | ||
|
|
||
| # load the current weight chunk | ||
| copy_numel_counter = CopyNumelCounter() | ||
| with copy_numel_counter: | ||
| res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] | ||
| layer._loaded_numel += copy_numel_counter.copied_numel | ||
|
|
||
| # if we have loaded all of the elements, call | ||
| # process_weights_after_loading | ||
| target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() | ||
| if layer._loaded_numel == target_loaded_numel: | ||
| self.process_weights_after_loading(layer) | ||
|
|
||
| # Prevent the usual `process_weights_after_loading` call | ||
| # from doing anything | ||
| layer._already_called_process_weights_after_loading = True | ||
|
|
||
| # Note that we keep `layer._loaded_numel`, | ||
| # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` | ||
| # around because if EP is on, weight loaders for non-local | ||
| # experts will run but not actually copy any elements, and we | ||
| # need to not re-initialize in that case. | ||
|
|
||
| return res | ||
|
|
||
| new_extra_weight_attrs["weight_loader"] = patched_weight_loader | ||
| extra_weight_attrs = new_extra_weight_attrs | ||
|
|
||
| # WEIGHTS | ||
| w13_weight = torch.nn.Parameter( | ||
| torch.empty( | ||
| num_experts, | ||
| 2 * intermediate_size_per_partition, | ||
| hidden_size, | ||
| # materialized just-in-time in `patched_weight_loader` | ||
| # materialized just-in-time in | ||
| # TODO(before review) document where | ||
| device="meta", | ||
| dtype=params_dtype, | ||
| ), | ||
|
|
@@ -1168,36 +1049,20 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): | |
| hidden_size, | ||
| intermediate_size_per_partition, | ||
| # materialized just-in-time in `patched_weight_loader` | ||
| # TODO(before review) document where | ||
| device="meta", | ||
| dtype=params_dtype, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
| layer.register_parameter("w2_weight", w2_weight) | ||
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||
| # stash the correct device for `patched_weight_loader` | ||
| layer._load_device = torch.get_default_device() | ||
|
|
||
| # WEIGHT_SCALES | ||
| # Allocate 2 scales for w1 and w3 respectively. | ||
| # They will be combined to a single scale after weight loading. | ||
| w13_weight_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||
| ) | ||
| w2_weight_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||
| ) | ||
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||
| set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||
| set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||
|
|
||
| layer.w13_input_scale = None | ||
| layer.w2_input_scale = None | ||
|
|
||
| def process_weights_after_loading(self, layer: Module) -> None: | ||
| if getattr(layer, "_already_called_process_weights_after_loading", False): | ||
| return | ||
| # if getattr(layer, "_already_called_process_weights_after_loading", False): | ||
| # return | ||
|
|
||
| # deferred initialization of randomly initialized weights for the | ||
| # `--load_format dummy` feature | ||
|
|
@@ -1224,7 +1089,21 @@ def process_weights_after_loading(self, layer: Module) -> None: | |
| layer.register_parameter("w2_weight", w2_weight) | ||
| initialize_single_dummy_weight(layer.w2_weight) | ||
|
|
||
| # If checkpoint is fp16, quantize in place. | ||
| # WEIGHT_SCALES | ||
| # Allocate 2 scales for w1 and w3 respectively. | ||
| # They will be combined to a single scale later. | ||
| num_experts = layer.num_experts | ||
| with layer.w13_weight.device: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it not make more sense to just call process_after_weight_loading within the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
| w13_weight_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these need to be initialized with ones? Why not empty?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think scale dtype should theoretically be the model dtype, not necessarily float32, but it's been a while since I looked at this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR |
||
| ) | ||
| w2_weight_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||
| ) | ||
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||
|
|
||
| # Quantize the loaded high precision checkpoint to fp8 | ||
| fp8_dtype = current_platform.fp8_dtype() | ||
| w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) | ||
| w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,10 @@ | |
| from vllm.config import ModelConfig, VllmConfig | ||
| from vllm.config.load import LoadConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.model_loader.reload.initial_load import ( | ||
| finalize_layerwise_initial_load, | ||
| initialize_layerwise_initial_load, | ||
| ) | ||
| from vllm.model_executor.model_loader.utils import ( | ||
| initialize_model, | ||
| process_weights_after_loading, | ||
|
|
@@ -56,20 +60,43 @@ def load_model( | |
| log_model_inspection(model) | ||
|
|
||
| logger.debug("Loading weights on %s ...", load_device) | ||
| # Quantization does not happen in `load_weights` but after it | ||
| self.load_weights(model, model_config) | ||
|
|
||
| # Log peak GPU memory after loading weights. This is needed | ||
| # to have test coverage on peak memory for online quantization. | ||
| if current_platform.is_cuda(): | ||
| peak_memory = torch.cuda.max_memory_allocated() | ||
| logger.debug_once( | ||
| "Peak GPU memory after loading weights: %s GiB", | ||
| format_gib(peak_memory), | ||
| scope="local", | ||
| ) | ||
|
|
||
| process_weights_after_loading(model, model_config, target_device) | ||
| is_online_quant = _is_online_quant(vllm_config, model_config) | ||
| if not is_online_quant: | ||
| # Regular path, `process_weights_after_loading` is called | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably don't need this many comments |
||
| # after all weights are loaded. | ||
|
|
||
| # Quantization does not happen in `load_weights` but after it | ||
| self.load_weights(model, model_config) | ||
| process_weights_after_loading(model, model_config, target_device) | ||
|
|
||
| else: | ||
| # Online quantization can take the layerwise loading path | ||
| # where `process_weights_after_loading` is done just-in-time | ||
| # after all of a layer's weights are loaded. | ||
|
|
||
| # set up weight loaders for layerwise loading with | ||
| # streaming post-processing | ||
| initialize_layerwise_initial_load(model, target_device) | ||
|
|
||
| # load the weights, layerwise loading infra will call | ||
| # each layer's `process_weights_after_loading` function | ||
| # as soon as every weight of that layer is loaded | ||
| self.load_weights(model, model_config) | ||
|
|
||
| # finalize layerwise reloading (call any post-processing | ||
| # that did not happen in real time) | ||
| finalize_layerwise_initial_load(model, model_config) | ||
|
|
||
| # Log peak GPU memory after loading weights. This is needed | ||
| # to have test coverage on peak memory for online quantization. | ||
| if current_platform.is_cuda(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be unindented by one? |
||
| peak_memory = torch.cuda.max_memory_allocated() | ||
| logger.debug_once( | ||
| "Peak GPU memory after loading weights: %s GiB", | ||
| format_gib(peak_memory), | ||
| scope="local", | ||
| ) | ||
|
|
||
| return model.eval() | ||
|
|
||
|
|
@@ -82,3 +109,14 @@ def log_model_inspection(model: nn.Module) -> None: | |
| from vllm.model_inspection import format_model_inspection | ||
|
|
||
| logger.info("vLLM model structure:\n%s", format_model_inspection(model)) | ||
|
|
||
|
|
||
| def _is_online_quant(vllm_config: VllmConfig, model_config: ModelConfig) -> bool: | ||
| quant_config = vllm_config.quant_config | ||
| return ( | ||
| # TODO(future): add other online quant paths here | ||
| model_config.quantization == "fp8" | ||
| and quant_config is not None | ||
| and hasattr(quant_config, "is_checkpoint_fp8_serialized") | ||
| and not quant_config.is_checkpoint_fp8_serialized | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment and TODO are outdated since
patched_weight_loaderhas been removed in this refactoring. The materialization now happens inmake_online_initial_load_process_loaderwithinvllm/model_executor/model_loader/reload/layerwise.py. Please remove these lines.