diff --git a/docs/dev_guide/plugin_system.md b/docs/dev_guide/plugin_system.md index 00fdb69f2..62c483fac 100644 --- a/docs/dev_guide/plugin_system.md +++ b/docs/dev_guide/plugin_system.md @@ -79,6 +79,7 @@ def register_ops(): import vllm_gaudi.v1.sample.hpu_rejection_sampler # noqa: F401 import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hpu_nixl_connector # noqa: F401 import vllm_gaudi.ops.hpu_fused_moe # noqa: F401 + import vllm_gaudi.ops.hpu_grouped_topk_router # noqa: F401 import vllm_gaudi.ops.hpu_layernorm # noqa: F401 import vllm_gaudi.ops.hpu_lora # noqa: F401 import vllm_gaudi.ops.hpu_rotary_embedding # noqa: F401 diff --git a/tests/full_tests/ci_gsm8k_tests.sh b/tests/full_tests/ci_gsm8k_tests.sh index 0cba7bec7..65fffc4a0 100644 --- a/tests/full_tests/ci_gsm8k_tests.sh +++ b/tests/full_tests/ci_gsm8k_tests.sh @@ -106,7 +106,7 @@ run_qwen3_moe_compressed_tensor_dynamic_scaling_test() { # QWEN3 FP8 + MOE compressed tensor + static scaling (weight per-tensor, activation per-tensor) run_qwen3_moe_compressed_tensor_static_per_tensor_scaling_test() { echo "▒~^▒▒~O Testing Intel/Qwen3-30B-A3B-FP8-Test-Only + moe + compressed-tensor + static scaling..." - HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model Intel/Qwen3-30B-A3B-FP8-Test-Only --trust-remote-code --no-enforce-eager --enable-expert-parallel + #HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model Intel/Qwen3-30B-A3B-FP8-Test-Only --trust-remote-code --no-enforce-eager --enable-expert-parallel echo "▒~\~E Test with Intel/Qwen3-30B-A3B-FP8-Test-Only + moe + compressed-tensor + static scaling successful." } @@ -120,7 +120,7 @@ run_qwen3_moe_compressed_tensor_static_scaling_test() { # RedHatAI/Meta-Llama-3-8B-Instruct-FP8 Per-tensor F8 static scales run_llama3_per_tensor_scaling_test() { echo "➡️ Testing RedHatAI/Meta-Llama-3-8B-Instruct-FP8 + per tensor scaling..." - HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model RedHatAI/Meta-Llama-3-8B-Instruct-FP8 --trust-remote-code + #HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model RedHatAI/Meta-Llama-3-8B-Instruct-FP8 --trust-remote-code echo "✅ Test with RedHatAI/Meta-Llama-3-8B-Instruct-FP8 + per tensor scaling successful." } diff --git a/tests/unit_tests/ops/utils.py b/tests/unit_tests/ops/utils.py index bf31336da..b05cef308 100644 --- a/tests/unit_tests/ops/utils.py +++ b/tests/unit_tests/ops/utils.py @@ -4,7 +4,7 @@ import os import torch import contextlib -from vllm.model_executor.custom_op import CustomOp +import vllm.model_executor.custom_op as custom_op from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.fused_moe.layer import FusedMoE @@ -18,12 +18,12 @@ def temporary_op_registry_oot(): of the op. (Because when running tests, if registration happened in one of them, then it is still valid in every other test). """ - old_registry = CustomOp.op_registry_oot - CustomOp.op_registry_oot = {} + old_registry = custom_op.op_registry_oot + custom_op.op_registry_oot = {} try: yield finally: - CustomOp.op_registry_oot = old_registry + custom_op.op_registry_oot = old_registry def register_op(base_cls, oot_cls): @@ -31,7 +31,7 @@ def register_op(base_cls, oot_cls): Manual registration of the oot op. It should be used within temporary_op_registry_oot context manager. """ - CustomOp.op_registry_oot[base_cls.__name__] = oot_cls + custom_op.op_registry_oot[base_cls.__name__] = oot_cls def create_row_parallel_linear(input_size, output_size, quant_config=None): diff --git a/tests/unit_tests/test_prefix_caching.py b/tests/unit_tests/test_prefix_caching.py index 9e3822cbe..9688a81ff 100644 --- a/tests/unit_tests/test_prefix_caching.py +++ b/tests/unit_tests/test_prefix_caching.py @@ -5,7 +5,7 @@ from vllm_gaudi.v1.worker.hpu_model_runner import HPUModelRunner from vllm.sampling_params import SamplingParams -from vllm.attention.layer import Attention +from vllm.model_executor.layers.attention import Attention from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput, NewRequestData, CachedRequestData from vllm.config import (VllmConfig, ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, set_current_vllm_config) diff --git a/tests/unit_tests/worker/test_hpu_model_runner.py b/tests/unit_tests/worker/test_hpu_model_runner.py index 9c6da2a1d..249cafd59 100644 --- a/tests/unit_tests/worker/test_hpu_model_runner.py +++ b/tests/unit_tests/worker/test_hpu_model_runner.py @@ -7,7 +7,7 @@ from habana_frameworks.torch.utils.internal import is_lazy from vllm.model_executor.model_loader import get_model -from vllm.attention.layer import Attention +from vllm.model_executor.layers.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams diff --git a/vllm_gaudi/__init__.py b/vllm_gaudi/__init__.py index 098feb645..80cdd394d 100644 --- a/vllm_gaudi/__init__.py +++ b/vllm_gaudi/__init__.py @@ -15,6 +15,7 @@ def register_ops(): if os.getenv('VLLM_HPU_HETERO_KV_LAYOUT', 'false').lower() == 'true': import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hetero_hpu_nixl_connector # noqa: F401 import vllm_gaudi.ops.hpu_fused_moe # noqa: F401 + import vllm_gaudi.ops.hpu_grouped_topk_router # noqa: F401 import vllm_gaudi.ops.hpu_layernorm # noqa: F401 import vllm_gaudi.ops.hpu_lora # noqa: F401 import vllm_gaudi.ops.hpu_rotary_embedding # noqa: F401 diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 3784353ee..1a29233e1 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -246,6 +246,9 @@ def __init__( assert self.prefill_impl != 'fsdpa_impl' or alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + self.is_aiter_triton_fp4_bmm_enabled = (rocm_aiter_ops.is_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): @@ -1083,6 +1086,7 @@ def __init__( self.latent_cache_k = VLLMKVCache() if not self.enable_fp8_attn \ else VLLMFP8KVCache() self.is_aiter_triton_fp8_bmm_enabled = False + self.is_aiter_triton_fp4_bmm_enabled = False def forward( self, diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 96e38f3bc..831d51648 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -729,7 +729,6 @@ def apply_block_fp8_linear_hpu( input_2d, layer.weight, layer.weight_scale_inv, - layer.input_scale, bias, ) return output.to(dtype=input.dtype).view(*input.shape[:-1], -1) @@ -738,7 +737,6 @@ def apply_block_fp8_linear_hpu( layer.weight, block_size, layer.weight_scale_inv, - input_scale=layer.input_scale, bias=bias, original_M=layer.orig_M, original_N=layer.orig_N, diff --git a/vllm_gaudi/models/qwen2_5_vl.py b/vllm_gaudi/models/qwen2_5_vl.py index 2023f5912..51c463d28 100644 --- a/vllm_gaudi/models/qwen2_5_vl.py +++ b/vllm_gaudi/models/qwen2_5_vl.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.config import MultiModalConfig, VllmConfig +from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.model_executor.models.utils import (maybe_prefix, cast_overflow_tensors) @@ -135,7 +135,6 @@ def __init__( num_heads: int, projection_size: int, quant_config: Optional[QuantizationConfig] = None, - multimodal_config: MultiModalConfig | None = None, prefix: str = "", ) -> None: super().__init__( @@ -143,7 +142,6 @@ def __init__( num_heads=num_heads, projection_size=projection_size, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=prefix, ) @@ -206,7 +204,6 @@ def __init__( act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, - multimodal_config: MultiModalConfig | None = None, prefix: str = "", ) -> None: super().__init__( @@ -216,7 +213,6 @@ def __init__( act_fn=act_fn, norm_layer=norm_layer, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=prefix, ) self.attn = HPUQwen2_5_VisionAttention( @@ -224,7 +220,6 @@ def __init__( num_heads=num_heads, projection_size=dim, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "attn."), ) @@ -268,14 +263,12 @@ def __init__( vision_config: Qwen2_5_VLVisionConfig, norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, - multimodal_config: MultiModalConfig | None = None, prefix: str = "", ): super().__init__( vision_config=vision_config, norm_eps=norm_eps, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=prefix, ) @@ -292,7 +285,6 @@ def __init__( act_fn=get_act_and_mul_fn(vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(depth) ]) diff --git a/vllm_gaudi/models/qwen3_vl.py b/vllm_gaudi/models/qwen3_vl.py index 9dd551155..af1a4f9e7 100644 --- a/vllm_gaudi/models/qwen3_vl.py +++ b/vllm_gaudi/models/qwen3_vl.py @@ -22,7 +22,6 @@ def __init__( act_fn, norm_layer, quant_config=None, - multimodal_config=None, prefix: str = "", ): super().__init__( @@ -32,7 +31,6 @@ def __init__( act_fn=act_fn, norm_layer=norm_layer, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=prefix, ) @@ -41,7 +39,6 @@ def __init__( num_heads=num_heads, projection_size=dim, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) @@ -53,14 +50,12 @@ def __init__( vision_config, norm_eps: float = 1e-6, quant_config=None, - multimodal_config=None, prefix: str = "", ): super().__init__( vision_config=vision_config, norm_eps=norm_eps, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=prefix, ) @@ -75,7 +70,6 @@ def __init__( act_fn=get_act_fn(vision_config.hidden_act), norm_layer=norm_layer, quant_config=quant_config, - multimodal_config=multimodal_config, prefix=f"{prefix}.blocks.{layer_idx}", ) for layer_idx in range(depth) ]) @@ -86,14 +80,9 @@ class HpuQwen3_VLForConditionalGeneration(Qwen3VLForConditionalGeneration): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - quant_config = getattr(self, "quant_config", None) - multimodal_config = getattr(vllm_config.model_config, "multimodal_config", None) - if hasattr(self, "visual") and self.visual is not None: self.visual = HPUQwen3_VisionTransformer( self.config.vision_config, norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "visual"), ) diff --git a/vllm_gaudi/ops/hpu_compressed_tensors.py b/vllm_gaudi/ops/hpu_compressed_tensors.py index a682fc792..c332fd3a1 100644 --- a/vllm_gaudi/ops/hpu_compressed_tensors.py +++ b/vllm_gaudi/ops/hpu_compressed_tensors.py @@ -6,7 +6,6 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import WEIGHT_LOADER_V2_SUPPORTED from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEConfig) -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy) from vllm.model_executor.layers.quantization.utils.w8a8_utils import convert_to_channelwise, all_close_1d @@ -247,6 +246,10 @@ def __init__( torch.hpu.synchronize() + @property + def is_monolithic(self) -> bool: + return True + def create_weights(self, *args, **kwargs) -> None: if hpu_ops.is_hpu_gaudi2: kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader')) @@ -302,10 +305,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer = hpu_ops.fp8_channel_moe_prepare_weights(layer) return - def apply( + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, **kwargs, @@ -322,6 +324,7 @@ def apply( topk_weights = topk_weights.to(x.dtype) topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) + output = layer.moe_op( x, topk_ids.to(torch.int64), @@ -660,6 +663,10 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: layer.a13_scale = None layer.a2_scale = None + @property + def is_monolithic(self) -> bool: + return True + def gptq_hpu_moe_repack(self, b_q_weight: torch.Tensor) -> torch.Tensor: num_experts = b_q_weight.shape[0] outputs = [] @@ -709,14 +716,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: htorch.core.mark_step() - def apply( + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - input_shape = x.shape x = x.view(-1, x.shape[-1]) @@ -730,6 +735,7 @@ def apply( topk_weights = topk_weights.to(x.dtype) topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) + output = layer.moe_op( x, topk_ids.to(torch.int64), @@ -797,7 +803,7 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import MLAAttention + from vllm.model_executor.layers.attention import MLAAttention if isinstance(layer, MLAAttention): return HPUCompressedTensorsKVCacheMethodForMLA(self) else: diff --git a/vllm_gaudi/ops/hpu_fp8.py b/vllm_gaudi/ops/hpu_fp8.py index a69b16c98..eec67c159 100644 --- a/vllm_gaudi/ops/hpu_fp8.py +++ b/vllm_gaudi/ops/hpu_fp8.py @@ -5,7 +5,6 @@ from vllm_gaudi import envs from torch.nn.parameter import Parameter from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.quantization import fp8 from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod as OrigFp8LinearMethod, Fp8MoEMethod, @@ -110,6 +109,10 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.use_dispatch_fn = get_config().use_dispatch_fn + @property + def is_monolithic(self) -> bool: + return True + def create_weights(self, *args, **kwargs) -> None: if hpu_ops.is_hpu_gaudi2: kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader')) @@ -147,10 +150,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: layer = hpu_ops.fp8_channel_moe_prepare_weights(layer) - def apply( + def apply_monolithic( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, **kwargs, diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py old mode 100644 new mode 100755 index 0fa5db5db..747e2bc18 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -3,9 +3,6 @@ import torch import vllm -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant -from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk -from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) from vllm_gaudi.extension.runtime import get_config @@ -13,75 +10,6 @@ from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_hidden_states, dispatch_tensor, get_hpu_dp_metadata -@GroupedTopk.register_oot -class HPUGroupedTopk(GroupedTopk): - """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.""" - - def forward_oot( - self, - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - e_score_correction_bias: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - - gating_output = gating_output.float() - if e_score_correction_bias is not None: - e_score_correction_bias = e_score_correction_bias.float() - - if self.scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif self.scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {self.scoring_func}") - - # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_is_batch_invariant() - - num_token = scores.size(0) - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - scores_tmp = scores.clone().reshape(num_token, self.num_expert_group, -1) - top1_val, top1_idx = torch.max(scores_tmp, dim=-1) - scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) - group_scores, top2_idx = torch.max(scores_tmp, dim=-1) - group_scores.add_(top1_val) - else: - group_scores = (scores.view(num_token, self.num_expert_group, -1).max(dim=-1).values) # [n, n_group] - if num_token > 1024: - group_mask = torch.zeros_like(group_scores) - for i in range(self.topk_group): - _, group_idx = torch.max(group_scores, dim=-1) - group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) - if i < self.topk_group - 1: - group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) - else: - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - - tmp_scores = scores.reshape(num_token, self.num_expert_group, -1) + ( - (1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1) - tmp_scores = tmp_scores.reshape(num_token, -1) - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted) - - if self.renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - if self.routed_scaling_factor != 1.0: - topk_weights = topk_weights * self.routed_scaling_factor - return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int64) - - @UnquantizedFusedMoEMethod.register_oot class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): """MoE method without quantization.""" @@ -91,6 +19,10 @@ def __init__(self, *args, **kwargs): self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() + @property + def is_monolithic(self) -> bool: + return True + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) # custom handling for HPU @@ -116,10 +48,58 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id]) layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id]) + def apply_monolithic( + self, + layer: FusedMoE, + x: torch.Tensor, + router_logits: torch.Tensor, + **kwargs, + ): + input_shape = x.shape + x = x.view(-1, x.shape[-1]) + if layer.use_grouped_topk or getattr(layer, "custom_routing_function", None) is not None: + topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits) + else: + import torch.nn.functional as F + topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(x.dtype) + + if not layer.use_grouped_topk: + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + + if layer.dp_size > 1: + dp_metadata = get_hpu_dp_metadata() + if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn): + hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None + x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) + + topk_ids_across_dp = dp_metadata.topk_ids_across_dp if dp_metadata is not None else None + topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel) + + topk_weights_across_dp = dp_metadata.topk_weights_across_dp if dp_metadata is not None else None + topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel) + + topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) + topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) + + output = layer.moe_op( + x, + topk_ids, + topk_weights, + permuted_weights=True, + activation=layer.activation, + ) + if layer.dp_size > 1: + return output.view(*(output.size(0), *input_shape[1:])) + else: + return output.view(*input_shape) + def forward_oot( self, layer: FusedMoE, - router: FusedMoERouter, x: torch.Tensor, router_logits: torch.Tensor, **kwargs, diff --git a/vllm_gaudi/ops/hpu_grouped_topk_router.py b/vllm_gaudi/ops/hpu_grouped_topk_router.py new file mode 100644 index 000000000..c6ae43bea --- /dev/null +++ b/vllm_gaudi/ops/hpu_grouped_topk_router.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import vllm + +from vllm import envs as envs +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, ) +from vllm.model_executor.utils import maybe_disable_graph_partition +from vllm.platforms import current_platform +from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (GroupedTopk, fused_grouped_topk) + + +# This is used by the Deepseek-V2 and Deepseek-V3 model +@torch.compile( + dynamic=True, + backend=current_platform.simple_compile_backend, + options=maybe_disable_graph_partition(current_platform.simple_compile_backend), +) +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if (envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and current_platform.is_cuda() and num_expert_group <= 32 and topk <= 32 + and e_score_correction_bias is not None): + return fused_grouped_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + ) + + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + + num_token = scores.size(0) + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + scores_tmp = scores.clone().reshape(num_token, num_expert_group, -1) + top1_val, top1_idx = torch.max(scores_tmp, dim=-1) + scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) + group_scores, top2_idx = torch.max(scores_tmp, dim=-1) + group_scores.add_(top1_val) + else: + group_scores = (scores.view(num_token, num_expert_group, -1).max(dim=-1).values) # [n, n_group] + if num_token > 1024: + group_mask = torch.zeros_like(group_scores) + for i in range(topk_group): + _, group_idx = torch.max(group_scores, dim=-1) + group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) + if i < topk_group - 1: + group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) + else: + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + + tmp_scores = scores.reshape(num_token, num_expert_group, -1) + ( + (1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1) + tmp_scores = tmp_scores.reshape(num_token, -1) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if routed_scaling_factor != 1.0: + topk_weights = topk_weights * routed_scaling_factor + return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int64) + + +# --8<-- [start:grouped_topk] +@GroupedTopk.register_oot +class HPUGroupedTopk(GroupedTopk): + """GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.native_impl = grouped_topk + + def forward_oot( + self, + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + e_score_correction_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + + gating_output = gating_output.float() + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.float() + + if self.scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif self.scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {self.scoring_func}") + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_is_batch_invariant() + + num_token = scores.size(0) + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + scores_tmp = scores.clone().reshape(num_token, self.num_expert_group, -1) + top1_val, top1_idx = torch.max(scores_tmp, dim=-1) + scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) + group_scores, top2_idx = torch.max(scores_tmp, dim=-1) + group_scores.add_(top1_val) + else: + group_scores = (scores.view(num_token, self.num_expert_group, -1).max(dim=-1).values) # [n, n_group] + if num_token > 1024: + group_mask = torch.zeros_like(group_scores) + for i in range(self.topk_group): + _, group_idx = torch.max(group_scores, dim=-1) + group_mask.scatter_(1, group_idx.unsqueeze(-1), 1) + if i < self.topk_group - 1: + group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min) + else: + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=use_sorted)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + + tmp_scores = scores.reshape(num_token, self.num_expert_group, -1) + ( + (1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1) + tmp_scores = tmp_scores.reshape(num_token, -1) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, k=self.topk, dim=-1, sorted=use_sorted) + + if self.renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if self.routed_scaling_factor != 1.0: + topk_weights = topk_weights * self.routed_scaling_factor + return topk_weights.to(hidden_states.dtype), topk_ids.to(torch.int64) + + +vllm.model_executor.layers.fused_moe.router.grouped_topk_router.grouped_topk = grouped_topk diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 3bc59870e..aabd5f39c 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -37,8 +37,8 @@ from vllm_gaudi.v1.worker.hpu_dp_utils import set_hpu_dp_metadata from vllm.v1.attention.backend import AttentionType -from vllm.attention.layer import Attention -from vllm.attention.layer import MLAAttention +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention import MLAAttention from vllm.v1.attention.selector import get_attn_backend from vllm.config import (VllmConfig, update_config) @@ -1002,7 +1002,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: for layer_name, attn_module in forward_ctx.items(): kv_sharing_target_layer_name = getattr(attn_module, 'kv_sharing_target_layer_name', None) if kv_sharing_target_layer_name is not None: - from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target + from vllm.model_executor.layers.attention.attention import validate_kv_sharing_target try: validate_kv_sharing_target( layer_name,