Skip to content

[QeRL] Compose online quantization with quantized reloading#38032

Merged
vllm-bot merged 8 commits intovllm-project:mainfrom
neuralmagic:kylesayrs/online-quant-integration
Mar 27, 2026
Merged

[QeRL] Compose online quantization with quantized reloading#38032
vllm-bot merged 8 commits intovllm-project:mainfrom
neuralmagic:kylesayrs/online-quant-integration

Conversation

@kylesayrs
Copy link
Copy Markdown
Contributor

@kylesayrs kylesayrs commented Mar 24, 2026

Purpose

  • Support online quantized reloading
  • Reuse layerwise reloading logic for online quantization
    • Only convert fp8 for now, support mxfp8 in a follow-up

Lifecycle

Online quantization follows a similar lifecycle to reloading:

Step Quantized Reload Online Quantization
record_metadata_for_reloading Record tensor metadata so that layers can be restored on the meta device Called but not used
restore_layer_on_meta Restore layer to model format at start of reload Not called
initialize_online_processing Wrap weight loaders with online_process_loader, which buffers weights until all layer weights are ready Called by layers with online quantization
_layerwise_process Process weights once all weights are loaded Called by layers with online quantization to quantize weights
Copy into kernel tensors Copy processed weights into original tensor locations to affect compiled cuda graphs, ect. Skipped
finalize_layerwise_reload Catch any layers which did not load all weights (due to attention or padding) Called after loading, supports dummy format

Changes

  • Break out initialize_online_processing for use by online quantization methods
    • Online quantization methods instantiate the full precision weight on the meta device, then call initialize_online_processing on the layer to add layerwise wrappers
    • By reusing the layerwise online_process_loader, we can ensure that weight loaders are never wrapped twice during online quantized reloading
  • Skip kernel copying and replacement step when online quantizing (not reloading)
  • Call finalize_layerwise_process in BaseModelLoader.load_weights. In the future, this function could fully replace process_weights_after_loading

Edge case handling

  • Some models have loaded parameters which are not created by the quantization method, for example e_score_correction_bias. Ideally, these parameters would be initialized on the meta device just like the weights. However, quant_method has no control over these parameters, so we end up loading them twice, leading to a small performance loss

Testing

  • Regression tested test_online_quantization and test_reload_weights
  • Test online quantized reloading in test_online_quantize_reload

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the online FP8 quantization loading mechanism by extracting initialization logic into a new initialize_online_processing function and adjusting how process_weights_after_loading is handled within the model loading pipeline. The LayerReloadingInfo dataclass is updated to make kernel_tensors optional, and related code is adjusted to handle this change. A critical issue was identified where torch.empty_like with a meta device tensor could lead to a NotImplementedError during weight initialization, requiring explicit device specification. Additionally, a commented-out line in base_loader.py should be removed for clarity.

@kylesayrs
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the online quantization loading mechanism by introducing initialize_online_processing and finalize_layerwise_process functions. These changes streamline how weights are loaded and processed for online quantization, particularly for FP8 and MXFP8 methods, and include updates to handle meta devices and weight loaders. A new test test_online_quantize_reload has been added to cover these functionalities. However, there are two critical issues identified: an incorrect method call in finalize_layerwise_process where layer.process_weights_after_loading is called with an incorrect argument and on the wrong object, and the removal of a device type check in CopyCounter and an assertion in get_numel_loaded, which could lead to incorrect tracking of loaded elements and premature weight processing.

@kylesayrs kylesayrs force-pushed the kylesayrs/online-quant-integration branch 2 times, most recently from 1564b60 to 15db502 Compare March 25, 2026 03:55
@kylesayrs kylesayrs marked this pull request as ready for review March 25, 2026 03:59
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels Mar 25, 2026
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
) as llm:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test behavior after the model is loaded and before the first reload happens

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the tested models is empty inference-optimization/DeepSeek-V3-debug-empty, so won't produce sane results.

Rather than add a kluge, I'd rather rely on specific online quantization tests, ie tests/quantization/test_fp8.py::test_online_quantization. Do you think this is enough coverage?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could just skip this for for DeepSeek-V3-debug-empty?

I do think we should test for sane output after online quant and before the first reloading call, test_fp8.py does not test anything related to layerwise reloading


# kernel format (device)
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
kernel_tensors: LayerTensors | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we make the comment explain when this is none vs specified

Copy link
Copy Markdown
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you! my comments are optional nits, lg if tests pass

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs force-pushed the kylesayrs/online-quant-integration branch from 4dd2c5b to acf0959 Compare March 25, 2026 21:30
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
@kylesayrs kylesayrs changed the title Compose online quantization with quantized reloading [QeRL] Compose online quantization with quantized reloading Mar 25, 2026
logger.info("vLLM model structure:\n%s", format_model_inspection(model))


def _has_online_quant(model: nn.Module):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kluge will be removed by neuralmagic#153

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look nice to me, and I'm fairly certain this can't affect the non-online path so I feel good about the spurious failures.

@vllm-bot vllm-bot merged commit 648edcf into vllm-project:main Mar 27, 2026
58 of 63 checks passed
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…ject#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)

with set_default_torch_dtype(model_config.dtype), target_device:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this may not safe. and I guess it cause Language Models Tests (Extra Standard) 2 case fail.
see #38426

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jikunshang I agree, although I'm not 100% sure why. I'm still trying to figure out why/ how to have both

zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Mar 29, 2026
Elm8116 pushed a commit to Elm8116/vllm that referenced this pull request Mar 30, 2026
…ject#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Elham Harirpoush <elham.harirpoush@arm.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…ject#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
benenzhu pushed a commit to benenzhu/vllm that referenced this pull request Mar 31, 2026
…ject#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>
neweyes pushed a commit to neweyes/vllm that referenced this pull request Mar 31, 2026
…ject#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: neweyes <328719365@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants