diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py index 63519cf5a58..1a3196c14ea 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py @@ -45,7 +45,9 @@ def _template_moe( valid_mask, selected_experts, torch.full_like(selected_experts, num_experts) ) # Create one-hot encoding with an extra class. - one_hot = F.one_hot(selected_experts_fixed, num_classes=num_experts + 1) + # NOTE: `F.one_hot` only accepts `LongTensor` as an input, and will throw an error if the tensor is of another + # dtype, even if `torch.int32`. + one_hot = F.one_hot(selected_experts_fixed.long(), num_classes=num_experts + 1) expert_mask = one_hot[..., :num_experts].permute(2, 1, 0) for expert_idx in range(num_experts): diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 7450d4e0393..00cde0dd314 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -222,6 +222,7 @@ def _build_model(self, device: DeviceLikeType) -> nn.Module: if custom_model_cls is not None: # `_from_config` has some behavior we would like to use where possible. It is # defined in the `PreTrainedModel` mixin. + ad_logger.info(f"Using custom model implementation {custom_model_cls}") if not hasattr(custom_model_cls, "_from_config"): raise ValueError( f"`{custom_model_cls.__name__}` must have a `_from_config` class method. " diff --git a/tensorrt_llm/_torch/auto_deploy/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/auto_deploy/models/modeling_nemotron_h.py index 75d3b0f199a..6a54617497e 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/auto_deploy/models/modeling_nemotron_h.py @@ -32,6 +32,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput +from tensorrt_llm._torch.auto_deploy.models.patches.nemotron_h import ( + _nemotron_h_moe_forward, + _nemotron_h_topk_router_forward, +) + class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, group_size, eps=1e-5): @@ -261,6 +266,8 @@ def __init__(self, config, layer_idx): self.mixer = NemotronHAttention(config, layer_idx=layer_idx) elif self.block_type == "mlp": self.mixer = NemotronHMLP(config, layer_idx=layer_idx) + elif self.block_type == "moe": + self.mixer = NemotronHMOE(config, layer_idx=layer_idx) else: raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}") @@ -277,12 +284,12 @@ def forward(self, hidden_states): # Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH class NemotronHMLP(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config, layer_idx: int, intermediate_size: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size + self.intermediate_size = intermediate_size or config.intermediate_size self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.mlp_hidden_act] @@ -291,6 +298,50 @@ def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) +class NemotronHMOE(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.experts = nn.ModuleList( + [ + NemotronHMLP( + config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx + ) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = NemotronHTopkRouter(config) + self.shared_experts = NemotronHMLP( + config=config, + intermediate_size=config.moe_shared_expert_intermediate_size, + layer_idx=layer_idx, + ) + + # TODO: inline code from `_nemotron_h_moe_forward` when removing patches. + forward = _nemotron_h_moe_forward + + +class NemotronHTopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, config.hidden_size), dtype=torch.float32) + ) + self.register_buffer( + "e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32) + ) + + forward = _nemotron_h_topk_router_forward + + # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -544,4 +595,6 @@ def forward( # TODO: uncomment after removing patches (and make sure it is imported in `__init__.py`). +# from tensorrt_llm._torch.auto_deploy.models.hf import AutoModelForCausalLMFactory +# # AutoModelForCausalLMFactory.register_custom_model_cls("NemotronHConfig", NemotronHForCausalLM) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py index 2f861da65b5..3ef4e8eb54f 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_nemotron_h_patches.py @@ -1,3 +1,4 @@ +import functools import types import pytest @@ -5,15 +6,40 @@ from _model_test_utils import _hf_model_dir_or_hub_id from transformers import AutoConfig +from tensorrt_llm._torch.auto_deploy.models.modeling_nemotron_h import NemotronHForCausalLM from tensorrt_llm._torch.auto_deploy.models.patches.nemotron_h import ( _from_config_original, _nemotron_h_moe_forward, ) -torch.manual_seed(42) +_BATCH_AND_SEQUENCE_TEST_CASES = ((2, 6), (1, 8)) -def _load_nemotron_moe_layer(model_name_or_path: str): +@pytest.fixture(scope="function", autouse=True) +def set_seed(): + torch.manual_seed(42) + + +def skip_on_no_hf_access(func): + """Decorator for skipping tests that fail due to HF access issues. + + This allows us to share the same test code for CI (where access may be restricted, especially for private + repositories) and locally. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except OSError as e: + if "not a valid model identifier" in str(e): + pytest.skip("Test skipped due to (no) HF access.") + raise + + return wrapper + + +def _load_nemotron_moe_layer(model_name_or_path: str, custom_model_cls=None): """ Build a tiny NemotronH model (1 layer, small dims) and return the first NemotronHMOE module. """ @@ -34,11 +60,14 @@ def _load_nemotron_moe_layer(model_name_or_path: str): cfg.num_key_value_heads = 2 cfg.ssm_state_size = 32 - model = _from_config_original(cfg, trust_remote_code=True) + if custom_model_cls is None: + model = _from_config_original(cfg, trust_remote_code=True) + else: + model = custom_model_cls._from_config(cfg) model.eval() nemotron_moe = None - for name, mod in model.named_modules(): + for _, mod in model.named_modules(): if type(mod).__name__ == "NemotronHMOE": nemotron_moe = mod break @@ -46,9 +75,22 @@ def _load_nemotron_moe_layer(model_name_or_path: str): if nemotron_moe is None: raise RuntimeError("NemotronHMOE layer not found. Check your model id or config.") + _set_gate_weights(nemotron_moe) + return nemotron_moe +def _set_gate_weights(module): + # This helper function is necessary because the `weight` parameter of the `NemotronHTopkRouter` + # is initialized as `torch.empty` in the original model code, which no manner of random seed + # setting will have any effect on. We therefore set it like the below to ensure the + # reproducibility of the tests. + for _, mod in module.named_modules(): + if type(mod).__name__ == "NemotronHTopkRouter": + if hasattr(mod, "weight"): + mod.weight = torch.nn.Parameter(torch.randn_like(mod.weight)) + + @pytest.mark.parametrize( "model_name", [ @@ -57,10 +99,11 @@ def _load_nemotron_moe_layer(model_name_or_path: str): ), ], ) -@pytest.mark.parametrize("B,S", [(2, 6), (1, 8)]) +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +@skip_on_no_hf_access def test_nemotronh_moe_patch_forward(model_name, B, S, dtype): - pytest.skip("Skipping due to permission issue") device = "cuda" module = _load_nemotron_moe_layer(model_name) @@ -69,12 +112,45 @@ def test_nemotronh_moe_patch_forward(model_name, B, S, dtype): H = module.config.hidden_size x = torch.randn(B, S, H, device=device, dtype=dtype) - with torch.no_grad(): - ref = module(x) + ref = module(x) module.forward = types.MethodType(_nemotron_h_moe_forward, module) - with torch.no_grad(): - test = module(x) + test = module(x) + + rtol = 0.05 + atol = 0.05 + + torch.testing.assert_close(test, ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize( + "model_name", + [ + _hf_model_dir_or_hub_id( + "NVIDIA-Nemotron-Nano-31B-A3-v3", "nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3" + ), + ], +) +@pytest.mark.parametrize("B,S", _BATCH_AND_SEQUENCE_TEST_CASES) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.no_grad() +@skip_on_no_hf_access +def test_nemotronh_moe_custom_implementation(model_name, B, S, dtype): + device = "cuda" + + module = _load_nemotron_moe_layer(model_name) + module.to(device) + + H = module.config.hidden_size + x = torch.randn(B, S, H, device=device, dtype=dtype) + + ref = module(x) + + new_module = _load_nemotron_moe_layer(model_name, custom_model_cls=NemotronHForCausalLM) + new_module.to(device) + new_module.load_state_dict(module.state_dict()) + + test = new_module(x) rtol = 0.05 atol = 0.05