Skip to content

[FIX_FOR_VLLM_CUSTOM=ff1f83b056aedcf3e2d978d267011b2b79c08aca] Hourly fixes – batch no. 3#1053

Merged
iboiko-habana merged 23 commits intovllm-project:mainfrom
pawel-olejniczak:dev/polejnix/fix_batch_3
Mar 6, 2026
Merged

[FIX_FOR_VLLM_CUSTOM=ff1f83b056aedcf3e2d978d267011b2b79c08aca] Hourly fixes – batch no. 3#1053
iboiko-habana merged 23 commits intovllm-project:mainfrom
pawel-olejniczak:dev/polejnix/fix_batch_3

Conversation

@pawel-olejniczak
Copy link
Copy Markdown
Contributor

@pawel-olejniczak pawel-olejniczak commented Feb 26, 2026

This PR contains part of fixes from #903
Fixed issues:
AttributeError: 'FusedMoE' object has no attribute 'forward_impl'
AttributeError: 'PatchedMixtralMoE' object has no attribute 'is_internal_router'
RuntimeError: Overloaded torch operator invoked from Python failed to match any schema
TypeError: HpuPlatform.get_attn_backend_cls() got an unexpected keyword argument 'num_heads'
TypeError: Request.init() got an unexpected keyword argument 'eos_token_id'
KeyError: 'model_type'
AttributeError: 'FusedMoE' object has no attribute 'dp_size'. Did you mean: 'ep_size'?
AttributeError: 'SharedFusedMoE' object has no attribute 'use_dp_chunking'
AttributeError: 'SharedFusedMoE' object has no attribute 'use_pplx_kernels'
AttributeError: 'SharedFusedMoE' object has no attribute 'dp_size'. Did you mean: 'ep_size'?
TypeError: HpuDeepseekOCRDummyInputsBuilder.get_dummy_mm_data() got an unexpected keyword argument 'mm_processor_kwargs'

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR contains fixes to align the vLLM-Gaudi codebase with upstream vLLM API changes, specifically focusing on MoE (Mixture of Experts) module updates, request signature changes, and attention backend enhancements. The changes are part of batch no. 3 from PR #903 and address multiple upstream PRs related to MultiModalKwargsItem and other architectural updates.

Changes:

  • Updated MoE parallel configuration access pattern to use layer.moe_parallel_config instead of direct layer attributes
  • Refactored MoE forward methods to delegate to runner API and added activation normalization for HPU custom ops
  • Updated test utilities to handle Request constructor signature changes with backward compatibility
  • Simplified FP8 weight processing by standardizing on weight_scale_inv attribute naming
  • Enhanced MoE gate synchronization logic after INC conversion with kernel flag syncing

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
vllm_gaudi/v1/worker/hpu_model_runner.py Refactored _sync_shared_moe_gates to add kernel flag synchronization and force external router path for INC-wrapped MoE modules
vllm_gaudi/platform.py Added optional num_heads parameter to get_attn_backend_cls for upstream API compatibility
vllm_gaudi/ops/hpu_fused_moe.py Updated MoE parallel config access, added activation normalization, and delegated forward logic to runner.forward
vllm_gaudi/ops/hpu_fp8.py Migrated dp_size and is_sequence_parallel access to moe_parallel_config
vllm_gaudi/extension/ops.py Added _as_activation_str helper, standardized on weight_scale_inv attribute, and simplified FP8 weight processing
vllm_gaudi/extension/environment.py Changed to use .get() for safer dictionary access in VllmValue
tests/unit_tests/ops/test_hpu_fused_moe.py Updated mock context and changed from forward_impl to forward_native
tests/unit_tests/ops/test_hpu_compressed_tensors.py Updated mock context and changed from forward_impl to forward_native
tests/unit_tests/kv_offload/utils.py Added create_request_compatible_with_signature helper for Request constructor compatibility
tests/unit_tests/kv_offload/test_offloading_connector.py Updated to use new Request creation pattern with sampling params

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread vllm_gaudi/ops/hpu_fused_moe.py Outdated
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_CUSTOM=ff1f83b056aedcf3e2d978d267011b2b79c08aca] Hourly fixes – batch no. 3 [FIX_FOR_VLLM_CUSTOM=83b47f67b1dfad505606070ae4d9f83e50ad4ebd] Hourly fixes – batch no. 3 Mar 2, 2026
@pawel-olejniczak pawel-olejniczak force-pushed the dev/polejnix/fix_batch_3 branch from 5cff2c7 to 5876268 Compare March 2, 2026 13:02
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
This reverts commit 5876268.

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
@pawel-olejniczak pawel-olejniczak changed the title [FIX_FOR_VLLM_CUSTOM=83b47f67b1dfad505606070ae4d9f83e50ad4ebd] Hourly fixes – batch no. 3 [FIX_FOR_VLLM_CUSTOM=ff1f83b056aedcf3e2d978d267011b2b79c08aca] Hourly fixes – batch no. 3 Mar 3, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 3, 2026

🚧 CI Blocked

The main CI workflow was not started for the following reason:

Your branch is behind the base branch. Please merge or rebase to get the latest changes.

@pawel-olejniczak pawel-olejniczak force-pushed the dev/polejnix/fix_batch_3 branch from b48ab35 to 55bb534 Compare March 5, 2026 08:03
This reverts commit a4009d6.

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
…tch_3

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 947 to +949
def fp8_block_linear_postprocess_weights(layer, force_channel_fp8=False):
weight_scale_name = "weight_scale" if hasattr(layer, "weight_scale") else "weight_scale_inv"
weight_scale_inv = getattr(layer, weight_scale_name).data
weight_block_size = layer.weight_block_size if hasattr(
layer, 'weight_block_size') else layer.quant_config.weight_block_size
weight, orig_M, orig_N = pad_block_fp8_weight_naive(layer.weight.data, weight_scale_inv, weight_block_size)
weight, orig_M, orig_N = pad_block_fp8_weight_naive(layer.weight.data, layer.weight_scale_inv.data,
layer.quant_config.weight_block_size)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_block_linear_postprocess_weights now unconditionally reads layer.weight_scale_inv, but block FP8 compressed-tensors layers register weight_scale (see vllm_gaudi/ops/hpu_compressed_tensors.py where BlockQuantScaleParameter is registered as weight_scale). This will raise AttributeError for the compressed-tensors BLOCK path. Consider keeping backwards/variant compatibility by selecting the available scale attribute (e.g., weight_scale_inv if present else weight_scale) and using that consistently for padding/dequant + parameter replacement.

Copilot uses AI. Check for mistakes.
Comment on lines 977 to +1001
def fp8_block_moe_prepare_weights(layer, force_channel_fp8=False):
w13_weight_scale_name = "w13_weight_scale" if hasattr(layer, "w13_weight_scale") else "w13_weight_scale_inv"
w2_weight_scale_name = "w2_weight_scale" if hasattr(layer, "w2_weight_scale") else "w2_weight_scale_inv"
w13_weight_scale_param = getattr(layer, w13_weight_scale_name)
w2_weight_scale_param = getattr(layer, w2_weight_scale_name)
weight_block_size = layer.weight_block_size if hasattr(
layer, 'weight_block_size') else layer.quant_config.weight_block_size

if force_channel_fp8:
# convert to channel-wise fp8
w13_weight, w13_weight_scale_inv = dynamic_quant(
dequant_block_fp8_weight_naive(layer.w13_weight.data, w13_weight_scale_param.data, weight_block_size))
dequant_block_fp8_weight_naive(layer.w13_weight.data, layer.w13_weight_scale_inv.data,
layer.quant_config.weight_block_size))
w2_weight, w2_weight_scale_inv = dynamic_quant(
dequant_block_fp8_weight_naive(layer.w2_weight.data, w2_weight_scale_param.data, weight_block_size))
dequant_block_fp8_weight_naive(layer.w2_weight.data, layer.w2_weight_scale_inv.data,
layer.quant_config.weight_block_size))
w13_weight_scale_inv, w2_weight_scale_inv \
= w13_weight_scale_inv.squeeze(-1), w2_weight_scale_inv.squeeze(-1)
layer.w13_weight.data.copy_(w13_weight)
layer.w2_weight.data.copy_(w2_weight)
replace_parameter(layer, w13_weight_scale_name, torch.nn.Parameter(w13_weight_scale_inv, requires_grad=False))
replace_parameter(layer, w2_weight_scale_name, torch.nn.Parameter(w2_weight_scale_inv, requires_grad=False))
layer.w13_weight_scale_inv = torch.nn.Parameter(w13_weight_scale_inv, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(w2_weight_scale_inv, requires_grad=False)
return fp8_channel_moe_prepare_weights(layer)

for index in range(layer.moe_op.num_experts):
layer.moe_op.w13_list[index].set_weight(layer.w13_weight[index])
layer.moe_op.w13_list[index].set_scale_inv_fp8(w13_weight_scale_param[index])
layer.moe_op.w13_list[index].set_weight_block_size(weight_block_size)
layer.moe_op.w13_list[index].set_scale_inv_fp8(layer.w13_weight_scale_inv[index])
layer.moe_op.w13_list[index].set_weight_block_size(layer.quant_config.weight_block_size)

layer.moe_op.w2_list[index].set_weight(layer.w2_weight[index])
layer.moe_op.w2_list[index].set_scale_inv_fp8(w2_weight_scale_param[index])
layer.moe_op.w2_list[index].set_weight_block_size(weight_block_size)
layer.moe_op.w2_list[index].set_scale_inv_fp8(layer.w2_weight_scale_inv[index])
layer.moe_op.w2_list[index].set_weight_block_size(layer.quant_config.weight_block_size)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8_block_moe_prepare_weights now assumes layer.w13_weight_scale_inv / layer.w2_weight_scale_inv exist, but the compressed-tensors FP8 MoE path uses w13_weight_scale / w2_weight_scale (see HPUCompressedTensorsW8A8Fp8MoEMethod.process_weights_after_loading calling this helper after setting w*_weight_scale). This will break BLOCK compressed-tensors MoE with an AttributeError. Please add handling for both attribute names (or normalize/alias them before using them here).

Copilot uses AI. Check for mistakes.


def create_request_compatible_with_signature(**request_kwargs: Any) -> Request:
if "eos_token_id" in inspect.signature(Request).parameters:
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_request_compatible_with_signature overwrites any caller-provided eos_token_id when the parameter exists in Request's signature. To avoid surprising behavior in future tests, only set eos_token_id if it is supported and not already present in request_kwargs.

Suggested change
if "eos_token_id" in inspect.signature(Request).parameters:
if ("eos_token_id" in inspect.signature(Request).parameters
and "eos_token_id" not in request_kwargs):

Copilot uses AI. Check for mistakes.
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 5, 2026

✅ CI Passed

All checks passed successfully against the following vllm commit:
ff1f83b056aedcf3e2d978d267011b2b79c08aca

@iboiko-habana iboiko-habana merged commit 16b15c7 into vllm-project:main Mar 6, 2026
131 of 143 checks passed
SKRohit pushed a commit to SKRohit/vllm-gaudi that referenced this pull request Mar 12, 2026
… fixes – batch no. 3 (vllm-project#1053)

This PR contains part of fixes from
vllm-project#903
Fixed issues:
AttributeError: 'FusedMoE' object has no attribute 'forward_impl'
AttributeError: 'PatchedMixtralMoE' object has no attribute
'is_internal_router'
RuntimeError: Overloaded torch operator invoked from Python failed to
match any schema
TypeError: HpuPlatform.get_attn_backend_cls() got an unexpected keyword
argument 'num_heads'
TypeError: Request.__init__() got an unexpected keyword argument
'eos_token_id'
KeyError: 'model_type'
AttributeError: 'FusedMoE' object has no attribute 'dp_size'. Did you
mean: 'ep_size'?
AttributeError: 'SharedFusedMoE' object has no attribute
'use_dp_chunking'
AttributeError: 'SharedFusedMoE' object has no attribute
'use_pplx_kernels'
AttributeError: 'SharedFusedMoE' object has no attribute 'dp_size'. Did
you mean: 'ep_size'?
TypeError: HpuDeepseekOCRDummyInputsBuilder.get_dummy_mm_data() got an
unexpected keyword argument 'mm_processor_kwargs'

---------

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Co-authored-by: Iryna Boiko <iryna.boiko@intel.com>
shepark pushed a commit to libinta/vllm-gaudi that referenced this pull request Mar 19, 2026
… fixes – batch no. 3 (vllm-project#1053)

This PR contains part of fixes from
vllm-project#903
Fixed issues:
AttributeError: 'FusedMoE' object has no attribute 'forward_impl'
AttributeError: 'PatchedMixtralMoE' object has no attribute
'is_internal_router'
RuntimeError: Overloaded torch operator invoked from Python failed to
match any schema
TypeError: HpuPlatform.get_attn_backend_cls() got an unexpected keyword
argument 'num_heads'
TypeError: Request.__init__() got an unexpected keyword argument
'eos_token_id'
KeyError: 'model_type'
AttributeError: 'FusedMoE' object has no attribute 'dp_size'. Did you
mean: 'ep_size'?
AttributeError: 'SharedFusedMoE' object has no attribute
'use_dp_chunking'
AttributeError: 'SharedFusedMoE' object has no attribute
'use_pplx_kernels'
AttributeError: 'SharedFusedMoE' object has no attribute 'dp_size'. Did
you mean: 'ep_size'?
TypeError: HpuDeepseekOCRDummyInputsBuilder.get_dummy_mm_data() got an
unexpected keyword argument 'mm_processor_kwargs'

---------

Signed-off-by: Paweł Olejniczak <polejniczakx@habana.ai>
Co-authored-by: Iryna Boiko <iryna.boiko@intel.com>
afierka-intel pushed a commit that referenced this pull request Apr 14, 2026
Rename FP8 blockwise compressed tensors scales to match HPU ops, Fixes
regression in
https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due
to #1220 and
#1053

---------

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Copilot AI pushed a commit that referenced this pull request Apr 14, 2026
Rename FP8 blockwise compressed tensors scales to match HPU ops, Fixes
regression in
https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due
to #1220 and
#1053

---------

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Co-authored-by: michalkuligowski <23379006+michalkuligowski@users.noreply.github.com>
mgawarkiewicz-intel pushed a commit that referenced this pull request Apr 20, 2026
…or v0.19.0 (#1374)

Renames FP8 blockwise compressed tensors scales to match HPU ops, Fixes
regression in
https://huggingface.co/mistralai/Mistral-Large-3-675B-Instruct-2512 due
to #1220 and
#1053

---------

Signed-off-by: Kavulya, Soila P <soila.p.kavulya@intel.com>
Signed-off-by: Soila Kavulya <soila.p.kavulya@intel.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants