From c72fefe88403a6e50142c29a8dc67fbb0f5756c0 Mon Sep 17 00:00:00 2001 From: ElleElleWu <1608928702@qq.com> Date: Tue, 3 Feb 2026 22:49:55 -0800 Subject: [PATCH 1/5] [Model] Support HunyuanImage3 Diffusion Model for GPU and NPU Co-authored-by: skf1999 <13234016272@163.com> Co-authored-by: Just-it <1161406585@qq.com> Co-authored-by: Semmer2 Signed-off-by: ElleElleWu <1608928702@qq.com> --- .../models/hunyuan_image_3/__init__.py | 2 + .../hunyuan_image_3/hunyuan_fused_moe.py | 184 ++++++++++++++++++ .../hunyuan_image_3_transformer.py | 21 +- 3 files changed, 188 insertions(+), 19 deletions(-) create mode 100644 vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py b/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py index 98a3ac07b1c..5535dcb417c 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py @@ -9,9 +9,11 @@ from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import ( HunyuanImage3Pipeline, ) +from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE __all__ = [ "HunyuanImage3Pipeline", "HunyuanImage3Model", "HunyuanImage3Text2ImagePipeline", + "HunyuanFusedMoE" ] diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py new file mode 100644 index 00000000000..2205cc2017b --- /dev/null +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +from typing import Any, Optional +import torch + +from vllm.config import VllmConfig +from vllm.distributed import get_ep_group +from vllm.distributed.parallel_state import ( + init_model_parallel_group as vllm_init_model_parallel_group, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_world_size +) +import vllm.forward_context as _vllm_fc + +from vllm_omni.platforms import current_omni_platform +from vllm_omni.diffusion.forward_context import get_forward_context as omni_get_ctx +from vllm_omni.diffusion.distributed.parallel_state import ( + get_data_parallel_world_size, + get_world_group, +) + +logger = logging.getLogger(__name__) + +_impl_class: type | None = None + +def _init_mc2_group_for_diffusion_npu( + world_size: int, + data_parallel_size: int, + tensor_parallel_size: int, + backend: str, + local_rank: int, +) -> None: + import vllm_ascend.distributed.parallel_state as vllm_ascend_parallel_state + + if getattr(vllm_ascend_parallel_state, "_MC2", None) is not None: + return + all_ranks = torch.arange(world_size).reshape( + -1, data_parallel_size * tensor_parallel_size + ) + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + vllm_ascend_parallel_state._MC2 = vllm_init_model_parallel_group( + group_ranks, + local_rank, + backend, + group_name="mc2", + ) + + +def _get_impl_class() -> type: + global _impl_class + if _impl_class is not None: + return _impl_class + if current_omni_platform.is_npu(): + _impl_class = _get_npu_impl_class() + elif current_omni_platform.is_cuda(): + _impl_class = _get_cuda_impl_class() + else: + raise NotImplementedError( + f"HunyuanFusedMoE is not implemented for current_omni_platform: " + f"{current_omni_platform!r}" + ) + return _impl_class + + +def _get_cuda_impl_class() -> type: + from vllm.model_executor.layers.fused_moe import SharedFusedMoE + + class HunyuanFusedMoECuda(SharedFusedMoE): + def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: + super().__init__(prefix=prefix, **kwargs) + self._prefix = prefix + self._init_hook_handle = self.register_forward_pre_hook( + self._initialize_kernel_hook, with_kwargs=True + ) + + def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: + if self.quant_method: + self.quant_method.process_weights_after_loading(self) + self._init_hook_handle.remove() + + def forward(self, hidden_states: Any, router_logits: Any) -> Any: + return super().forward(hidden_states, router_logits) + + return HunyuanFusedMoECuda + + +def _get_npu_impl_class() -> type: + from vllm_ascend.ops.fused_moe.fused_moe import AscendSharedFusedMoE + from vllm_ascend.ops.fused_moe.moe_comm_method import _MoECommMethods + from vllm_ascend.ascend_forward_context import MoECommType + from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type + + # Workaround for vllm-ascend: mc2_group must be initialized to prevent errors, despite being unused in FusedMoE communication. + world_size = torch.distributed.get_world_size() + data_parallel_size = get_data_parallel_world_size() + tensor_parallel_size = get_tensor_model_parallel_world_size() + backend = torch.distributed.get_backend(get_world_group().device_group) + local_rank = get_world_group().local_rank + _init_mc2_group_for_diffusion_npu( + world_size=world_size, + data_parallel_size=data_parallel_size, + tensor_parallel_size=tensor_parallel_size, + backend=backend, + local_rank=local_rank, + ) + + if not hasattr(_vllm_fc.ForwardContext, "moe_comm_method"): + _vllm_fc.ForwardContext.__annotations__["in_profile_run"] = bool + _vllm_fc.ForwardContext.in_profile_run = False + + def select_moe_comm_method(vllm_config: VllmConfig) -> MoECommType | None: + soc_version = get_ascend_device_type() + if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A2}: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A3}: + moe_comm_type = MoECommType.ALLTOALL + elif soc_version in {AscendDeviceType._310P}: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A5}: + moe_comm_type = MoECommType.ALLTOALL + else: + raise ValueError(f"Unsupported soc_version: {soc_version}") + return moe_comm_type + + + class HunyuanFusedMoENPU(AscendSharedFusedMoE): + def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: + super().__init__(prefix=prefix, **kwargs) + self._prefix = prefix + self._init_hook_handle = self.register_forward_pre_hook( + self._initialize_kernel_hook, with_kwargs=True + ) + + _vllm_fc.ForwardContext.moe_comm_type = select_moe_comm_method(vllm_config=omni_get_ctx().vllm_config) + _vllm_fc.ForwardContext.moe_comm_method=_MoECommMethods.get(_vllm_fc.ForwardContext.moe_comm_type) + _vllm_fc.ForwardContext.flash_comm_v1_enabled=False + + + def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: + if self.quant_method: + self.quant_method.process_weights_after_loading(self) + self._init_hook_handle.remove() + + def forward(self, hidden_states: Any, router_logits: Any) -> Any: + return super().forward(hidden_states, router_logits) + + def __del__(self): + if vllm_ascend_parallel_state._MC2: + vllm_ascend_parallel_state._MC2.destroy() + vllm_ascend_parallel_state._MC2 = None + + return HunyuanFusedMoENPU + + +class HunyuanFusedMoE: + def __new__(cls, *, prefix: str = "", **kwargs: Any) -> Any: + impl = _get_impl_class() + return impl(prefix=prefix, **kwargs) + + @classmethod + def make_expert_params_mapping( + cls, + model: Any, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + num_redundant_experts: int = 0, + ) -> list[tuple[str, str, int, str]]: + return _get_impl_class().make_expert_params_mapping( + model, + ckpt_gate_proj_name=ckpt_gate_proj_name, + ckpt_down_proj_name=ckpt_down_proj_name, + ckpt_up_proj_name=ckpt_up_proj_name, + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, + ) \ No newline at end of file diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py index 0c2f9e290ac..08a229d68b2 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py @@ -63,6 +63,7 @@ from vllm_omni.diffusion.distributed.parallel_state import get_pp_group from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.layers.rope import RotaryEmbedding +from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE logger = logging.getLogger(__name__) @@ -1417,7 +1418,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states[0] + final_hidden_states[1] if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) return final_hidden_states.view(orig_shape) @@ -1564,23 +1565,6 @@ def forward( output = output.reshape(bsz, q_len, -1) return output, None, past_key_value - -class HunyuanFusedMoE(SharedFusedMoE): - def __init__(self, *, prefix: str = "", **kwargs): - super().__init__(prefix=prefix, **kwargs) - self._prefix = prefix - - self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) - - def _initialize_kernel_hook(self, module, args, kwargs): - if self.quant_method: - self.quant_method.process_weights_after_loading(self) - self._init_hook_handle.remove() - - def forward(self, hidden_states, router_logits): - return super().forward(hidden_states, router_logits) - - class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config: HunyuanImage3Config, layer_idx: int, prefix: str = ""): super().__init__() @@ -2454,7 +2438,6 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From bac1937fe420a2c0c08647804b0eeefd4dab45ea Mon Sep 17 00:00:00 2001 From: ElleElleWu <1608928702@qq.com> Date: Fri, 6 Mar 2026 15:30:12 +0000 Subject: [PATCH 2/5] fix is_moe type and threshold, add UT for is_moe, Hunyuan_fused_moe Signed-off-by: ElleElleWu <1608928702@qq.com> --- docs/models/supported_models.md | 1 + .../hunyuan_image_3/test_hunyuan_fused_moe.py | 109 ++++++++++++++++++ tests/diffusion/test_data_is_moe.py | 78 +++++++++++++ vllm_omni/diffusion/data.py | 6 +- .../models/hunyuan_image_3/__init__.py | 9 +- .../hunyuan_image_3/hunyuan_fused_moe.py | 49 ++++---- .../hunyuan_image_3_transformer.py | 3 +- 7 files changed, 216 insertions(+), 39 deletions(-) create mode 100644 tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py create mode 100644 tests/diffusion/test_data_is_moe.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 67a49c5755c..52e75989d2f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -65,6 +65,7 @@ th { |--------------|--------|-------------------| | `Qwen3OmniMoeForConditionalGeneration` | Qwen3-Omni | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | | `Qwen2_5OmniForConditionalGeneration` | Qwen2.5-Omni | `Qwen/Qwen2.5-Omni-7B`, `Qwen/Qwen2.5-Omni-3B`| +| `HunyuanImage3ForCausalMM` | HunyuanImage3.0 (DiT-only) | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | | `QwenImagePipeline` | Qwen-Image | `Qwen/Qwen-Image` | | `QwenImagePipeline` | Qwen-Image-2512 | `Qwen/Qwen-Image-2512` | | `QwenImageEditPipeline` | Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | diff --git a/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py b/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py new file mode 100644 index 00000000000..76d4b5456ca --- /dev/null +++ b/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for HunyuanFusedMoE (Support HunyuanImage3 Diffusion Model, 5a779b4).""" + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class TestHunyuanFusedMoEPlatformDispatch: + """Test platform dispatch and NotImplementedError for unknown platform.""" + + def test_unknown_platform_raises_not_implemented_error(self, mocker): + """HunyuanFusedMoE should raise NotImplementedError when platform is not NPU or CUDA.""" + import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe + + # Clear cached impl so _get_impl_class() is re-run with mocked platform + hunyuan_moe._impl_class = None + + mock_platform = mocker.MagicMock() + mock_platform.is_npu.return_value = False + mock_platform.is_cuda.return_value = False + mock_platform.__repr__ = lambda self: "UnknownPlatform" + + mocker.patch.object( + hunyuan_moe, + "current_omni_platform", + mock_platform, + ) + + from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import ( + HunyuanFusedMoE, + ) + + with pytest.raises(NotImplementedError) as exc_info: + HunyuanFusedMoE(prefix="") + + assert "HunyuanFusedMoE is not implemented" in str(exc_info.value) + assert "current_omni_platform" in str(exc_info.value) + + +class TestHunyuanFusedMoEFactory: + """Test HunyuanFusedMoE factory __new__ and make_expert_params_mapping delegation.""" + + def test_new_delegates_to_impl_class(self, mocker): + """HunyuanFusedMoE(prefix=..., **kwargs) should instantiate and return impl instance.""" + import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe + + hunyuan_moe._impl_class = None + + # Use a simple mock class as impl so we don't need CUDA/NPU + class MockImpl: + def __init__(self, *, prefix: str = "", **kwargs): + self.prefix = prefix + self.kwargs = kwargs + + mock_impl_class = mocker.MagicMock(return_value=MockImpl(prefix="test", a=1)) + mocker.patch.object( + hunyuan_moe, + "_get_impl_class", + return_value=mock_impl_class, + ) + + from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import ( + HunyuanFusedMoE, + ) + + result = HunyuanFusedMoE(prefix="test", a=1) + + assert isinstance(result, MockImpl) + assert result.prefix == "test" + assert result.kwargs == {"a": 1} + mock_impl_class.assert_called_once_with(prefix="test", a=1) + + def test_make_expert_params_mapping_delegates_to_impl(self, mocker): + """make_expert_params_mapping should delegate to impl class method.""" + import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe + + expected_mapping = [("a", "b", 0, "c")] + mock_impl_class = mocker.MagicMock() + mock_impl_class.make_expert_params_mapping = mocker.MagicMock(return_value=expected_mapping) + mocker.patch.object( + hunyuan_moe, + "_get_impl_class", + return_value=mock_impl_class, + ) + + from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import ( + HunyuanFusedMoE, + ) + + result = HunyuanFusedMoE.make_expert_params_mapping( + model=None, + ckpt_gate_proj_name="gate", + ckpt_down_proj_name="down", + ckpt_up_proj_name="up", + num_experts=4, + num_redundant_experts=0, + ) + + assert result == expected_mapping + mock_impl_class.make_expert_params_mapping.assert_called_once_with( + None, + ckpt_gate_proj_name="gate", + ckpt_down_proj_name="down", + ckpt_up_proj_name="up", + num_experts=4, + num_redundant_experts=0, + ) diff --git a/tests/diffusion/test_data_is_moe.py b/tests/diffusion/test_data_is_moe.py new file mode 100644 index 00000000000..25fa59ef1db --- /dev/null +++ b/tests/diffusion/test_data_is_moe.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for OmniDiffusionConfig.is_moe (fix is_moe type and threshold, 6663c0b).""" + +import pytest + +from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +class TestOmniDiffusionConfigIsMoE: + """Tests for OmniDiffusionConfig.is_moe property. + + Covers commit 6663c0b: fix is_moe type and threshold + - num_experts must be (list, tuple, int); otherwise return False. + - Threshold: is_moe is True when num_experts > 0 (not > 1). + """ + + def test_is_moe_missing_num_experts_returns_false(self): + """When num_experts is absent, is_moe should be False.""" + tf_config = TransformerConfig.from_dict({}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is False + + def test_is_moe_none_num_experts_returns_false(self): + """When num_experts is explicitly None (e.g. in params), is_moe should be False.""" + tf_config = TransformerConfig.from_dict({"num_experts": None}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is False + + def test_is_moe_non_allowed_type_returns_false(self): + """When num_experts is not int/list/tuple (e.g. str), is_moe should be False.""" + tf_config = TransformerConfig.from_dict({"num_experts": "2"}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is False + + def test_is_moe_int_zero_returns_false(self): + """num_experts int 0 should yield is_moe False (threshold > 0).""" + tf_config = TransformerConfig.from_dict({"num_experts": 0}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is False + + def test_is_moe_int_one_returns_true(self): + """num_experts int 1 should yield is_moe True (threshold > 0, not > 1).""" + tf_config = TransformerConfig.from_dict({"num_experts": 1}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is True + + def test_is_moe_int_gt_one_returns_true(self): + """num_experts int > 1 should yield is_moe True.""" + tf_config = TransformerConfig.from_dict({"num_experts": 2}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is True + + def test_is_moe_list_all_zero_returns_false(self): + """num_experts list with all <= 0 should yield is_moe False.""" + tf_config = TransformerConfig.from_dict({"num_experts": [0]}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is False + + def test_is_moe_list_has_positive_returns_true(self): + """num_experts list with any int > 0 should yield is_moe True.""" + tf_config = TransformerConfig.from_dict({"num_experts": [0, 1]}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is True + + def test_is_moe_tuple_has_positive_returns_true(self): + """num_experts tuple with any int > 0 should yield is_moe True.""" + tf_config = TransformerConfig.from_dict({"num_experts": (0, 2)}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is True + + def test_is_moe_list_non_int_ignored(self): + """num_experts list with only non-int entries should yield is_moe False.""" + tf_config = TransformerConfig.from_dict({"num_experts": ["a", 0.0]}) + config = OmniDiffusionConfig(model="test", tf_model_config=tf_config) + assert config.is_moe is False diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 38366469eb1..634cb329414 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -463,11 +463,13 @@ class OmniDiffusionConfig: @property def is_moe(self) -> bool: num_experts = self.tf_model_config.get("num_experts", None) + if not isinstance(num_experts, (list, tuple, int)): + return False if isinstance(num_experts, int): - return num_experts > 1 + return num_experts > 0 if isinstance(num_experts, (list, tuple)): - return any(isinstance(n, int) and n > 1 for n in num_experts) + return any(isinstance(n, int) and n > 0 for n in num_experts) return False diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py b/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py index 5535dcb417c..cbc6a8ad1f4 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Hunyuan Image 3 diffusion model components.""" +from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_image_3_transformer import ( HunyuanImage3Model, HunyuanImage3Text2ImagePipeline, @@ -9,11 +10,5 @@ from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import ( HunyuanImage3Pipeline, ) -from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE -__all__ = [ - "HunyuanImage3Pipeline", - "HunyuanImage3Model", - "HunyuanImage3Text2ImagePipeline", - "HunyuanFusedMoE" -] +__all__ = ["HunyuanImage3Pipeline", "HunyuanImage3Model", "HunyuanImage3Text2ImagePipeline", "HunyuanFusedMoE"] diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py index 2205cc2017b..4c1be77965e 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py @@ -2,30 +2,29 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -from typing import Any, Optional -import torch +from typing import Any +import torch +import vllm.forward_context as _vllm_fc from vllm.config import VllmConfig from vllm.distributed import get_ep_group +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import ( init_model_parallel_group as vllm_init_model_parallel_group, ) -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size -) -import vllm.forward_context as _vllm_fc -from vllm_omni.platforms import current_omni_platform -from vllm_omni.diffusion.forward_context import get_forward_context as omni_get_ctx from vllm_omni.diffusion.distributed.parallel_state import ( get_data_parallel_world_size, get_world_group, ) +from vllm_omni.diffusion.forward_context import get_forward_context as omni_get_ctx +from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) _impl_class: type | None = None + def _init_mc2_group_for_diffusion_npu( world_size: int, data_parallel_size: int, @@ -37,9 +36,7 @@ def _init_mc2_group_for_diffusion_npu( if getattr(vllm_ascend_parallel_state, "_MC2", None) is not None: return - all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size * tensor_parallel_size - ) + all_ranks = torch.arange(world_size).reshape(-1, data_parallel_size * tensor_parallel_size) group_ranks = all_ranks.unbind(0) group_ranks = [x.tolist() for x in group_ranks] @@ -61,8 +58,7 @@ def _get_impl_class() -> type: _impl_class = _get_cuda_impl_class() else: raise NotImplementedError( - f"HunyuanFusedMoE is not implemented for current_omni_platform: " - f"{current_omni_platform!r}" + f"HunyuanFusedMoE is not implemented for current_omni_platform: {current_omni_platform!r}" ) return _impl_class @@ -74,9 +70,7 @@ class HunyuanFusedMoECuda(SharedFusedMoE): def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: super().__init__(prefix=prefix, **kwargs) self._prefix = prefix - self._init_hook_handle = self.register_forward_pre_hook( - self._initialize_kernel_hook, with_kwargs=True - ) + self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: if self.quant_method: @@ -90,12 +84,13 @@ def forward(self, hidden_states: Any, router_logits: Any) -> Any: def _get_npu_impl_class() -> type: + from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.fused_moe import AscendSharedFusedMoE from vllm_ascend.ops.fused_moe.moe_comm_method import _MoECommMethods - from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - # Workaround for vllm-ascend: mc2_group must be initialized to prevent errors, despite being unused in FusedMoE communication. + # Workaround for vllm-ascend: mc2_group must be initialized to prevent errors, + # despite being unused in FusedMoE communication. world_size = torch.distributed.get_world_size() data_parallel_size = get_data_parallel_world_size() tensor_parallel_size = get_tensor_model_parallel_world_size() @@ -120,28 +115,24 @@ def select_moe_comm_method(vllm_config: VllmConfig) -> MoECommType | None: elif soc_version in {AscendDeviceType.A2}: moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType.A3}: - moe_comm_type = MoECommType.ALLTOALL + moe_comm_type = MoECommType.ALLTOALL elif soc_version in {AscendDeviceType._310P}: moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType.A5}: - moe_comm_type = MoECommType.ALLTOALL + moe_comm_type = MoECommType.ALLTOALL else: raise ValueError(f"Unsupported soc_version: {soc_version}") return moe_comm_type - class HunyuanFusedMoENPU(AscendSharedFusedMoE): def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: super().__init__(prefix=prefix, **kwargs) self._prefix = prefix - self._init_hook_handle = self.register_forward_pre_hook( - self._initialize_kernel_hook, with_kwargs=True - ) + self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) _vllm_fc.ForwardContext.moe_comm_type = select_moe_comm_method(vllm_config=omni_get_ctx().vllm_config) - _vllm_fc.ForwardContext.moe_comm_method=_MoECommMethods.get(_vllm_fc.ForwardContext.moe_comm_type) - _vllm_fc.ForwardContext.flash_comm_v1_enabled=False - + _vllm_fc.ForwardContext.moe_comm_method = _MoECommMethods.get(_vllm_fc.ForwardContext.moe_comm_type) + _vllm_fc.ForwardContext.flash_comm_v1_enabled = False def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: if self.quant_method: @@ -152,6 +143,8 @@ def forward(self, hidden_states: Any, router_logits: Any) -> Any: return super().forward(hidden_states, router_logits) def __del__(self): + import vllm_ascend.distributed.parallel_state as vllm_ascend_parallel_state + if vllm_ascend_parallel_state._MC2: vllm_ascend_parallel_state._MC2.destroy() vllm_ascend_parallel_state._MC2 = None @@ -181,4 +174,4 @@ def make_expert_params_mapping( ckpt_up_proj_name=ckpt_up_proj_name, num_experts=num_experts, num_redundant_experts=num_redundant_experts, - ) \ No newline at end of file + ) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py index 08a229d68b2..a89931550de 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py @@ -32,10 +32,8 @@ from vllm.config import CacheConfig from vllm.distributed import ( get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -1565,6 +1563,7 @@ def forward( output = output.reshape(bsz, q_len, -1) return output, None, past_key_value + class HunyuanImage3DecoderLayer(nn.Module): def __init__(self, config: HunyuanImage3Config, layer_idx: int, prefix: str = ""): super().__init__() From 6603fc7293006e106d0fc682dffff13c36f1f1b0 Mon Sep 17 00:00:00 2001 From: ElleElleWu <1608928702@qq.com> Date: Thu, 12 Mar 2026 17:40:42 +0800 Subject: [PATCH 3/5] fix hunyuan_fused_moe for xpu and other devices Signed-off-by: ElleElleWu <1608928702@qq.com> --- .../models/hunyuan_image_3/hunyuan_fused_moe.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py index 4c1be77965e..6d9f0bb752d 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py @@ -54,19 +54,15 @@ def _get_impl_class() -> type: return _impl_class if current_omni_platform.is_npu(): _impl_class = _get_npu_impl_class() - elif current_omni_platform.is_cuda(): - _impl_class = _get_cuda_impl_class() else: - raise NotImplementedError( - f"HunyuanFusedMoE is not implemented for current_omni_platform: {current_omni_platform!r}" - ) + _impl_class = _get_default_impl_class() return _impl_class -def _get_cuda_impl_class() -> type: +def _get_default_impl_class() -> type: from vllm.model_executor.layers.fused_moe import SharedFusedMoE - class HunyuanFusedMoECuda(SharedFusedMoE): + class HunyuanFusedMoEDefault(SharedFusedMoE): def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: super().__init__(prefix=prefix, **kwargs) self._prefix = prefix @@ -80,7 +76,7 @@ def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: def forward(self, hidden_states: Any, router_logits: Any) -> Any: return super().forward(hidden_states, router_logits) - return HunyuanFusedMoECuda + return HunyuanFusedMoEDefault def _get_npu_impl_class() -> type: From 79d77c97f19096b2a6fa193112f66bcb164a1e52 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Thu, 12 Mar 2026 11:10:34 +0000 Subject: [PATCH 4/5] Refactor hardware dispathc Signed-off-by: gcanlin --- .../hunyuan_image_3/test_hunyuan_fused_moe.py | 55 +++--- .../hunyuan_image_3/hunyuan_fused_moe.py | 161 +++--------------- vllm_omni/platforms/interface.py | 11 ++ vllm_omni/platforms/npu/models/__init__.py | 2 + .../platforms/npu/models/hunyuan_fused_moe.py | 108 ++++++++++++ vllm_omni/platforms/npu/platform.py | 19 +++ 6 files changed, 191 insertions(+), 165 deletions(-) create mode 100644 vllm_omni/platforms/npu/models/__init__.py create mode 100644 vllm_omni/platforms/npu/models/hunyuan_fused_moe.py diff --git a/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py b/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py index 76d4b5456ca..2aa1adf4449 100644 --- a/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py +++ b/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py @@ -8,35 +8,38 @@ class TestHunyuanFusedMoEPlatformDispatch: - """Test platform dispatch and NotImplementedError for unknown platform.""" + """Test platform dispatch via platform qualname hooks.""" - def test_unknown_platform_raises_not_implemented_error(self, mocker): - """HunyuanFusedMoE should raise NotImplementedError when platform is not NPU or CUDA.""" + def test_default_platform_uses_default_impl_qualname(self, mocker): + """HunyuanFusedMoE should resolve the impl class from the platform hook.""" import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe - # Clear cached impl so _get_impl_class() is re-run with mocked platform - hunyuan_moe._impl_class = None - mock_platform = mocker.MagicMock() - mock_platform.is_npu.return_value = False - mock_platform.is_cuda.return_value = False - mock_platform.__repr__ = lambda self: "UnknownPlatform" + mock_platform.get_diffusion_model_impl_qualname.return_value = ( + "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" + ) mocker.patch.object( hunyuan_moe, "current_omni_platform", mock_platform, ) + mock_resolve = mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname") + mock_impl = mocker.MagicMock() + mock_resolve.return_value = mock_impl from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import ( HunyuanFusedMoE, ) - with pytest.raises(NotImplementedError) as exc_info: - HunyuanFusedMoE(prefix="") + HunyuanFusedMoE(prefix="") - assert "HunyuanFusedMoE is not implemented" in str(exc_info.value) - assert "current_omni_platform" in str(exc_info.value) + mock_platform.prepare_diffusion_op_runtime.assert_called_once_with("hunyuan_fused_moe") + mock_platform.get_diffusion_model_impl_qualname.assert_called_once_with("hunyuan_fused_moe") + mock_resolve.assert_called_once_with( + "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" + ) + mock_impl.assert_called_once_with(prefix="") class TestHunyuanFusedMoEFactory: @@ -46,20 +49,17 @@ def test_new_delegates_to_impl_class(self, mocker): """HunyuanFusedMoE(prefix=..., **kwargs) should instantiate and return impl instance.""" import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe - hunyuan_moe._impl_class = None - - # Use a simple mock class as impl so we don't need CUDA/NPU class MockImpl: def __init__(self, *, prefix: str = "", **kwargs): self.prefix = prefix self.kwargs = kwargs + mock_platform = mocker.MagicMock() + mock_platform.get_diffusion_model_impl_qualname.return_value = "mock.impl.Qualname" + mocker.patch.object(hunyuan_moe, "current_omni_platform", mock_platform) + mock_impl_class = mocker.MagicMock(return_value=MockImpl(prefix="test", a=1)) - mocker.patch.object( - hunyuan_moe, - "_get_impl_class", - return_value=mock_impl_class, - ) + mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname", return_value=mock_impl_class) from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import ( HunyuanFusedMoE, @@ -70,6 +70,8 @@ def __init__(self, *, prefix: str = "", **kwargs): assert isinstance(result, MockImpl) assert result.prefix == "test" assert result.kwargs == {"a": 1} + mock_platform.prepare_diffusion_op_runtime.assert_called_once_with("hunyuan_fused_moe") + mock_platform.get_diffusion_model_impl_qualname.assert_called_once_with("hunyuan_fused_moe") mock_impl_class.assert_called_once_with(prefix="test", a=1) def test_make_expert_params_mapping_delegates_to_impl(self, mocker): @@ -77,13 +79,13 @@ def test_make_expert_params_mapping_delegates_to_impl(self, mocker): import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe expected_mapping = [("a", "b", 0, "c")] + mock_platform = mocker.MagicMock() + mock_platform.get_diffusion_model_impl_qualname.return_value = "mock.impl.Qualname" + mocker.patch.object(hunyuan_moe, "current_omni_platform", mock_platform) + mock_impl_class = mocker.MagicMock() mock_impl_class.make_expert_params_mapping = mocker.MagicMock(return_value=expected_mapping) - mocker.patch.object( - hunyuan_moe, - "_get_impl_class", - return_value=mock_impl_class, - ) + mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname", return_value=mock_impl_class) from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import ( HunyuanFusedMoE, @@ -99,6 +101,7 @@ def test_make_expert_params_mapping_delegates_to_impl(self, mocker): ) assert result == expected_mapping + mock_platform.get_diffusion_model_impl_qualname.assert_called_once_with("hunyuan_fused_moe") mock_impl_class.make_expert_params_mapping.assert_called_once_with( None, ckpt_gate_proj_name="gate", diff --git a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py index 6d9f0bb752d..5ada5ceb848 100644 --- a/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py +++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py @@ -1,156 +1,36 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import logging from typing import Any -import torch -import vllm.forward_context as _vllm_fc -from vllm.config import VllmConfig -from vllm.distributed import get_ep_group -from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import ( - init_model_parallel_group as vllm_init_model_parallel_group, -) +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm_omni.diffusion.distributed.parallel_state import ( - get_data_parallel_world_size, - get_world_group, -) -from vllm_omni.diffusion.forward_context import get_forward_context as omni_get_ctx from vllm_omni.platforms import current_omni_platform -logger = logging.getLogger(__name__) -_impl_class: type | None = None +class HunyuanFusedMoEDefault(SharedFusedMoE): + def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: + super().__init__(prefix=prefix, **kwargs) + self._prefix = prefix + self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) + def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: + if self.quant_method: + self.quant_method.process_weights_after_loading(self) + self._init_hook_handle.remove() -def _init_mc2_group_for_diffusion_npu( - world_size: int, - data_parallel_size: int, - tensor_parallel_size: int, - backend: str, - local_rank: int, -) -> None: - import vllm_ascend.distributed.parallel_state as vllm_ascend_parallel_state - - if getattr(vllm_ascend_parallel_state, "_MC2", None) is not None: - return - all_ranks = torch.arange(world_size).reshape(-1, data_parallel_size * tensor_parallel_size) - group_ranks = all_ranks.unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - - vllm_ascend_parallel_state._MC2 = vllm_init_model_parallel_group( - group_ranks, - local_rank, - backend, - group_name="mc2", - ) - - -def _get_impl_class() -> type: - global _impl_class - if _impl_class is not None: - return _impl_class - if current_omni_platform.is_npu(): - _impl_class = _get_npu_impl_class() - else: - _impl_class = _get_default_impl_class() - return _impl_class - - -def _get_default_impl_class() -> type: - from vllm.model_executor.layers.fused_moe import SharedFusedMoE - - class HunyuanFusedMoEDefault(SharedFusedMoE): - def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: - super().__init__(prefix=prefix, **kwargs) - self._prefix = prefix - self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) - - def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: - if self.quant_method: - self.quant_method.process_weights_after_loading(self) - self._init_hook_handle.remove() - - def forward(self, hidden_states: Any, router_logits: Any) -> Any: - return super().forward(hidden_states, router_logits) - - return HunyuanFusedMoEDefault - - -def _get_npu_impl_class() -> type: - from vllm_ascend.ascend_forward_context import MoECommType - from vllm_ascend.ops.fused_moe.fused_moe import AscendSharedFusedMoE - from vllm_ascend.ops.fused_moe.moe_comm_method import _MoECommMethods - from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type - - # Workaround for vllm-ascend: mc2_group must be initialized to prevent errors, - # despite being unused in FusedMoE communication. - world_size = torch.distributed.get_world_size() - data_parallel_size = get_data_parallel_world_size() - tensor_parallel_size = get_tensor_model_parallel_world_size() - backend = torch.distributed.get_backend(get_world_group().device_group) - local_rank = get_world_group().local_rank - _init_mc2_group_for_diffusion_npu( - world_size=world_size, - data_parallel_size=data_parallel_size, - tensor_parallel_size=tensor_parallel_size, - backend=backend, - local_rank=local_rank, - ) - - if not hasattr(_vllm_fc.ForwardContext, "moe_comm_method"): - _vllm_fc.ForwardContext.__annotations__["in_profile_run"] = bool - _vllm_fc.ForwardContext.in_profile_run = False - - def select_moe_comm_method(vllm_config: VllmConfig) -> MoECommType | None: - soc_version = get_ascend_device_type() - if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1: - moe_comm_type = MoECommType.ALLGATHER - elif soc_version in {AscendDeviceType.A2}: - moe_comm_type = MoECommType.ALLGATHER - elif soc_version in {AscendDeviceType.A3}: - moe_comm_type = MoECommType.ALLTOALL - elif soc_version in {AscendDeviceType._310P}: - moe_comm_type = MoECommType.ALLGATHER - elif soc_version in {AscendDeviceType.A5}: - moe_comm_type = MoECommType.ALLTOALL - else: - raise ValueError(f"Unsupported soc_version: {soc_version}") - return moe_comm_type - - class HunyuanFusedMoENPU(AscendSharedFusedMoE): - def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: - super().__init__(prefix=prefix, **kwargs) - self._prefix = prefix - self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) - - _vllm_fc.ForwardContext.moe_comm_type = select_moe_comm_method(vllm_config=omni_get_ctx().vllm_config) - _vllm_fc.ForwardContext.moe_comm_method = _MoECommMethods.get(_vllm_fc.ForwardContext.moe_comm_type) - _vllm_fc.ForwardContext.flash_comm_v1_enabled = False - - def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: - if self.quant_method: - self.quant_method.process_weights_after_loading(self) - self._init_hook_handle.remove() - - def forward(self, hidden_states: Any, router_logits: Any) -> Any: - return super().forward(hidden_states, router_logits) - - def __del__(self): - import vllm_ascend.distributed.parallel_state as vllm_ascend_parallel_state - - if vllm_ascend_parallel_state._MC2: - vllm_ascend_parallel_state._MC2.destroy() - vllm_ascend_parallel_state._MC2 = None - - return HunyuanFusedMoENPU + def forward(self, hidden_states: Any, router_logits: Any) -> Any: + return super().forward(hidden_states, router_logits) class HunyuanFusedMoE: def __new__(cls, *, prefix: str = "", **kwargs: Any) -> Any: - impl = _get_impl_class() + op_name = "hunyuan_fused_moe" + current_omni_platform.prepare_diffusion_op_runtime(op_name) + impl = resolve_obj_by_qualname( + current_omni_platform.get_diffusion_model_impl_qualname(op_name), + ) return impl(prefix=prefix, **kwargs) @classmethod @@ -163,7 +43,10 @@ def make_expert_params_mapping( num_experts: int, num_redundant_experts: int = 0, ) -> list[tuple[str, str, int, str]]: - return _get_impl_class().make_expert_params_mapping( + impl = resolve_obj_by_qualname( + current_omni_platform.get_diffusion_model_impl_qualname("hunyuan_fused_moe"), + ) + return impl.make_expert_params_mapping( model, ckpt_gate_proj_name=ckpt_gate_proj_name, ckpt_down_proj_name=ckpt_down_proj_name, diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py index f90c69e3a53..314cb3219e5 100644 --- a/vllm_omni/platforms/interface.py +++ b/vllm_omni/platforms/interface.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import Any import torch from vllm.platforms import Platform @@ -52,6 +53,16 @@ def get_omni_generation_worker_cls(cls) -> str: def get_default_stage_config_path(cls) -> str: raise NotImplementedError + @classmethod + def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: + if op_name == "hunyuan_fused_moe": + return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault" + raise NotImplementedError(f"Unsupported diffusion model op: {op_name}") + + @classmethod + def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None: + return None + @classmethod def get_diffusion_attn_backend_cls( cls, diff --git a/vllm_omni/platforms/npu/models/__init__.py b/vllm_omni/platforms/npu/models/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/platforms/npu/models/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py new file mode 100644 index 00000000000..274ab772443 --- /dev/null +++ b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch +import vllm.forward_context as _vllm_fc +from vllm.config import VllmConfig +from vllm.distributed import get_ep_group +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import ( + init_model_parallel_group as vllm_init_model_parallel_group, +) +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.fused_moe.fused_moe import AscendSharedFusedMoE +from vllm_ascend.ops.fused_moe.moe_comm_method import _MoECommMethods +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type + +from vllm_omni.diffusion.distributed.parallel_state import ( + get_data_parallel_world_size, + get_world_group, +) +from vllm_omni.diffusion.forward_context import get_forward_context as omni_get_ctx + + +def _init_mc2_group_for_diffusion_npu( + world_size: int, + data_parallel_size: int, + tensor_parallel_size: int, + backend: str, + local_rank: int, +) -> None: + import vllm_ascend.distributed.parallel_state as vllm_ascend_parallel_state + + if getattr(vllm_ascend_parallel_state, "_MC2", None) is not None: + return + all_ranks = torch.arange(world_size).reshape(-1, data_parallel_size * tensor_parallel_size) + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + vllm_ascend_parallel_state._MC2 = vllm_init_model_parallel_group( + group_ranks, + local_rank, + backend, + group_name="mc2", + ) + + +def _select_moe_comm_method(vllm_config: VllmConfig) -> MoECommType | None: + soc_version = get_ascend_device_type() + if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A2}: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A3}: + moe_comm_type = MoECommType.ALLTOALL + elif soc_version in {AscendDeviceType._310P}: + moe_comm_type = MoECommType.ALLGATHER + elif soc_version in {AscendDeviceType.A5}: + moe_comm_type = MoECommType.ALLTOALL + else: + raise ValueError(f"Unsupported soc_version: {soc_version}") + return moe_comm_type + + +def prepare_hunyuan_fused_moe_runtime() -> None: + world_size = torch.distributed.get_world_size() + data_parallel_size = get_data_parallel_world_size() + tensor_parallel_size = get_tensor_model_parallel_world_size() + backend = torch.distributed.get_backend(get_world_group().device_group) + local_rank = get_world_group().local_rank + _init_mc2_group_for_diffusion_npu( + world_size=world_size, + data_parallel_size=data_parallel_size, + tensor_parallel_size=tensor_parallel_size, + backend=backend, + local_rank=local_rank, + ) + + if not hasattr(_vllm_fc.ForwardContext, "moe_comm_method"): + _vllm_fc.ForwardContext.__annotations__["in_profile_run"] = bool + _vllm_fc.ForwardContext.in_profile_run = False + + _vllm_fc.ForwardContext.moe_comm_type = _select_moe_comm_method(vllm_config=omni_get_ctx().vllm_config) + _vllm_fc.ForwardContext.moe_comm_method = _MoECommMethods.get(_vllm_fc.ForwardContext.moe_comm_type) + _vllm_fc.ForwardContext.flash_comm_v1_enabled = False + + +class HunyuanFusedMoENPU(AscendSharedFusedMoE): + def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: + super().__init__(prefix=prefix, **kwargs) + self._prefix = prefix + self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) + + def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: + if self.quant_method: + self.quant_method.process_weights_after_loading(self) + self._init_hook_handle.remove() + + def forward(self, hidden_states: Any, router_logits: Any) -> Any: + return super().forward(hidden_states, router_logits) + + def __del__(self): + import vllm_ascend.distributed.parallel_state as vllm_ascend_parallel_state + + if vllm_ascend_parallel_state._MC2: + vllm_ascend_parallel_state._MC2.destroy() + vllm_ascend_parallel_state._MC2 = None diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py index 3c2495c3d35..e2988b9485b 100644 --- a/vllm_omni/platforms/npu/platform.py +++ b/vllm_omni/platforms/npu/platform.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + import torch from vllm.logger import init_logger from vllm_ascend.platform import NPUPlatform @@ -33,6 +35,23 @@ def get_omni_generation_worker_cls(cls) -> str: def get_default_stage_config_path(cls) -> str: return "vllm_omni/platforms/npu/stage_configs" + @classmethod + def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: + if op_name == "hunyuan_fused_moe": + return "vllm_omni.platforms.npu.models.hunyuan_fused_moe.HunyuanFusedMoENPU" + return super().get_diffusion_model_impl_qualname(op_name) + + @classmethod + def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None: + if op_name != "hunyuan_fused_moe": + return + + from vllm_omni.platforms.npu.models.hunyuan_fused_moe import ( + prepare_hunyuan_fused_moe_runtime, + ) + + prepare_hunyuan_fused_moe_runtime() + @classmethod def get_diffusion_attn_backend_cls( cls, From 022e40798d6fb7727cb4c958a31366b7bb0affa8 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Thu, 12 Mar 2026 11:19:53 +0000 Subject: [PATCH 5/5] rename Signed-off-by: gcanlin --- vllm_omni/platforms/npu/models/hunyuan_fused_moe.py | 6 +++--- vllm_omni/platforms/npu/platform.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py index 274ab772443..46c76a65290 100644 --- a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py +++ b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py @@ -23,7 +23,7 @@ from vllm_omni.diffusion.forward_context import get_forward_context as omni_get_ctx -def _init_mc2_group_for_diffusion_npu( +def _init_mc2_group_for_diffusion( world_size: int, data_parallel_size: int, tensor_parallel_size: int, @@ -69,7 +69,7 @@ def prepare_hunyuan_fused_moe_runtime() -> None: tensor_parallel_size = get_tensor_model_parallel_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) local_rank = get_world_group().local_rank - _init_mc2_group_for_diffusion_npu( + _init_mc2_group_for_diffusion( world_size=world_size, data_parallel_size=data_parallel_size, tensor_parallel_size=tensor_parallel_size, @@ -86,7 +86,7 @@ def prepare_hunyuan_fused_moe_runtime() -> None: _vllm_fc.ForwardContext.flash_comm_v1_enabled = False -class HunyuanFusedMoENPU(AscendSharedFusedMoE): +class AscendHunyuanFusedMoE(AscendSharedFusedMoE): def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: super().__init__(prefix=prefix, **kwargs) self._prefix = prefix diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py index e2988b9485b..bda4e4f6155 100644 --- a/vllm_omni/platforms/npu/platform.py +++ b/vllm_omni/platforms/npu/platform.py @@ -38,7 +38,7 @@ def get_default_stage_config_path(cls) -> str: @classmethod def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: if op_name == "hunyuan_fused_moe": - return "vllm_omni.platforms.npu.models.hunyuan_fused_moe.HunyuanFusedMoENPU" + return "vllm_omni.platforms.npu.models.hunyuan_fused_moe.AscendHunyuanFusedMoE" return super().get_diffusion_model_impl_qualname(op_name) @classmethod