Skip to content

[wip] layerwise loading for fp8.py, take 2#34020

Closed
vkuzo wants to merge 1 commit intovllm-project:mainfrom
vkuzo:20260206_layerwise_v2
Closed

[wip] layerwise loading for fp8.py, take 2#34020
vkuzo wants to merge 1 commit intovllm-project:mainfrom
vkuzo:20260206_layerwise_v2

Conversation

@vkuzo
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo commented Feb 6, 2026

Summary:

Refactor fp8.py's streaming weight loading to be more similar to QERL (#32133).

Because the logic we need during the initial load is singificantly simpler than QERL (don't care about kernel format, don't care about cuda graphs, etc), I ended up rewriting the logic in a similar style to QERL with only minimal reuse. The alternative would be to branch in the high level reloading functions (like #33814), we decided to rewrite instead of reuse after chatting with @kylesayrs .

Readiness of current PR:

  • dense - works
  • MoE single node - works
  • MoE with TP and EP multi node - need to test
  • load_format dummy - broken and need to fix
  • API design and polish - 80% (some cleanups still to do)

Test Plan:

// dense model single node
VLLM_LOGGING_LEVEL=DEBUG VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8

// moe model single node
VLLM_LOGGING_LEVEL=DEBUG VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 examples/offline_inference/basic/generate.py --model ibm-granite/granite-3.0-1b-a400m-base --enforce-eager --dtype=bfloat16 --max_model_len=2048 --quantization=fp8

// test cases (currently load_format dummy fails, need to investigate + fix)
pytest tests/quantization/test_fp8.py -s -k

TODO write me

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the layer-wise loading mechanism for online quantization, centralizing the logic from fp8.py into a more general framework within vllm/model_executor/model_loader/reload/. This is a positive architectural change that improves modularity and maintainability.

However, the pull request is clearly a work-in-progress, as indicated by numerous TODO comments for documentation, configuration, and code placement. A significant limitation is the NotImplementedError for handling attention layers (Attention and MLAAttention), which will need to be addressed. All the identified TODOs should be resolved before this PR is considered for merging.

Comment on lines 562 to +563
# materialized just-in-time in `patched_weight_loader`
# TODO(before review): say where exactly this will be materialized
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This comment and TODO are outdated since patched_weight_loader has been removed in this refactoring. The materialization now happens in make_online_initial_load_process_loader within vllm/model_executor/model_loader/reload/layerwise.py. Please remove these lines.

Comment on lines +62 to +63
# TODO(before review): set this from config
is_online_quant = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The is_online_quant flag is currently hardcoded. As the TODO suggests, this should be driven by a configuration parameter to allow for dynamic control over the quantization path.

Comment on lines +116 to +118
"""
TODO write me
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring for initialize_layerwise_initial_load is a placeholder. Please add a comprehensive docstring that explains the function's purpose, arguments, and behavior to improve code clarity and maintainability.

Comment on lines +126 to +127
# TODO better place for this?
layer._load_device = target_device
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO comment indicates uncertainty about the placement of layer._load_device. Attaching attributes directly to the layer can sometimes be fragile. Consider if this could be passed more explicitly through the loader context or another mechanism for a cleaner design.

Comment on lines +195 to +196
# TODO(before review): move fp8's one to a common place and use that
class CopyCounter(TorchDispatchMode):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As the TODO suggests, the CopyCounter class is a general utility. Moving it to a common utility file would improve code organization and reusability.

Comment on lines +330 to +333
def finalize_layerwise_initial_load(model: torch.nn.Module, model_config: ModelConfig):
"""
TODO
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring for finalize_layerwise_initial_load is a placeholder. Please provide a complete docstring explaining its functionality, especially how it handles partially loaded layers and special cases like attention layers.

Comment on lines +406 to +409
def _layerwise_initial_load_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
TODO write me
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring for _layerwise_initial_load_process is a placeholder. Please add a proper docstring to explain what this function does.

bound_args.arguments["param"] = current_param

# Cache loaded weights, track loading progress
info.loaded_weights.append((param_name, bound_args))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to delete this

@vkuzo vkuzo force-pushed the 20260206_layerwise_v2 branch 3 times, most recently from d618c48 to 9af16a1 Compare February 9, 2026 13:22
Summary:

TODO write me

Test Plan:

TODO write me

Signed-off-by: Vasiliy Kuznetsov <vasiliy@meta.com>
@vkuzo vkuzo force-pushed the 20260206_layerwise_v2 branch from 9af16a1 to 7679f4d Compare February 9, 2026 14:37
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale later.
num_experts = layer.num_experts
with layer.w13_weight.device:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not make more sense to just call process_after_weight_loading within the with target_device context?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand the reasoning behind duplicating this code? It seems like this logic is essentially the same, but can be slightly modified to skip kernel_tensors stuff by passing a is_reloading flag?

If is_realoding == True, skip all the logic related to replacing kernel tensors, ie the following lines:

  1. https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfL96
  2. https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfL217
  3. https://github.com/vllm-project/vllm/pull/33814/changes#diff-7fbdc7b012d399cf0aabe6611f2a2e79d5047d3f2e19a11e35867f024c1cdcdfL235-L243

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger = init_logger(__name__)

# Global dict storing information used for layerwise loading
INITIAL_LOAD_LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It really doesn't make sense to me to duplicate this variable. Why can't this be reused?

process_weights_after_loading(model, model_config, target_device)
is_online_quant = _is_online_quant(vllm_config, model_config)
if not is_online_quant:
# Regular path, `process_weights_after_loading` is called
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably don't need this many comments


# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be unindented by one?

num_experts = layer.num_experts
with layer.w13_weight.device:
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be initialized with ones? Why not empty?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR

num_experts = layer.num_experts
with layer.w13_weight.device:
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
Copy link
Copy Markdown
Contributor

@kylesayrs kylesayrs Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think scale dtype should theoretically be the model dtype, not necessarily float32, but it's been a while since I looked at this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is code movement from one place to another, so keeping as is to minimize risk since not technically related to this PR

@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 3, 2026

closing in favor of #33814

@vkuzo vkuzo closed this Mar 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants