[wip] layerwise loading for fp8.py, take 2#34020
[wip] layerwise loading for fp8.py, take 2#34020vkuzo wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the layer-wise loading mechanism for online quantization, centralizing the logic from fp8.py into a more general framework within vllm/model_executor/model_loader/reload/. This is a positive architectural change that improves modularity and maintainability.
However, the pull request is clearly a work-in-progress, as indicated by numerous TODO comments for documentation, configuration, and code placement. A significant limitation is the NotImplementedError for handling attention layers (Attention and MLAAttention), which will need to be addressed. All the identified TODOs should be resolved before this PR is considered for merging.
| # materialized just-in-time in `patched_weight_loader` | ||
| # TODO(before review): say where exactly this will be materialized |
There was a problem hiding this comment.
| # TODO(before review): set this from config | ||
| is_online_quant = True |
| """ | ||
| TODO write me | ||
| """ |
| # TODO better place for this? | ||
| layer._load_device = target_device |
There was a problem hiding this comment.
| # TODO(before review): move fp8's one to a common place and use that | ||
| class CopyCounter(TorchDispatchMode): |
| def finalize_layerwise_initial_load(model: torch.nn.Module, model_config: ModelConfig): | ||
| """ | ||
| TODO | ||
| """ |
| def _layerwise_initial_load_process(layer: torch.nn.Module, info: LayerReloadingInfo): | ||
| """ | ||
| TODO write me | ||
| """ |
| bound_args.arguments["param"] = current_param | ||
|
|
||
| # Cache loaded weights, track loading progress | ||
| info.loaded_weights.append((param_name, bound_args)) |
d618c48 to
9af16a1
Compare
Summary: TODO write me Test Plan: TODO write me Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
9af16a1 to
7679f4d
Compare
| # Allocate 2 scales for w1 and w3 respectively. | ||
| # They will be combined to a single scale later. | ||
| num_experts = layer.num_experts | ||
| with layer.w13_weight.device: |
There was a problem hiding this comment.
Does it not make more sense to just call process_after_weight_loading within the with target_device context?
| @@ -0,0 +1,213 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
I don't really understand the reasoning behind duplicating this code? It seems like this logic is essentially the same, but can be slightly modified to skip kernel_tensors stuff by passing a is_reloading flag?
If is_realoding == True, skip all the logic related to replacing kernel tensors, ie the following lines:
- https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfL96
- https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfL217
- https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfL235-L243
There was a problem hiding this comment.
yep, ^ is this PR: https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfR73
let's sync offline to align here
| logger = init_logger(__name__) | ||
|
|
||
| # Global dict storing information used for layerwise loading | ||
| INITIAL_LOAD_LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = ( |
There was a problem hiding this comment.
It really doesn't make sense to me to duplicate this variable. Why can't this be reused?
| process_weights_after_loading(model, model_config, target_device) | ||
| is_online_quant = _is_online_quant(vllm_config, model_config) | ||
| if not is_online_quant: | ||
| # Regular path, `process_weights_after_loading` is called |
There was a problem hiding this comment.
probably don't need this many comments
|
|
||
| # Log peak GPU memory after loading weights. This is needed | ||
| # to have test coverage on peak memory for online quantization. | ||
| if current_platform.is_cuda(): |
There was a problem hiding this comment.
This should be unindented by one?
| num_experts = layer.num_experts | ||
| with layer.w13_weight.device: | ||
| w13_weight_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False |
There was a problem hiding this comment.
Do these need to be initialized with ones? Why not empty?
There was a problem hiding this comment.
this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR
| num_experts = layer.num_experts | ||
| with layer.w13_weight.device: | ||
| w13_weight_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False |
There was a problem hiding this comment.
I think scale dtype should theoretically be the model dtype, not necessarily float32, but it's been a while since I looked at this.
There was a problem hiding this comment.
this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR
|
closing in favor of #33814 |
Summary:
Refactor fp8.py's streaming weight loading to be more similar to QERL (#32133).
Because the logic we need during the initial load is singificantly simpler than QERL (don't care about kernel format, don't care about cuda graphs, etc), I ended up rewriting the logic in a similar style to QERL with only minimal reuse. The alternative would be to branch in the high level reloading functions (like #33814), we decided to rewrite instead of reuse after chatting with @kylesayrs .
Readiness of current PR:
Test Plan:
TODO write me
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.