diff --git a/tests/diffusion/layers/test_norm.py b/tests/diffusion/layers/test_norm.py new file mode 100644 index 00000000000..e420415285d --- /dev/null +++ b/tests/diffusion/layers/test_norm.py @@ -0,0 +1,453 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for LayerNorm and RMSNorm custom ops in diffusion layers.""" + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +# ── Import tests ── + + +def test_layernorm_import(): + """Verify LayerNorm can be imported from the norm module.""" + from vllm_omni.diffusion.layers.norm import LayerNorm # noqa: F401 + + +def test_rmsnorm_import(): + """Verify RMSNorm can be imported from the norm module.""" + from vllm_omni.diffusion.layers.norm import RMSNorm # noqa: F401 + + +# ── LayerNorm tests ── + + +def test_layernorm_forward_shape(): + """LayerNorm produces correct output shapes.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + batch = 2 + seq_len = 4 + norm = LayerNorm(dim) + + x = torch.randn(batch, seq_len, dim) + out = norm(x) + + assert out.shape == (batch, seq_len, dim) + + +def test_layernorm_forward_shape_2d(): + """LayerNorm works with 2D input tensors.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + batch = 2 + norm = LayerNorm(dim) + + x = torch.randn(batch, dim) + out = norm(x) + + assert out.shape == (batch, dim) + + +def test_layernorm_preserves_dtype_fp32(): + """LayerNorm preserves float32 dtype.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim) + + x = torch.randn(2, 4, dim, dtype=torch.float32) + out = norm(x) + + assert out.dtype == torch.float32 + + +def test_layernorm_preserves_dtype_fp16(): + """LayerNorm preserves float16 dtype.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim) + + x = torch.randn(2, 4, dim, dtype=torch.float16) + out = norm(x) + + assert out.dtype == torch.float16 + + +def test_layernorm_preserves_dtype_bf16(): + """LayerNorm preserves bfloat16 dtype.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim) + + x = torch.randn(2, 4, dim, dtype=torch.bfloat16) + out = norm(x) + + assert out.dtype == torch.bfloat16 + + +def test_layernorm_without_elementwise_affine(): + """LayerNorm works without elementwise_affine (no learned parameters).""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim, elementwise_affine=False) + + assert norm.weight is None + assert norm.bias is None + + x = torch.randn(2, 4, dim) + out = norm(x) + + assert out.shape == (2, 4, dim) + + +def test_layernorm_custom_eps(): + """LayerNorm accepts custom epsilon value.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + eps = 1e-5 + norm = LayerNorm(dim, eps=eps) + + assert norm.eps == eps + + +def test_layernorm_has_learnable_parameters(): + """LayerNorm has learnable weight and bias by default.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim) + + assert norm.weight is not None + assert norm.bias is not None + assert norm.weight.shape == (dim,) + assert norm.bias.shape == (dim,) + + +def test_layernorm_matches_fp32_reference(): + """Verify LayerNorm produces identical output to FP32 nn.LayerNorm.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + eps = 1e-6 + torch.manual_seed(42) + + ours = LayerNorm(dim, eps=eps) + ref = torch.nn.LayerNorm(dim, eps=eps) + + # Copy weights + ref.weight.data.copy_(ours.weight.data) + ref.bias.data.copy_(ours.bias.data) + + x = torch.randn(2, 4, dim) + + out_ours = ours(x) + out_ref = ref(x.float()).to(x.dtype) + + torch.testing.assert_close(out_ours, out_ref, atol=1e-5, rtol=1e-5) + + +def test_layernorm_matches_diffusers_fp32layernorm(): + """Verify LayerNorm produces identical output to diffusers FP32LayerNorm.""" + from diffusers.models.normalization import FP32LayerNorm + + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + eps = 1e-6 + torch.manual_seed(42) + + ours = LayerNorm(dim, eps=eps) + ref = FP32LayerNorm(dim, eps=eps) + + # Copy weights + ref.weight.data.copy_(ours.weight.data) + ref.bias.data.copy_(ours.bias.data) + + # Test with fp16 input to verify FP32 computation + x = torch.randn(2, 4, dim, dtype=torch.float16) + + out_ours = ours(x) + out_ref = ref(x) + + torch.testing.assert_close(out_ours, out_ref, atol=1e-3, rtol=1e-3) + + +# ── RMSNorm tests ── + + +def test_rmsnorm_forward_shape(): + """RMSNorm produces correct output shapes.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + batch = 2 + seq_len = 4 + norm = RMSNorm(hidden_size) + + x = torch.randn(batch, seq_len, hidden_size) + out = norm(x) + + assert out.shape == (batch, seq_len, hidden_size) + + +def test_rmsnorm_forward_shape_2d(): + """RMSNorm works with 2D input tensors.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + batch = 2 + norm = RMSNorm(hidden_size) + + x = torch.randn(batch, hidden_size) + out = norm(x) + + assert out.shape == (batch, hidden_size) + + +def test_rmsnorm_preserves_dtype_fp32(): + """RMSNorm preserves float32 dtype.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + norm = RMSNorm(hidden_size) + + x = torch.randn(2, 4, hidden_size, dtype=torch.float32) + out = norm(x) + + assert out.dtype == torch.float32 + + +def test_rmsnorm_preserves_dtype_fp16(): + """RMSNorm preserves float16 dtype.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + norm = RMSNorm(hidden_size) + + x = torch.randn(2, 4, hidden_size, dtype=torch.float16) + out = norm(x) + + assert out.dtype == torch.float16 + + +def test_rmsnorm_preserves_dtype_bf16(): + """RMSNorm preserves bfloat16 dtype.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + norm = RMSNorm(hidden_size) + + x = torch.randn(2, 4, hidden_size, dtype=torch.bfloat16) + out = norm(x) + + assert out.dtype == torch.bfloat16 + + +def test_rmsnorm_custom_eps(): + """RMSNorm accepts custom epsilon value.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + eps = 1e-5 + norm = RMSNorm(hidden_size, eps=eps) + + assert norm.variance_epsilon == eps + + +def test_rmsnorm_has_weight_parameter(): + """RMSNorm has learnable weight parameter initialized to ones.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + norm = RMSNorm(hidden_size) + + assert norm.weight is not None + assert norm.weight.shape == (hidden_size,) + torch.testing.assert_close(norm.weight, torch.ones(hidden_size)) + + +def test_rmsnorm_numerical_correctness(): + """Verify RMSNorm produces numerically correct output.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + eps = 1e-6 + torch.manual_seed(42) + + norm = RMSNorm(hidden_size, eps=eps) + x = torch.randn(2, 4, hidden_size) + + # Compute expected output manually + x_fp32 = x.to(torch.float32) + variance = x_fp32.pow(2).mean(-1, keepdim=True) + expected = x_fp32 * torch.rsqrt(variance + eps) + expected = norm.weight.to(torch.float32) * expected + expected = expected.to(x.dtype) + + out = norm(x) + + torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5) + + +def test_rmsnorm_matches_reference_implementation(): + """Verify RMSNorm matches a reference implementation.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + def reference_rmsnorm(x, weight, eps): + """Reference RMSNorm implementation.""" + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + out = x * torch.rsqrt(variance + eps) + out = weight.to(torch.float32) * out + return out.to(input_dtype) + + hidden_size = 128 + eps = 1e-6 + torch.manual_seed(123) + + norm = RMSNorm(hidden_size, eps=eps) + + # Test with various dtypes + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + x = torch.randn(4, 8, hidden_size, dtype=dtype) + expected = reference_rmsnorm(x, norm.weight, eps) + out = norm(x) + torch.testing.assert_close(out, expected, atol=1e-3, rtol=1e-3) + + +# ── CustomOp dispatch tests ── + + +def test_layernorm_inherits_from_customop(): + """LayerNorm inherits from CustomOp for platform dispatch.""" + from vllm_omni.diffusion.layers.custom_op import CustomOp + from vllm_omni.diffusion.layers.norm import LayerNorm + + norm = LayerNorm(64) + assert isinstance(norm, CustomOp) + + +def test_rmsnorm_inherits_from_customop(): + """RMSNorm inherits from CustomOp for platform dispatch.""" + from vllm_omni.diffusion.layers.custom_op import CustomOp + from vllm_omni.diffusion.layers.norm import RMSNorm + + norm = RMSNorm(64) + assert isinstance(norm, CustomOp) + + +def test_layernorm_has_platform_methods(): + """LayerNorm has forward methods for each platform.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + norm = LayerNorm(64) + + assert hasattr(norm, "forward_cuda") + assert hasattr(norm, "forward_hip") + assert hasattr(norm, "forward_xpu") + assert hasattr(norm, "forward_npu") + assert hasattr(norm, "forward_native") + + +def test_rmsnorm_has_platform_methods(): + """RMSNorm has forward methods for each platform.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + norm = RMSNorm(64) + + assert hasattr(norm, "forward_cuda") + assert hasattr(norm, "forward_hip") + assert hasattr(norm, "forward_xpu") + assert hasattr(norm, "forward_npu") + assert hasattr(norm, "forward_native") + + +def test_layernorm_forward_native_directly(): + """LayerNorm.forward_native can be called directly.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim) + x = torch.randn(2, 4, dim) + + out = norm.forward_native(x) + + assert out.shape == (2, 4, dim) + + +def test_rmsnorm_forward_native_directly(): + """RMSNorm.forward_native can be called directly.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + norm = RMSNorm(hidden_size) + x = torch.randn(2, 4, hidden_size) + + out = norm.forward_native(x) + + assert out.shape == (2, 4, hidden_size) + + +# ── Edge case tests ── + + +def test_layernorm_with_large_dim(): + """LayerNorm works with large hidden dimensions.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 4096 + norm = LayerNorm(dim) + x = torch.randn(1, 16, dim) + + out = norm(x) + + assert out.shape == (1, 16, dim) + + +def test_rmsnorm_with_large_dim(): + """RMSNorm works with large hidden dimensions.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 4096 + norm = RMSNorm(hidden_size) + x = torch.randn(1, 16, hidden_size) + + out = norm(x) + + assert out.shape == (1, 16, hidden_size) + + +def test_layernorm_with_single_element_batch(): + """LayerNorm works with batch size of 1.""" + from vllm_omni.diffusion.layers.norm import LayerNorm + + dim = 64 + norm = LayerNorm(dim) + x = torch.randn(1, 1, dim) + + out = norm(x) + + assert out.shape == (1, 1, dim) + + +def test_rmsnorm_with_single_element_batch(): + """RMSNorm works with batch size of 1.""" + from vllm_omni.diffusion.layers.norm import RMSNorm + + hidden_size = 64 + norm = RMSNorm(hidden_size) + x = torch.randn(1, 1, hidden_size) + + out = norm(x) + + assert out.shape == (1, 1, hidden_size) diff --git a/vllm_omni/diffusion/layers/adalayernorm.py b/vllm_omni/diffusion/layers/adalayernorm.py index 4d70ed52f71..d147bdcfeb6 100644 --- a/vllm_omni/diffusion/layers/adalayernorm.py +++ b/vllm_omni/diffusion/layers/adalayernorm.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm_omni.diffusion.layers.custom_op import CustomOp +from vllm_omni.diffusion.layers.norm import LayerNorm if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import QuantizationConfig @@ -27,7 +28,7 @@ def __init__(self, hidden_size: int, elementwise_affine: bool = False, eps: floa self.eps = eps self.elementwise_affine = elementwise_affine self.hidden_size = hidden_size - self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps) + self.layernorm = LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps) def forward_cuda( self, diff --git a/vllm_omni/diffusion/layers/norm.py b/vllm_omni/diffusion/layers/norm.py new file mode 100644 index 00000000000..6096ad7c370 --- /dev/null +++ b/vllm_omni/diffusion/layers/norm.py @@ -0,0 +1,110 @@ +from importlib.util import find_spec + +import torch +import torch.nn as nn +import torch.nn.functional as F +from vllm.logger import init_logger + +from vllm_omni.diffusion.layers.custom_op import CustomOp + +logger = init_logger(__name__) + +_HAS_MINDIESD = find_spec("mindiesd") is not None + + +class LayerNorm(nn.LayerNorm, CustomOp): + """ + LayerNorm implementation that inherits from both ``nn.LayerNorm`` and ``CustomOp``. + NPU: + Uses ``mindiesd.fast_layernorm(self, x)`` when MindIE-SD is installed. + CUDA / HIP / XPU / native: + Falls back to FP32 nn.LayerNorm implementation. + """ + + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True): + super().__init__(normalized_shape=dim, eps=eps, elementwise_affine=elementwise_affine) + # CustomOp.__init__ cannot be called here because it would re-run + # nn.Module initialization and clear LayerNorm parameters. + self._forward_method = CustomOp.dispatch_forward(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._forward_method(x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + def forward_hip(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + def forward_npu(self, x: torch.Tensor) -> torch.Tensor: + if _HAS_MINDIESD: + try: + from mindiesd import fast_layernorm + + return fast_layernorm(self, x) + except ImportError as e: + logger.warning_once( + "mindiesd.fast_layernorm import failed, falling back to FP32 layer_norm: %s", + e, + ) + + return self.forward_native(x) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + origin_dtype = x.dtype + return F.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class RMSNorm(CustomOp): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x) + + def forward_hip( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x) + + def forward_npu( + self, + x: torch.Tensor, + ) -> torch.Tensor: + import torch_npu + + output = torch_npu.npu_rms_norm(x, gamma=self.weight, epsilon=self.variance_epsilon)[0] + + return output + + def forward_xpu( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return self.forward_native(x) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + out = x * torch.rsqrt(variance + self.variance_epsilon) + out = self.weight.to(torch.float32) * out + return out.to(input_dtype) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index b870193a140..d4d81b78eb8 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -11,7 +11,6 @@ from diffusers.models.attention import FeedForward from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput -from diffusers.models.normalization import FP32LayerNorm from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -30,6 +29,7 @@ ) from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm +from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -236,9 +236,9 @@ class WanImageEmbedding(nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: int | None = None): super().__init__() - self.norm1 = FP32LayerNorm(in_features) + self.norm1 = LayerNorm(in_features) self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") - self.norm2 = FP32LayerNorm(out_features) + self.norm2 = LayerNorm(out_features) if pos_embed_seq_len is not None: self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) else: @@ -378,8 +378,12 @@ def __init__( self.tp_inner_dim = self.num_heads * head_dim # QK normalization using vLLM's RMSNorm - self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) - self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + if get_tensor_model_parallel_world_size() > 1: + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + else: + self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps) self.to_out = RowParallelLinear( self.inner_dim, @@ -498,8 +502,12 @@ def __init__( self.tp_inner_dim = self.num_heads * head_dim # QK normalization - self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) - self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + if get_tensor_model_parallel_world_size() > 1: + self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + else: + self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps) + self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps) # Optional added KV projections for I2V (image embeddings) self.added_kv_proj_dim = added_kv_proj_dim @@ -518,7 +526,10 @@ def __init__( gather_output=False, return_bias=False, ) - self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + if get_tensor_model_parallel_world_size() > 1: + self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) + else: + self.norm_added_k = RMSNorm(self.tp_inner_dim, eps=eps) else: self.add_k_proj = None self.add_v_proj = None @@ -637,7 +648,7 @@ def __init__( eps=eps, added_kv_proj_dim=added_kv_proj_dim, ) - self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.norm2 = LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward self.ffn = WanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim)