diff --git a/tests/compile/passes/test_fusion.py b/tests/compile/passes/test_fusion.py index 2feb0bc4f787..40f07549f065 100644 --- a/tests/compile/passes/test_fusion.py +++ b/tests/compile/passes/test_fusion.py @@ -40,7 +40,7 @@ TritonFp8BlockScaledMMKernel, _KernelT, ) -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm, RMSNormGated from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, create_fp8_quant_key, @@ -441,3 +441,242 @@ def test_aiter_fusion_rmsnorm_quant( _run_fusion_test( model, fusion_pass, vllm_config, dtype, hidden_size, num_tokens ) + + +class TestGatedModel(torch.nn.Module): + """Model that uses RMSNormGated + reshape + group FP8 quant + linear. + + Mimics GatedDeltaNetAttention's output projection path where: + - RMSNormGated operates on per-head tensors (N*H, D) + - Output is reshaped to (N, H*D) before group quantization + linear + """ + + def __init__( + self, + num_heads: int, + head_dim: int, + eps: float, + force_kernel: type[_KernelT], + group_shape: GroupShape, + dtype: torch.dtype, + use_aiter_quant: bool = True, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hidden_dim = num_heads * head_dim + + self.norm = RMSNormGated( + head_dim, + eps=eps, + group_size=None, + norm_before_gate=True, + ) + + self.activation_quant_key = create_fp8_quant_key( + static=False, group_shape=group_shape + ) + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(group_shape.col, group_shape.col) + ) + + self.fp8_linear = TestFP8Layer( + weight_shape=(hidden_dim, hidden_dim), + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + force_kernel=force_kernel, + transpose_weights=True, + input_dtype=dtype, + ) + self.fp8_linear.kernel.quant_fp8.use_aiter = use_aiter_quant + + def forward(self, x, z): + num_heads = self.num_heads + head_dim = self.head_dim + hidden_dim = num_heads * head_dim + x = torch.relu(x) + z = torch.relu(z) + x_heads = x.reshape(-1, num_heads, head_dim).reshape(-1, head_dim) + z_heads = z.reshape(-1, num_heads, head_dim).reshape(-1, head_dim) + normed = self.norm(x_heads, z_heads) + merged = normed.reshape(-1, hidden_dim) + out = self.fp8_linear(merged) + return out + + def ops_in_model_after(self): + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + AiterRMSNormGatedFp8GroupQuantPattern, + ) + + return [AiterRMSNormGatedFp8GroupQuantPattern.FUSED_OP] + + +class _MockGDNLayer: + """Minimal mock to populate static_forward_context for pass discovery. + + Uses __class__ assignment to pass isinstance checks against + GatedDeltaNetAttention without requiring a full config-based init. + """ + + def __init__(self, num_v_heads: int, head_v_dim: int, tp_size: int = 1): + self.num_v_heads = num_v_heads + self.head_v_dim = head_v_dim + self.tp_size = tp_size + + from vllm.model_executor.layers.mamba.gdn_linear_attn import ( + GatedDeltaNetAttention, + ) + + self.__class__ = GatedDeltaNetAttention + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_heads", [2]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("num_tokens", [8]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.skipif( + (not current_platform.is_rocm() or not IS_AITER_FOUND), + reason="Only test on ROCm with aiter package installed", +) +def test_aiter_fusion_rmsnorm_gated_quant( + dtype: torch.dtype, + num_heads: int, + head_dim: int, + num_tokens: int, + eps: float, + monkeypatch: pytest.MonkeyPatch, +): + group_shape = GroupShape(1, 128) + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["-rms_norm", "-silu_and_mul", "-quant_fp8"], + pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + # Register a mock GDN layer so the pass discovers num_heads/head_dim + mock_gdn = _MockGDNLayer(num_v_heads=num_heads, head_v_dim=head_dim, tp_size=1) + vllm_config.compilation_config.static_forward_context["mock_gdn_layer"] = ( + mock_gdn + ) + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + + model = TestGatedModel( + num_heads=num_heads, + head_dim=head_dim, + eps=eps, + force_kernel=AiterFp8BlockScaledMMKernel, + group_shape=group_shape, + dtype=dtype, + use_aiter_quant=True, + ) + + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + backend2 = TestBackend(noop_pass, cleanup_pass) + + hidden_dim = num_heads * head_dim + x = torch.rand(num_tokens, hidden_dim) + z = torch.rand(num_tokens, hidden_dim) + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(z, 0) + + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x, z) + + model_unfused = torch.compile(model, backend=backend2) + result_unfused = model_unfused(x, z) + + torch.testing.assert_close(result_fused, result_unfused, atol=1e-2, rtol=1e-2) + + assert fusion_pass.matched_count == 1 + backend.check_after_ops(model.ops_in_model_after()) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_heads", [2]) +@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("num_tokens", [8]) +@pytest.mark.parametrize("eps", [1e-6]) +@pytest.mark.skipif( + (not current_platform.is_rocm() or not IS_AITER_FOUND), + reason="Only test on ROCm with aiter package installed", +) +def test_aiter_fusion_rmsnorm_gated_quant_no_gdn_layers( + dtype: torch.dtype, + num_heads: int, + head_dim: int, + num_tokens: int, + eps: float, + monkeypatch: pytest.MonkeyPatch, +): + """Verify that without GDN layers in static_forward_context, + the gated pattern is not registered and no matches occur.""" + group_shape = GroupShape(1, 128) + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["-rms_norm", "-silu_and_mul", "-quant_fp8"], + pass_config=PassConfig(fuse_norm_quant=True, eliminate_noops=True), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( + RocmAiterRMSNormQuantFusionPass, + ) + + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(1) + + # No mock GDN layer registered -- pass should not register gated pattern + fusion_pass = RocmAiterRMSNormQuantFusionPass(vllm_config) + + model = TestGatedModel( + num_heads=num_heads, + head_dim=head_dim, + eps=eps, + force_kernel=AiterFp8BlockScaledMMKernel, + group_shape=group_shape, + dtype=dtype, + use_aiter_quant=True, + ) + + noop_pass = NoOpEliminationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend(noop_pass, fusion_pass, cleanup_pass) + + hidden_dim = num_heads * head_dim + x = torch.rand(num_tokens, hidden_dim) + z = torch.rand(num_tokens, hidden_dim) + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(z, 0) + + model_fused = torch.compile(model, backend=backend) + model_fused(x, z) + + assert fusion_pass.matched_count == 0 diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index e628c97835c6..c742e2eba8f2 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -913,6 +913,50 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake( ) +def _rocm_aiter_fused_rms_gated_fp8_group_quant_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + z: torch.Tensor, + eps: float, + norm_before_gate: bool, + activation: str, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused gated-RMSNorm + FP8 group quantization via aiter Triton kernel.""" + from aiter.ops.triton.quant import fused_rms_gated_fp8_group_quant + + return fused_rms_gated_fp8_group_quant( + x, + weight, + bias, + z, + eps, + norm_before_gate=norm_before_gate, + activation=activation, + out_dtype=FP8_DTYPE, + group_size=group_size, + ) + + +def _rocm_aiter_fused_rms_gated_fp8_group_quant_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + z: torch.Tensor, + eps: float, + norm_before_gate: bool, + activation: str, + group_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + M, N = x.shape + scale_shape = (M, (N + group_size - 1) // group_size) + return ( + torch.empty_like(x, dtype=FP8_DTYPE, device=x.device), + torch.empty(scale_shape, dtype=torch.float32, device=x.device), + ) + + def _rocm_aiter_group_fp8_quant_impl( x: torch.Tensor, group_size: int, @@ -1480,6 +1524,9 @@ def are_gdn_triton_kernels_available(cls) -> bool: try: import aiter.ops.triton.causal_conv1d_update_single_token # noqa: F401 import aiter.ops.triton.gated_delta_net # noqa: F401 + from aiter.ops.triton.quant import ( # noqa: F401 + fused_rms_gated_fp8_group_quant, + ) return True except (ImportError, ModuleNotFoundError): @@ -1598,6 +1645,12 @@ def register_ops_once() -> None: fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_rms_gated_fp8_group_quant", + op_func=_rocm_aiter_fused_rms_gated_fp8_group_quant_impl, + fake_impl=_rocm_aiter_fused_rms_gated_fp8_group_quant_fake, + ) + direct_register_custom_op( op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant", op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl, @@ -1689,6 +1742,11 @@ def get_rmsnorm_fused_dynamic_quant_op() -> OpOverload: def get_rmsnorm_group_fused_quant_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default + @staticmethod + def get_fused_rms_gated_fp8_group_quant_op() -> OpOverload: + """Return the fused gated-RMSNorm + FP8 group quant custom op.""" + return torch.ops.vllm.rocm_aiter_fused_rms_gated_fp8_group_quant.default + @staticmethod def get_rmsnorm_group_add_fused_quant_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index 94f0e62204c3..9f25a6805e93 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -10,6 +10,7 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNormGated from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -28,7 +29,6 @@ ) from vllm.platforms import current_platform -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default ROTARY_OP = torch.ops._C.rotary_embedding.default FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default @@ -161,6 +161,67 @@ def forward_native( return result +class MatcherRMSNormGated(MatcherCustomOp): + """Matches RMSNormGated with norm_before_gate=True and group_size=None.""" + + def __init__( + self, + epsilon: float, + enabled: bool | None = None, + norm_before_gate: bool = True, + group_size: int | None = None, + ) -> None: + if enabled is None: + enabled = RMSNormGated.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + self.norm_before_gate = norm_before_gate + self.group_size = group_size + + def inputs(self) -> list[torch.Tensor]: + x = self.empty(5, 16) + z = self.empty(5, 16) + weight = self.empty(16) + return [x, z, weight] + + def forward_custom( + self, + x: torch.Tensor, + z: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + from vllm.model_executor.layers.fla.ops.layernorm_guard import ( + rmsnorm_fn, + ) + + return rmsnorm_fn( + x, + weight, + bias=None, + z=z, + eps=self.epsilon, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) + + def forward_native( + self, + x: torch.Tensor, + z: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return RMSNormGated.forward_static( + x, + z, + weight, + self.epsilon, + self.model_dtype, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) + + class MatcherDeepseekScalingRotaryEmbedding(MatcherCustomOp): def __init__( self, diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index 9a975c5fed4d..0ee60d01a815 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -12,7 +12,7 @@ import vllm.ir.ops import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -28,9 +28,12 @@ VllmInductorPass, VllmPatternMatcherPass, VllmPatternReplacement, + _fx_view_to_reshape, + fold_consecutive_reshapes, ) from .matcher_utils import ( MatcherQuantFP8, + MatcherRMSNormGated, MatcherSiluAndMul, ) from .rms_quant_fusion import ( @@ -449,6 +452,97 @@ def trace_with_view_to_reshape(*args: Any, **kwargs: Any) -> fx.GraphModule: ) +class AiterRMSNormGatedFp8GroupQuantPattern(AiterRMSNormQuantPattern): + """ + Matches decomposed RMSNormGated + reshape + group FP8 quant and replaces + with rocm_aiter_fused_rms_gated_fp8_group_quant. + + The norm operates per-head on (N*H, D) tensors. The compiler folds the + reshape chain so after norm the result goes through reshape->merge->quant. + The pattern reshapes from (N*H, D) to (N, H*D) before calling + MatcherQuantFP8 so that _quantize_group_native sees the full hidden dim + and computes the correct num_groups. + """ + + FUSED_OP = rocm_aiter_ops.get_fused_rms_gated_fp8_group_quant_op() + + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape, + num_heads: int, + head_dim: int, + match_aiter_quant: bool = True, + symmetric: bool = True, + ) -> None: + scale = ScaleDesc(torch.float32, False, group_shape) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) + super().__init__(epsilon, key, match_aiter_quant) + self.rmsnorm_gated_matcher = MatcherRMSNormGated(epsilon) + self.num_heads = num_heads + self.head_dim = head_dim + + def register(self, pm_pass: PatternMatcherPass) -> None: + num_heads = self.num_heads + head_dim = self.head_dim + hidden_dim = num_heads * head_dim + quant_matcher = self.quant_matcher + + def pattern( + x: torch.Tensor, + z: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + normed = self.rmsnorm_gated_matcher(x, z, weight) + merged = normed.reshape(-1, hidden_dim) + quant_out, scales_out = quant_matcher(merged) + return quant_out, scales_out + + def replacement( + x: torch.Tensor, + z: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + fused = self.FUSED_OP( + x=x, + weight=weight, + bias=None, + z=z, + eps=self.epsilon, + norm_before_gate=True, + activation="silu", + group_size=head_dim, + ) + fp8_out = fused[0] + scales_out = fused[1] + fp8_reshaped = fp8_out.reshape(-1, hidden_dim) + scales_reshaped = scales_out.reshape(-1, num_heads) + return fp8_reshaped, scales_reshaped + + n_tokens = 2 + x = self.empty(n_tokens * num_heads, head_dim) + z = self.empty(n_tokens * num_heads, head_dim) + w = self.empty(head_dim) + + def trace_fn(*args, **kwargs): + gm = pm.fwd_only(*args, **kwargs) + _fx_view_to_reshape(gm) + fold_consecutive_reshapes(gm) + return gm + + pm.register_replacement( + pattern, + replacement, + [x, z, w], + trace_fn, + pm_pass, + ) + + class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass): """ This pass fuses aiter rms_norm & vllm/aiter quant custom ops @@ -464,6 +558,19 @@ def __init__(self, config: VllmConfig) -> None: pass_name="rocm_aiter_rms_norm_quant_fusion_pass" ) + # Discover (num_heads, head_dim) pairs for gated RMSNorm patterns + # from GatedDeltaNetAttention layers in static_forward_context. + from vllm.model_executor.layers.mamba.gdn_linear_attn import ( + GatedDeltaNetAttention, + ) + + gdn_layers = get_layers_from_vllm_config(config, GatedDeltaNetAttention) + gated_norm_shapes: set[tuple[int, int]] = set() + for layer in gdn_layers.values(): + gated_norm_shapes.add( + (layer.num_v_heads // layer.tp_size, layer.head_v_dim) + ) + # Make sure fused add patterns are before simple rms norm, # as the latter is a subset of the former in torch ops. # The DoubleQuant patterns handle 1 rms_norm -> 2 group_fp8_quant @@ -517,6 +624,21 @@ def __init__(self, config: VllmConfig) -> None: epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant ).register(self.patterns) + # Fuse decomposed RMSNormGated + group fp8 quant. + # The replacement op (fused_rms_gated_fp8_group_quant) requires + # an aiter version that includes the GDN triton kernel renames. + if gated_norm_shapes and rocm_aiter_ops.are_gdn_triton_kernels_available(): + for num_heads, head_dim in gated_norm_shapes: + if head_dim != 128: + continue + AiterRMSNormGatedFp8GroupQuantPattern( + epsilon, + FP8_DTYPE, + GroupShape(1, 128), + num_heads=num_heads, + head_dim=head_dim, + ).register(self.patterns) + self.dump_patterns(config, self.patterns) @VllmInductorPass.time_and_log @@ -534,6 +656,7 @@ def uuid(self) -> str: AiterFusedAddRMSFp8GroupQuantPattern, DoubleAiterRMSFp8GroupQuantPattern, DoubleAiterRMSFp8GroupQuantViewPattern, + AiterRMSNormGatedFp8GroupQuantPattern, ] return self.hash_source(self, *fusion_patterns) diff --git a/vllm/compilation/passes/vllm_inductor_pass.py b/vllm/compilation/passes/vllm_inductor_pass.py index 46c3fe770869..4f90b2a27e1b 100644 --- a/vllm/compilation/passes/vllm_inductor_pass.py +++ b/vllm/compilation/passes/vllm_inductor_pass.py @@ -252,6 +252,33 @@ def _fx_view_to_reshape(gm: fx.GraphModule) -> None: view_to_reshape(gm) +def fold_consecutive_reshapes(gm: fx.GraphModule) -> None: + """Fold consecutive reshape ops into a single reshape. + + ``make_fx`` faithfully records every view/reshape the Python code performs, + so patterns like ``x.reshape(a, b).reshape(c, d)`` produce two reshape + nodes. Inductor's own optimisation would fold these, but + ``pm.register_replacement``'s ``trace_fn`` runs before Inductor, so we + must fold them ourselves for the pattern to match the compiled graph. + + When reshape(A, shape1) feeds only into reshape(result, shape2), + the first reshape is redundant -- replace with reshape(A, shape2). + """ + aten_reshape = torch.ops.aten.reshape.default + for node in list(gm.graph.nodes): + if not is_func(node, aten_reshape): + continue + inp = node.args[0] + if not isinstance(inp, fx.Node) or not is_func(inp, aten_reshape): + continue + if len(inp.users) != 1: + continue + original_input = inp.args[0] + node.args = (original_input, node.args[1]) + inp.replace_all_uses_with(original_input) + gm.graph.erase_node(inp) + + def _remove_noop_permutes(gm: fx.GraphModule) -> None: for node in gm.graph.nodes: if not is_func(node, torch.ops.aten.permute.default): diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a5d4e4db79fe..d5671eb9c1e6 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -229,56 +229,71 @@ def __init__( def reset_parameters(self): torch.nn.init.ones_(self.weight) - def forward_native( - self, x: torch.Tensor, z: torch.Tensor | None = None + @staticmethod + def forward_static( + x: torch.Tensor, + z: torch.Tensor | None, + weight: torch.Tensor, + epsilon: float, + orig_dtype: torch.dtype, + group_size: int | None = None, + norm_before_gate: bool = True, + activation: str = "swish", ) -> torch.Tensor: - """ - Native PyTorch implementation of RMS normalization with gating. - - Args: - x: Input tensor - z: Optional gating tensor + """Pure-PyTorch RMS normalization with optional gating. - Returns: - Normalized (and optionally gated) tensor + This static method contains the full native logic so that both + ``forward_native`` and ``MatcherRMSNormGated`` (used by the + compilation pattern matcher) can share the same implementation. - If z is not None: - - norm_before_gate=True: out = norm(x) * silu(z) - - norm_before_gate=False: out = norm(x * silu(z)) + If *z* is not None and *norm_before_gate* is True: + ``out = rms_norm(x) * act(z)`` + If *z* is not None and *norm_before_gate* is False: + ``out = rms_norm(x * act(z))`` """ - orig_dtype = x.dtype x = x.float() - weight = self.weight.float() - z = z.float() if z is not None else None + weight = weight.float() + if z is not None: + z = z.float() - assert self.activation in ["silu", "sigmoid", "swish"] - act_fn = F.sigmoid if self.activation == "sigmoid" else F.silu + assert activation in ["silu", "sigmoid", "swish"] + act_fn = F.sigmoid if activation == "sigmoid" else F.silu - # Apply gating before normalization if needed - if z is not None and not self.norm_before_gate: + if z is not None and not norm_before_gate: x = x * act_fn(z) - # RMS Normalization - if self.group_size is None: - # Standard RMS norm across the last dimension + if group_size is None: variance = x.pow(2).mean(dim=-1, keepdim=True) - x_normed = x * torch.rsqrt(variance + self.eps) + x_normed = x * torch.rsqrt(variance + epsilon) out = x_normed * weight else: - # Group RMS norm from einops import rearrange - x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size) + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) variance = x_group.pow(2).mean(dim=-1, keepdim=True) - x_normed = x_group * torch.rsqrt(variance + self.eps) + x_normed = x_group * torch.rsqrt(variance + epsilon) out = rearrange(x_normed, "... g d -> ... (g d)") * weight - # Apply gating after normalization if needed - if z is not None and self.norm_before_gate: + if z is not None and norm_before_gate: out = out * act_fn(z) return out.to(orig_dtype) + def forward_native( + self, x: torch.Tensor, z: torch.Tensor | None = None + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static( + x, + z, + self.weight, + self.eps, + x.dtype, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + activation=self.activation, + ) + def forward_cuda( self, x: torch.Tensor, z: torch.Tensor | None = None ) -> torch.Tensor: