Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import triton
import triton.language as tl


@triton.jit
def _fwd_fused_add_rmsnorm(
original,
residual,
weight,
original_stride0,
original_stride1,
residual_stride0,
residual_stride1,
N, # number of columns in X
eps,
BLOCK_SIZE: tl.constexpr,
):
block_id = tl.program_id(0)
# data's base address of this block
_original = original + block_id * original_stride0
_residual = residual + block_id * residual_stride0

# avoid repeat loading from gmem to smem
# in some very large size, have better performance
if N <= BLOCK_SIZE:
# data's offset address of this block
range = tl.arange(0, BLOCK_SIZE)
_original_offset = range * original_stride1
_residual_offset = range * residual_stride1
_weight_offset = range

# data's pointers of this block
_original_ptr = _original + _original_offset
_residual_ptr = _residual + _residual_offset
_weight_ptr = weight + _weight_offset

# load data from memory
mask = range < N
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32)
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32)

# store (original + residual) to original
original_cache = original_cache + residual_cache
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask)

# compute variance
var = tl.sum(original_cache * original_cache) / N
rstd = 1 / tl.sqrt(var + eps)
residual_cache = original_cache * rstd * weight_cache

# store rmsnorm(original + residual) back to residual
tl.store(_residual_ptr, residual_cache.to(residual.dtype.element_ty), mask=mask)
else:
sum_of_squares = tl.zeros([], dtype=tl.float32)
for block_offset in range(0, N, BLOCK_SIZE):
# data's offset address of this block
range = tl.arange(0, BLOCK_SIZE) + block_offset
_original_offset = range * original_stride1
_residual_offset = range * residual_stride1

# data's pointers of this block
_original_ptr = _original + _original_offset
_residual_ptr = _residual + _residual_offset

# load data from memory
mask = range < N
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32)

# store (original + residual) to original
original_cache = original_cache + residual_cache
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask)

# compute sum_of_squares
sum_of_squares += tl.sum(original_cache * original_cache)

# compute variance
var = sum_of_squares / N
rstd = 1 / tl.sqrt(var + eps)

for block_offset in range(0, N, BLOCK_SIZE):
# data's offset address of this block
range = tl.arange(0, BLOCK_SIZE) + block_offset
_original_offset = range * original_stride1
_residual_offset = range * residual_stride1
_weight_offset = range

# data's pointers of this block
_original_ptr = _original + _original_offset
_residual_ptr = _residual + _residual_offset
_weight_ptr = weight + _weight_offset

# load data from memory
mask = range < N
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32)

# apply rmsnorm using pre-computed rstd
original_cache = original_cache * rstd * weight_cache

# store rmsnorm(original) back to residual
tl.store(_residual_ptr, original_cache.to(residual.dtype.element_ty), mask=mask)


def fused_add_rmsnorm_inplace(
original: torch.Tensor, # [num_tokens, hidden_size]
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
):
"""
Perform fused add & rmsnorm
suppose the skip connection result is H(x) = F(x) + x,
then F(x) is the residual, x is the original.
Here original will be (residual + original), residual will be rmsnorm(residual + original)
At first Layer, residual should be all zeros.
"""
# reshape input data into 2D tensor
original_arg = original.view(-1, original.shape[-1])
residual_arg = residual.view(-1, residual.shape[-1])

assert original.data_ptr() == original_arg.data_ptr()
assert residual.data_ptr() == residual_arg.data_ptr()

M, N = original_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // original.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))

if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384
Comment on lines +133 to +140

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check if N > BLOCK_SIZE: on line 133 disables the multi-block execution path (the else block) in the Triton kernel. The kernel is designed to handle feature dimensions N larger than BLOCK_SIZE, but this check prevents it. This makes the else branch in the kernel dead code and limits the kernel's applicability to large feature dimensions. The hard cap if BLOCK_SIZE > 16384: on line 139 is also overly restrictive. Remove the error check and the hard cap.

Suggested change
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)


# enqueue kernel
_fwd_fused_add_rmsnorm[(M,)](
original_arg,
residual_arg,
weight,
original_arg.stride(0),
original_arg.stride(1),
residual_arg.stride(0),
residual_arg.stride(1),
N, # number of columns in X
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
125 changes: 124 additions & 1 deletion lightllm/utils/custom_kernel_utis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from typing import List
from typing import List, Callable


def custom_cat(tensors):
Expand Down Expand Up @@ -125,3 +125,126 @@ def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int):
out[0:origin_batch_size, :] = input
out[origin_batch_size:, :] = input[0:1, :]
return out


def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)

SNR can be calcualted as following equation:

SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2

if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.

SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)

Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.

Raises:
ValueError: _description_
ValueError: _description_

Returns:
torch.Tensor: _description_
"""
Comment on lines +130 to +153

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint for the error function is -> torch.Tensor, but the function returns a float via .item(). It should be -> float. Also, the docstring contains placeholder text and mentions a reduction parameter that doesn't exist. The formula in the docstring is for an element-wise ratio mean, but the implementation calculates a ratio of summed powers. Update the function signature and the docstring for clarity and correctness.

def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> float:
    """
    Compute SNR error between y_pred(tensor) and y_real(tensor).

    The SNR error is calculated as the ratio of noise power to signal power:
    `sum((y_pred - y_real)^2) / sum(y_real^2)`

    Args:
        y_pred (torch.Tensor): The predicted tensor.
        y_real (torch.Tensor): The ground truth tensor.

    Raises:
        ValueError: If tensors have different shapes.

    Returns:
        float: The computed SNR error value.
    """

y_pred = torch.flatten(y_pred).float()
y_real = torch.flatten(y_real).float()

if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})"
)

noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
return snr.item()


def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
"""
A decorator function to assist in performance testing of CUDA operations.

This function will:
1. Automatically determine whether any parameters in the argument list,
or the output of the `func`, are of type `torch.Tensor`.
2. If so, calculate the memory usage of the input and output tensors
on the GPU (based on their data type and `torch.numel()`).
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations.
4. Record the execution time during these iterations.
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.

Args:
func (function): The function to benchmark.
shape (list of int): The problem shape.
tflops (float): The computational workload (in TFLOPS) per call of `func`.
steps (int): The number of times the function is executed during benchmarking.
*args: Positional arguments to be passed to the `func`.
**kwargs: Keyword arguments to be passed to the `func`.

Returns:
function result
"""
Comment on lines +168 to +191

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for the benchmark function is inaccurate. It's described as a "decorator function", but it's a regular utility function. It claims to "Establish a CUDA graph", but the implementation uses torch.cuda.Event for timing, not CUDA graphs. The Returns section says it returns the "function result", but it actually returns None and only prints the metrics. Update the docstring to accurately describe the function's behavior.

def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
    """
    A utility function to assist in performance testing of CUDA operations.

    This function will:
    1. Automatically determine whether any parameters in the argument list,
       or the output of the `func`, are of type `torch.Tensor`.
    2. If so, calculate the memory usage of the input and output tensors
       on the GPU (based on their data type and `torch.numel()`).
    3. Execute `func` repeatedly for `steps` iterations after a warm-up period.
    4. Record the execution time during these iterations using CUDA events.
    5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.

    Args:
        func (function): The function to benchmark.
        shape (list of int): The problem shape.
        tflops (float): The computational workload (in TFLOPS) per call of `func`.
        steps (int): The number of times the function is executed during benchmarking.
        *args: Positional arguments to be passed to the `func`.
        **kwargs: Keyword arguments to be passed to the `func`.

    Returns:
        None
    """


# Ensure CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for benchmarking.")

# Check for torch.Tensor in inputs and outputs
input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)]

def calculate_memory(tensor: torch.Tensor):
"""Calculate memory usage in bytes for a tensor."""
return tensor.numel() * tensor.element_size()

input_memory = sum(calculate_memory(t) for t in input_tensors)

# Execute the function to inspect outputs
with torch.no_grad():
output = func(*args, **kwargs)
Comment on lines +208 to +209

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initial call to func is problematic for functions that perform in-place operations, as it will modify the input tensors. All subsequent calls during the warm-up and measurement phases will then operate on this modified data, leading to incorrect benchmark results. Ensure that each timed call operates on a fresh, unmodified copy of the input data.


output_memory = 0
if isinstance(output, torch.Tensor):
output_memory = calculate_memory(output)
elif isinstance(output, (list, tuple)):
output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor))

total_memory = input_memory + output_memory

# Warm-up and CUDA graph creation
for _ in range(10): # Warm-up
func(*args, **kwargs)

torch.cuda.synchronize() # Ensure no pending operations

# Benchmark the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
for _ in range(steps):
func(*args, **kwargs)
end_event.record()

torch.cuda.synchronize() # Ensure all operations are finished
elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds

# Calculate performance metrics
elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds
avg_time_per_step = elapsed_time_s / steps
compute_performance = tflops / avg_time_per_step # TFLOPS
memory_throughput = (total_memory * steps / (1024 ** 3)) / elapsed_time_s # GB/s

# Print performance metrics
print(f"Function: {func.__name__}{shape}")
# print(f"Function: {func.__ne__}{shape}")
print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds")
print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms")
print(f"Compute Performance: {compute_performance:.2f} TFLOPS")
print(f"Memory Throughput: {memory_throughput:.2f} GB/s")
print("") # print a blank line.
63 changes: 63 additions & 0 deletions unit_tests/models/llama/test_fused_add_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest
import torch
from lightllm.models.llama.triton_kernel.fused_add_rmsnorm_inplace import fused_add_rmsnorm_inplace
from lightllm.utils.custom_kernel_utis import benchmark, error


class TestFusedAddRmsNormInplace(unittest.TestCase):
def setUp(self):
"""Set up common test parameters."""
self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384]
self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dimension 32768 will cause the test to fail with the current implementation of fused_add_rmsnorm_inplace because the wrapper function raises a RuntimeError for dimensions greater than 16384. Remove or adjust this test case.

self.device = "cuda"
self.dtype = torch.bfloat16

def torch_add_rmsnorm(self, X, R, W):
X.add_(R)
return torch.nn.functional.rms_norm(X, (X.shape[1],), W, eps=1e-6)

def test_accuracy(self):
"""Test the accuracy of fused_add_rmsnorm_inplace against torch.rmsnorm."""
for token_num in self.tokens:
for dim in self.dims:
with self.subTest(shape=[token_num, dim]):
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
_X = X.clone()
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
_R = R.clone()
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)

r_real = self.torch_add_rmsnorm(_X, _R, W)
fused_add_rmsnorm_inplace(X, R, W, eps=1e-6)
r_pred = R
self.assertTrue(
error(r_pred, r_real) < 0.01,
f"Accuracy test failed for size {token_num}, {dim}. r_real={r_real}, r_pred={r_pred}",
)
print(f"{error(r_pred, r_real) = }")

x_real = _X
x_pred = X
self.assertTrue(
error(x_pred, x_real) < 0.01,
f"Accuracy test failed for size {token_num}, {dim}. x_real={x_real}, x_pred={x_pred}",
)
print(f"{error(x_pred, x_real) = }")

def test_performance(self):
"""Test the performance of rmsnorm using benchmark."""
for token_num in self.tokens:
for dim in self.dims:
with self.subTest(shape=[token_num, dim]):
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)

shape = [token_num, dim]
tflops = 0.0
benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W)
benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6)
Comment on lines +47 to +59

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This performance test uses the benchmark utility, which is flawed for in-place operations like torch_add_rmsnorm and fused_add_rmsnorm_inplace. The benchmark function modifies the input tensors on its first call, causing subsequent warm-up and measurement runs to use altered data. This will lead to unreliable performance results. The benchmark function in custom_kernel_utis.py needs to be fixed to handle functions with side effects correctly before this performance test can be considered reliable.



if __name__ == "__main__":
unittest.main()