Skip to content

Compile will make LigerLayerNorm's output change #871

@techkang

Description

@techkang

🐛 Describe the bug

As showd in title, numerical diffrenece is very large after compile LigerLayerNorm.
=== Numerical Differences ===
Liger vs PyTorch LayerNorm:
Max diff: 0.01562500
Mean diff: 0.00000447
Liger compiled vs Liger:
Max diff: 4.90625000
Mean diff: 0.79687500
Liger compiled vs PyTorch:
Max diff: 4.90625000
Mean diff: 0.79687500

=== Output Statistics ===
PyTorch : mean=-0.000007, std=1.000000
Liger : mean=-0.000007, std=1.000000
Liger+Compile : mean=0.000000, std=0.000000

Reproduce

# -*- coding: utf-8 -*-
#!/usr/bin/env python3
"""
Simple test script for liger_kernel LayerNorm vs PyTorch LayerNorm in bfloat16
Tests: LigerLayerNorm, LigerLayerNorm+compile, PyTorch LayerNorm
"""

import os
import warnings

# Suppress warnings
os.environ["TRITON_PRINT_AUTOTUNING"] = "0"
os.environ["TRITON_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn

def test_bfloat16_comparison():
    """Test numerical differences in bfloat16"""
    
    try:
        from liger_kernel.transformers import LigerLayerNorm
        print("✓ Successfully imported LigerLayerNorm")
    except ImportError as e:
        print(f"✗ Failed to import LigerLayerNorm: {e}")
        return
    
    # Test parameters
    batch_size, seq_len, hidden_size = 4, 128, 768
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.bfloat16
    
    print(f"Device: {device}, dtype: {dtype}")
    print(f"Shape: ({batch_size}, {seq_len}, {hidden_size})")
    
    # Create test input
    torch.manual_seed(42)
    x = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
    
    # Create models
    torch_ln = nn.LayerNorm(hidden_size).to(device).to(dtype)
    liger_ln = LigerLayerNorm(hidden_size).to(device).to(dtype)
    liger_ln_compiled = torch.compile(liger_ln)
    
    # Forward pass
    with torch.no_grad():
        torch_out = torch_ln(x)
        liger_out = liger_ln(x)
        liger_compiled_out = liger_ln_compiled(x)
    
    # Calculate differences
    diff_liger_vs_torch = torch.abs(liger_out - torch_out)
    diff_compiled_vs_liger = torch.abs(liger_compiled_out - liger_out)
    diff_compiled_vs_torch = torch.abs(liger_compiled_out - torch_out)
    
    print(f"\n=== Numerical Differences ===")
    print(f"Liger vs PyTorch LayerNorm:")
    print(f"  Max diff: {diff_liger_vs_torch.max().item():.8f}")
    print(f"  Mean diff: {diff_liger_vs_torch.mean().item():.8f}")
    
    print(f"Liger compiled vs Liger:")
    print(f"  Max diff: {diff_compiled_vs_liger.max().item():.8f}")
    print(f"  Mean diff: {diff_compiled_vs_liger.mean().item():.8f}")
    
    print(f"Liger compiled vs PyTorch:")
    print(f"  Max diff: {diff_compiled_vs_torch.max().item():.8f}")
    print(f"  Mean diff: {diff_compiled_vs_torch.mean().item():.8f}")
    
    # Output statistics
    print(f"\n=== Output Statistics ===")
    for name, out in [("PyTorch", torch_out), ("Liger", liger_out), ("Liger+Compile", liger_compiled_out)]:
        print(f"{name:15}: mean={out.mean().item():.6f}, std={out.std().item():.6f}")

if __name__ == "__main__":
    test_bfloat16_comparison()

Versions

Environment Report:

Operating System: Linux-5.4.250-2-velinux1u1-amd64-x86_64-with-glibc2.35
Python version: 3.12.2
Liger Kernel version: 0.6.2
PyTorch version: 2.8.0+cu128
CUDA version: 12.8
HIP(ROCm) version: Not available
Triton version: 3.4.0
Transformers version: 4.55.0
XPU version: XPU Not Available

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions