[CI] Fix online FP8 quantization materializing tensors on CPU#38456
[CI] Fix online FP8 quantization materializing tensors on CPU#38456haosdent wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the model loading logic to support explicit device targeting during tensor materialization. By updating materialize_meta_tensor and materialize_layer to accept a device parameter, the code now ensures tensors are created on the correct platform device. The review feedback identifies the use of a private PyTorch API for type hinting and recommends using public types to ensure long-term stability.
|
@jikunshang Can you help to take a look when you are available? Many thanks! |
|
Ah, didn't notice #38442 could address the CI failure mentioned in this PR , but I think #38442 is indeed better since it addresses the issue systematically. |
After vllm-project#38426 narrowed the `with target_device:` context to only wrap `initialize_model()`, code that relied on the ambient device context during `load_weights()` started creating tensors on CPU instead of GPU. This fixes three locations: 1. `materialize_meta_tensor()` / `materialize_layer()` — accept an explicit `device` parameter instead of relying on the ambient `torch.device` context. 2. `DummyModelLoader.load_weights()` — passes `device=current_platform.device_type` when materializing meta tensors. 3. `Fp8OnlineMoEMethod.process_weights_after_loading()` — the `torch.ones` calls for `w13_scale` / `w2_scale` now specify `device=layer.w13_weight.device` so the scale tensors land on the same GPU as the already-materialized weights. Signed-off-by: haosdent <haosdent@gmail.com>
51731f8 to
7dae654
Compare
|
thanks for quick fixing! #38442 is approved now. let's wait for CI result. hope it can be merged soon. |
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Address CI failures:
After #38426 moved
load_weights()outside thewith target_device:context (to fix OOM), online FP8 quantization broke in two ways:materialize_meta_tensor()usestorch.empty_strided()without adevice=arg, relying on the ambient device context. Without it, tensors are created on CPU andprocess_weights_after_loadingfails withNotImplementedError: Could not run '_C::dynamic_scaled_fp8_quant' with arguments from the 'CPU' backend.Fp8OnlineMoEMethod.process_weights_after_loading()creates scale tensors withtorch.ones(...)withoutdevice=, causing them to land on CPU and crash the Triton fused MoE kernel withValueError: Pointer argument (at 5) cannot be accessed from Triton (cpu tensor?).Test Plan
Test Result
All 8 tests pass: