-
Notifications
You must be signed in to change notification settings - Fork 399
Description
🐛 Describe the bug
As showd in title, numerical diffrenece is very large after compile LigerLayerNorm.
=== Numerical Differences ===
Liger vs PyTorch LayerNorm:
Max diff: 0.01562500
Mean diff: 0.00000447
Liger compiled vs Liger:
Max diff: 4.90625000
Mean diff: 0.79687500
Liger compiled vs PyTorch:
Max diff: 4.90625000
Mean diff: 0.79687500
=== Output Statistics ===
PyTorch : mean=-0.000007, std=1.000000
Liger : mean=-0.000007, std=1.000000
Liger+Compile : mean=0.000000, std=0.000000
Reproduce
# -*- coding: utf-8 -*-
#!/usr/bin/env python3
"""
Simple test script for liger_kernel LayerNorm vs PyTorch LayerNorm in bfloat16
Tests: LigerLayerNorm, LigerLayerNorm+compile, PyTorch LayerNorm
"""
import os
import warnings
# Suppress warnings
os.environ["TRITON_PRINT_AUTOTUNING"] = "0"
os.environ["TRITON_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
def test_bfloat16_comparison():
"""Test numerical differences in bfloat16"""
try:
from liger_kernel.transformers import LigerLayerNorm
print("✓ Successfully imported LigerLayerNorm")
except ImportError as e:
print(f"✗ Failed to import LigerLayerNorm: {e}")
return
# Test parameters
batch_size, seq_len, hidden_size = 4, 128, 768
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
print(f"Device: {device}, dtype: {dtype}")
print(f"Shape: ({batch_size}, {seq_len}, {hidden_size})")
# Create test input
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, hidden_size, device=device, dtype=dtype)
# Create models
torch_ln = nn.LayerNorm(hidden_size).to(device).to(dtype)
liger_ln = LigerLayerNorm(hidden_size).to(device).to(dtype)
liger_ln_compiled = torch.compile(liger_ln)
# Forward pass
with torch.no_grad():
torch_out = torch_ln(x)
liger_out = liger_ln(x)
liger_compiled_out = liger_ln_compiled(x)
# Calculate differences
diff_liger_vs_torch = torch.abs(liger_out - torch_out)
diff_compiled_vs_liger = torch.abs(liger_compiled_out - liger_out)
diff_compiled_vs_torch = torch.abs(liger_compiled_out - torch_out)
print(f"\n=== Numerical Differences ===")
print(f"Liger vs PyTorch LayerNorm:")
print(f" Max diff: {diff_liger_vs_torch.max().item():.8f}")
print(f" Mean diff: {diff_liger_vs_torch.mean().item():.8f}")
print(f"Liger compiled vs Liger:")
print(f" Max diff: {diff_compiled_vs_liger.max().item():.8f}")
print(f" Mean diff: {diff_compiled_vs_liger.mean().item():.8f}")
print(f"Liger compiled vs PyTorch:")
print(f" Max diff: {diff_compiled_vs_torch.max().item():.8f}")
print(f" Mean diff: {diff_compiled_vs_torch.mean().item():.8f}")
# Output statistics
print(f"\n=== Output Statistics ===")
for name, out in [("PyTorch", torch_out), ("Liger", liger_out), ("Liger+Compile", liger_compiled_out)]:
print(f"{name:15}: mean={out.mean().item():.6f}, std={out.std().item():.6f}")
if __name__ == "__main__":
test_bfloat16_comparison()
Versions
Environment Report:
Operating System: Linux-5.4.250-2-velinux1u1-amd64-x86_64-with-glibc2.35
Python version: 3.12.2
Liger Kernel version: 0.6.2
PyTorch version: 2.8.0+cu128
CUDA version: 12.8
HIP(ROCm) version: Not available
Triton version: 3.4.0
Transformers version: 4.55.0
XPU version: XPU Not Available