diff --git a/tests/full_tests/ci_e2e_discoverable_tests.sh b/tests/full_tests/ci_e2e_discoverable_tests.sh index 693915405..c6a306ee3 100755 --- a/tests/full_tests/ci_e2e_discoverable_tests.sh +++ b/tests/full_tests/ci_e2e_discoverable_tests.sh @@ -473,9 +473,8 @@ launch_all_tests() { run_tp2_load_generate_test run_mla_moe_load_generate_test run_granite_inc_load_generate_test - # Failed after #32344 - #run_deepseek_v2_inc_load_generate_test - #run_deepseek_v2_inc_dynamic_tp2_load_generate_test + run_deepseek_v2_inc_load_generate_test + run_deepseek_v2_inc_dynamic_tp2_load_generate_test run_qwen3_inc_dynamic_load_generate_test run_dsv2_blockfp8_static_scaling_fp8kv_load_generate_test run_qwen3_8b_fp8_attn_static_scaling_fp8kv_test diff --git a/tests/unit_tests/kv_offload/test_offloading_connector.py b/tests/unit_tests/kv_offload/test_offloading_connector.py index cdfef2e21..b3eb1bc0e 100644 --- a/tests/unit_tests/kv_offload/test_offloading_connector.py +++ b/tests/unit_tests/kv_offload/test_offloading_connector.py @@ -42,10 +42,10 @@ TransferSpec, ) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput -from vllm.v1.request import Request from .utils import ( EOS_TOKEN_ID, + create_request_compatible_with_signature, create_model_runner_output, create_scheduler, create_vllm_config, @@ -215,14 +215,18 @@ def __init__(self, offloaded_block_size: int, gpu_block_size: int, num_gpu_block def new_request(self, token_ids: list[int]): self.req_id += 1 - req = Request( - request_id=str(self.req_id), - prompt_token_ids=token_ids, - sampling_params=SamplingParams(max_tokens=1000), - pooling_params=None, - eos_token_id=EOS_TOKEN_ID, - block_hasher=self._block_hasher, - ) + sampling_params = SamplingParams(max_tokens=1000) + sampling_params.update_from_generation_config({}, EOS_TOKEN_ID) + + request_kwargs: dict[str, Any] = { + "request_id": str(self.req_id), + "prompt_token_ids": token_ids, + "sampling_params": sampling_params, + "pooling_params": None, + "block_hasher": self._block_hasher, + } + + req = create_request_compatible_with_signature(**request_kwargs) self.scheduler.add_request(req) diff --git a/tests/unit_tests/kv_offload/utils.py b/tests/unit_tests/kv_offload/utils.py index 0afdb2503..7db061344 100644 --- a/tests/unit_tests/kv_offload/utils.py +++ b/tests/unit_tests/kv_offload/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile +import inspect from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass @@ -43,6 +44,12 @@ EOS_TOKEN_ID = 50256 +def create_request_compatible_with_signature(**request_kwargs: Any) -> Request: + if "eos_token_id" in inspect.signature(Request).parameters: + request_kwargs["eos_token_id"] = EOS_TOKEN_ID + return Request(**request_kwargs) + + def assert_scheduler_empty(scheduler: Scheduler): """Confirm the scheduler is "empty" - i.e. no leaks.""" # Scheduler Metadata. @@ -197,20 +204,21 @@ def create_request( max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) + sampling_params.update_from_generation_config({}, EOS_TOKEN_ID) common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else [] suffix = [i * request_id for i in range(num_tokens - common_prefix_len)] prompt_token_ids = common_prefix + suffix - req = Request( - request_id=f"id-{request_id}", - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - pooling_params=None, - mm_features=None, - eos_token_id=EOS_TOKEN_ID, - block_hasher=get_request_block_hasher(block_size, hash_fn), - ) + request_kwargs: dict[str, Any] = { + "request_id": f"id-{request_id}", + "prompt_token_ids": prompt_token_ids, + "sampling_params": sampling_params, + "pooling_params": None, + "mm_features": None, + "block_hasher": get_request_block_hasher(block_size, hash_fn), + } + req = create_request_compatible_with_signature(**request_kwargs) req.kv_transfer_params = kv_transfer_params return req diff --git a/vllm_gaudi/extension/environment.py b/vllm_gaudi/extension/environment.py index 5b1b544e2..21efcb45a 100644 --- a/vllm_gaudi/extension/environment.py +++ b/vllm_gaudi/extension/environment.py @@ -100,7 +100,7 @@ def VllmValue(name, env_var_type, depend=None): if depend is not None: return Value(name, env_var_type=env_var_type, dependencies=depend) global _VLLM_VALUES - return Value(name, lambda _: _VLLM_VALUES[name], env_var_type=env_var_type) + return Value(name, lambda _: _VLLM_VALUES.get(name), env_var_type=env_var_type) def get_environment(): diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index d6cf35632..5b3b2a640 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -17,6 +17,7 @@ import habana_frameworks.torch.utils.experimental as htexp import types from vllm.model_executor.layers.fused_moe import FusedMoeWeightScaleSupported +from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization import get_quantization_config as vllm_get_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -33,6 +34,13 @@ MAX_EXPERTS_PER_SLICE = int(os.environ.get("MAX_EXPERTS_PER_SLICE", -1)) +def _as_activation_str(activation): + """Normalize activation to string for HPU custom op.""" + if isinstance(activation, MoEActivation): + return activation.value + return activation + + def get_inc_quant_method(layer): return layer @@ -626,6 +634,7 @@ def __init__(self, def forward(self, hidden_states, expert_routing_table, router_weights, permuted_weights=True, activation="silu"): tokens_num, _ = hidden_states.shape + activation = _as_activation_str(activation) kwargs = self._get_extra_kwargs(tokens_num) # pre-processing for custom op inputs experts_range = range(self.num_experts) @@ -936,23 +945,20 @@ def fp8_perchannel_linear_postprocess_weights(layer): def fp8_block_linear_postprocess_weights(layer, force_channel_fp8=False): - weight_scale_name = "weight_scale" if hasattr(layer, "weight_scale") else "weight_scale_inv" - weight_scale_inv = getattr(layer, weight_scale_name).data - weight_block_size = layer.weight_block_size if hasattr( - layer, 'weight_block_size') else layer.quant_config.weight_block_size - weight, orig_M, orig_N = pad_block_fp8_weight_naive(layer.weight.data, weight_scale_inv, weight_block_size) + weight, orig_M, orig_N = pad_block_fp8_weight_naive(layer.weight.data, layer.weight_scale_inv.data, + layer.quant_config.weight_block_size) if force_channel_fp8: # convert to channel-wise fp8 weight, weight_scale_inv = dynamic_quant( dequant_block_fp8_weight_naive(weight, - weight_scale_inv.data, - weight_block_size, + layer.weight_scale_inv.data, + layer.quant_config.weight_block_size, original_M=orig_M, original_N=orig_N, do_unpad=True)) weight_scale_inv = weight_scale_inv.squeeze(-1) layer.weight.data.copy_(weight) - replace_parameter(layer, weight_scale_name, torch.nn.Parameter(weight_scale_inv, requires_grad=False)) + layer.weight_scale_inv = torch.nn.Parameter(weight_scale_inv, requires_grad=False) htorch.core.mark_step() return layer else: @@ -969,35 +975,30 @@ def fp8_block_linear_postprocess_weights(layer, force_channel_fp8=False): def fp8_block_moe_prepare_weights(layer, force_channel_fp8=False): - w13_weight_scale_name = "w13_weight_scale" if hasattr(layer, "w13_weight_scale") else "w13_weight_scale_inv" - w2_weight_scale_name = "w2_weight_scale" if hasattr(layer, "w2_weight_scale") else "w2_weight_scale_inv" - w13_weight_scale_param = getattr(layer, w13_weight_scale_name) - w2_weight_scale_param = getattr(layer, w2_weight_scale_name) - weight_block_size = layer.weight_block_size if hasattr( - layer, 'weight_block_size') else layer.quant_config.weight_block_size - if force_channel_fp8: # convert to channel-wise fp8 w13_weight, w13_weight_scale_inv = dynamic_quant( - dequant_block_fp8_weight_naive(layer.w13_weight.data, w13_weight_scale_param.data, weight_block_size)) + dequant_block_fp8_weight_naive(layer.w13_weight.data, layer.w13_weight_scale_inv.data, + layer.quant_config.weight_block_size)) w2_weight, w2_weight_scale_inv = dynamic_quant( - dequant_block_fp8_weight_naive(layer.w2_weight.data, w2_weight_scale_param.data, weight_block_size)) + dequant_block_fp8_weight_naive(layer.w2_weight.data, layer.w2_weight_scale_inv.data, + layer.quant_config.weight_block_size)) w13_weight_scale_inv, w2_weight_scale_inv \ = w13_weight_scale_inv.squeeze(-1), w2_weight_scale_inv.squeeze(-1) layer.w13_weight.data.copy_(w13_weight) layer.w2_weight.data.copy_(w2_weight) - replace_parameter(layer, w13_weight_scale_name, torch.nn.Parameter(w13_weight_scale_inv, requires_grad=False)) - replace_parameter(layer, w2_weight_scale_name, torch.nn.Parameter(w2_weight_scale_inv, requires_grad=False)) + layer.w13_weight_scale_inv = torch.nn.Parameter(w13_weight_scale_inv, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter(w2_weight_scale_inv, requires_grad=False) return fp8_channel_moe_prepare_weights(layer) for index in range(layer.moe_op.num_experts): layer.moe_op.w13_list[index].set_weight(layer.w13_weight[index]) - layer.moe_op.w13_list[index].set_scale_inv_fp8(w13_weight_scale_param[index]) - layer.moe_op.w13_list[index].set_weight_block_size(weight_block_size) + layer.moe_op.w13_list[index].set_scale_inv_fp8(layer.w13_weight_scale_inv[index]) + layer.moe_op.w13_list[index].set_weight_block_size(layer.quant_config.weight_block_size) layer.moe_op.w2_list[index].set_weight(layer.w2_weight[index]) - layer.moe_op.w2_list[index].set_scale_inv_fp8(w2_weight_scale_param[index]) - layer.moe_op.w2_list[index].set_weight_block_size(weight_block_size) + layer.moe_op.w2_list[index].set_scale_inv_fp8(layer.w2_weight_scale_inv[index]) + layer.moe_op.w2_list[index].set_weight_block_size(layer.quant_config.weight_block_size) htorch.core.mark_step() return layer @@ -1133,6 +1134,7 @@ def forward( activation="silu", ): tokens_num, _ = x.shape + activation = _as_activation_str(activation) kwargs = self._get_extra_kwargs(tokens_num) w13_list = [] w2_list = [] @@ -1198,6 +1200,7 @@ def forward( activation="silu", ): tokens_num, _ = x.shape + activation = _as_activation_str(activation) kwargs = self._get_extra_kwargs(tokens_num) experts_range = range(self.num_experts) w13_list = [self.w13_list[i].weight.squeeze() for i in experts_range] @@ -1404,6 +1407,7 @@ def forward( permuted_weights=True, activation="silu", ): + activation = _as_activation_str(activation) w13_list = [] w2_list = [] for j in range(self.num_experts): diff --git a/vllm_gaudi/models/deepseek_ocr.py b/vllm_gaudi/models/deepseek_ocr.py index b9ce8c351..5ca75ea80 100644 --- a/vllm_gaudi/models/deepseek_ocr.py +++ b/vllm_gaudi/models/deepseek_ocr.py @@ -48,6 +48,7 @@ def get_dummy_mm_data( seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions] | None = None, + mm_processor_kwargs: Mapping[str, object] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) diff --git a/vllm_gaudi/ops/hpu_fp8.py b/vllm_gaudi/ops/hpu_fp8.py index 337a23ef8..1bddfed51 100644 --- a/vllm_gaudi/ops/hpu_fp8.py +++ b/vllm_gaudi/ops/hpu_fp8.py @@ -153,7 +153,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 if layer.moe_config.dp_size > 1 and self.use_dispatch_fn: - dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.is_sequence_parallel) + dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.moe_config.is_sequence_parallel) else: dispatch_fn = None @@ -190,6 +190,7 @@ def apply_monolithic( router_logits: torch.Tensor, **kwargs, ) -> torch.Tensor: + is_sequence_parallel = layer.moe_config.is_sequence_parallel input_shape = x.shape x = x.view(-1, x.shape[-1]) if layer.use_grouped_topk or getattr(layer, "custom_routing_function", None) is not None: @@ -209,13 +210,13 @@ def apply_monolithic( dp_metadata = get_hpu_dp_metadata() if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn): hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None - x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) + x = dispatch_tensor(x, hidden_states_across_dp, is_sequence_parallel) topk_ids_across_dp = dp_metadata.topk_ids_across_dp if dp_metadata is not None else None - topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel) + topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, is_sequence_parallel) topk_weights_across_dp = dp_metadata.topk_weights_across_dp if dp_metadata is not None else None - topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel) + topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, is_sequence_parallel) topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index a47451ad8..bc99cb5c5 100755 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from enum import Enum from functools import partial from typing import Union @@ -30,6 +31,10 @@ from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_hidden_states, dispatch_tensor, get_hpu_dp_metadata +def _normalize_moe_activation(activation): + return activation.value if isinstance(activation, Enum) else activation + + @UnquantizedFusedMoEMethod.register_oot class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): """MoE method without quantization.""" @@ -62,7 +67,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 if layer.moe_config.dp_size > 1 and self.use_dispatch_fn: - dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.is_sequence_parallel) + dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.moe_config.is_sequence_parallel) else: dispatch_fn = None @@ -123,7 +128,7 @@ def apply_monolithic( topk_ids, topk_weights, permuted_weights=True, - activation=layer.activation, + activation=_normalize_moe_activation(layer.activation), ) if layer.moe_config.dp_size > 1: return output.view(*(output.size(0), *input_shape[1:])) @@ -177,7 +182,7 @@ def forward_oot( topk_ids.to(torch.int64), topk_weights.to(x.dtype), permuted_weights=True, - activation=layer.activation, + activation=_normalize_moe_activation(layer.activation), ).view(*input_shape) output = layer.moe_op( @@ -185,7 +190,7 @@ def forward_oot( topk_ids, topk_weights, permuted_weights=True, - activation=layer.activation, + activation=_normalize_moe_activation(layer.activation), ) if layer.moe_config.dp_size > 1: return output.view(*(output.size(0), *input_shape[1:])) diff --git a/vllm_gaudi/platform.py b/vllm_gaudi/platform.py index 359b65723..169f23c93 100644 --- a/vllm_gaudi/platform.py +++ b/vllm_gaudi/platform.py @@ -45,6 +45,7 @@ def get_attn_backend_cls( cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", + num_heads: Optional[int] = None, ) -> str: if attn_selector_config.use_sparse: raise NotImplementedError("Sparse Attention is not supported on HPU.") diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 6388f9840..e8b3beaab 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -4436,18 +4436,24 @@ def _remove_duplicate_submodules(self): self._detached_moe_gates.add(id(experts)) def _sync_shared_moe_gates(self): - """Re-sync SharedFusedMoE._gate after INC conversion. + """Apply SharedFusedMoE post-INC synchronization and compatibility. - After INC converts/patches the model, the block-level gate - (e.g. mlp.gate) is properly patched. This method restores - the SharedFusedMoE._gate reference so that the overlapped - execution path inside FusedMoE.forward_impl() also uses the - patched gate. - - Only experts whose _gate was explicitly detached by - _remove_duplicate_submodules are restored; experts whose - _gate was originally None are left unchanged. + Synchronizes per-layer MoE state after INC conversion, including + router handling and compatibility flags expected by INC wrappers. + Detached gate tracking is used only as a cleanup aid. """ + + def _sync_moe_kernel_flags(module: torch.nn.Module): + moe_config = getattr(module, "moe_config", None) + for name in ( + "use_pplx_kernels", + "use_deepep_ht_kernels", + "use_deepep_ll_kernels", + "use_mori_kernels", + "use_fi_all2allv_kernels", + ): + setattr(module, name, bool(getattr(moe_config, name, False))) + model = self.get_model() if not hasattr(model, "model"): return @@ -4457,9 +4463,28 @@ def _sync_shared_moe_gates(self): continue block_gate = getattr(mlp, 'gate', None) experts = getattr(mlp, 'experts', None) - if (block_gate is not None and experts is not None and id(experts) in self._detached_moe_gates): - experts._gate = block_gate - self._detached_moe_gates.remove(id(experts)) + if block_gate is not None and experts is not None: + _sync_moe_kernel_flags(experts) + orig_mod = getattr(experts, "orig_mod", None) + if orig_mod is not None: + _sync_moe_kernel_flags(orig_mod) + + # Force external router path: the model's forward checks + # experts.is_internal_router to decide the gate path. + if isinstance(experts, FusedMoE): + # is_internal_router is a read-only property backed + # by _gate; setting _gate=None makes it return False. + experts._gate = None + else: + # INC wrappers (e.g. PatchedMixtralMoE) don't inherit + # the property — set a plain attribute instead. + experts.is_internal_router = False + runner = getattr(experts, "runner", None) + if runner is not None and hasattr(runner, "gate"): + runner.gate = None + + if id(experts) in self._detached_moe_gates: + self._detached_moe_gates.remove(id(experts)) def _inc_preprocess(self): _apply_inc_patch()