-
Notifications
You must be signed in to change notification settings - Fork 280
add rmsnorm-add fusion kernel #996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def _fwd_fused_add_rmsnorm( | ||
original, | ||
residual, | ||
weight, | ||
original_stride0, | ||
original_stride1, | ||
residual_stride0, | ||
residual_stride1, | ||
N, # number of columns in X | ||
eps, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
block_id = tl.program_id(0) | ||
# data's base address of this block | ||
_original = original + block_id * original_stride0 | ||
_residual = residual + block_id * residual_stride0 | ||
|
||
# avoid repeat loading from gmem to smem | ||
# in some very large size, have better performance | ||
if N <= BLOCK_SIZE: | ||
# data's offset address of this block | ||
range = tl.arange(0, BLOCK_SIZE) | ||
_original_offset = range * original_stride1 | ||
_residual_offset = range * residual_stride1 | ||
_weight_offset = range | ||
|
||
# data's pointers of this block | ||
_original_ptr = _original + _original_offset | ||
_residual_ptr = _residual + _residual_offset | ||
_weight_ptr = weight + _weight_offset | ||
|
||
# load data from memory | ||
mask = range < N | ||
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) | ||
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32) | ||
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32) | ||
|
||
# store (original + residual) to original | ||
original_cache = original_cache + residual_cache | ||
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask) | ||
|
||
# compute variance | ||
var = tl.sum(original_cache * original_cache) / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
residual_cache = original_cache * rstd * weight_cache | ||
|
||
# store rmsnorm(original + residual) back to residual | ||
tl.store(_residual_ptr, residual_cache.to(residual.dtype.element_ty), mask=mask) | ||
else: | ||
sum_of_squares = tl.zeros([], dtype=tl.float32) | ||
for block_offset in range(0, N, BLOCK_SIZE): | ||
# data's offset address of this block | ||
range = tl.arange(0, BLOCK_SIZE) + block_offset | ||
_original_offset = range * original_stride1 | ||
_residual_offset = range * residual_stride1 | ||
|
||
# data's pointers of this block | ||
_original_ptr = _original + _original_offset | ||
_residual_ptr = _residual + _residual_offset | ||
|
||
# load data from memory | ||
mask = range < N | ||
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) | ||
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32) | ||
|
||
# store (original + residual) to original | ||
original_cache = original_cache + residual_cache | ||
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask) | ||
|
||
# compute sum_of_squares | ||
sum_of_squares += tl.sum(original_cache * original_cache) | ||
|
||
# compute variance | ||
var = sum_of_squares / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
|
||
for block_offset in range(0, N, BLOCK_SIZE): | ||
# data's offset address of this block | ||
range = tl.arange(0, BLOCK_SIZE) + block_offset | ||
_original_offset = range * original_stride1 | ||
_residual_offset = range * residual_stride1 | ||
_weight_offset = range | ||
|
||
# data's pointers of this block | ||
_original_ptr = _original + _original_offset | ||
_residual_ptr = _residual + _residual_offset | ||
_weight_ptr = weight + _weight_offset | ||
|
||
# load data from memory | ||
mask = range < N | ||
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) | ||
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32) | ||
|
||
# apply rmsnorm using pre-computed rstd | ||
original_cache = original_cache * rstd * weight_cache | ||
|
||
# store rmsnorm(original) back to residual | ||
tl.store(_residual_ptr, original_cache.to(residual.dtype.element_ty), mask=mask) | ||
|
||
|
||
def fused_add_rmsnorm_inplace( | ||
original: torch.Tensor, # [num_tokens, hidden_size] | ||
residual: torch.Tensor, | ||
weight: torch.Tensor, | ||
eps: float, | ||
): | ||
""" | ||
Perform fused add & rmsnorm | ||
suppose the skip connection result is H(x) = F(x) + x, | ||
then F(x) is the residual, x is the original. | ||
Here original will be (residual + original), residual will be rmsnorm(residual + original) | ||
At first Layer, residual should be all zeros. | ||
""" | ||
# reshape input data into 2D tensor | ||
original_arg = original.view(-1, original.shape[-1]) | ||
residual_arg = residual.view(-1, residual.shape[-1]) | ||
|
||
assert original.data_ptr() == original_arg.data_ptr() | ||
assert residual.data_ptr() == residual_arg.data_ptr() | ||
|
||
M, N = original_arg.shape | ||
# Less than 64KB per feature: enqueue fused kernel | ||
MAX_FUSED_SIZE = 65536 // original.element_size() | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) | ||
|
||
if N > BLOCK_SIZE: | ||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") | ||
|
||
# heuristics for number of warps | ||
num_warps = min(max(BLOCK_SIZE // 256, 1), 4) | ||
num_warps = triton.next_power_of_2(num_warps) | ||
if BLOCK_SIZE > 16384: | ||
BLOCK_SIZE = 16384 | ||
|
||
# enqueue kernel | ||
_fwd_fused_add_rmsnorm[(M,)]( | ||
original_arg, | ||
residual_arg, | ||
weight, | ||
original_arg.stride(0), | ||
original_arg.stride(1), | ||
residual_arg.stride(0), | ||
residual_arg.stride(1), | ||
N, # number of columns in X | ||
eps, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
num_warps=num_warps, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
from typing import List | ||
from typing import List, Callable | ||
|
||
|
||
def custom_cat(tensors): | ||
|
@@ -125,3 +125,126 @@ def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int): | |
out[0:origin_batch_size, :] = input | ||
out[origin_batch_size:, :] = input[0:1, :] | ||
return out | ||
|
||
|
||
def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute SNR between y_pred(tensor) and y_real(tensor) | ||
|
||
SNR can be calcualted as following equation: | ||
|
||
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 | ||
|
||
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. | ||
|
||
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) | ||
|
||
Args: | ||
y_pred (torch.Tensor): _description_ | ||
y_real (torch.Tensor): _description_ | ||
reduction (str, optional): _description_. Defaults to 'mean'. | ||
|
||
Raises: | ||
ValueError: _description_ | ||
ValueError: _description_ | ||
|
||
Returns: | ||
torch.Tensor: _description_ | ||
""" | ||
Comment on lines
+130
to
+153
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type hint for the def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> float:
"""
Compute SNR error between y_pred(tensor) and y_real(tensor).
The SNR error is calculated as the ratio of noise power to signal power:
`sum((y_pred - y_real)^2) / sum(y_real^2)`
Args:
y_pred (torch.Tensor): The predicted tensor.
y_real (torch.Tensor): The ground truth tensor.
Raises:
ValueError: If tensors have different shapes.
Returns:
float: The computed SNR error value.
""" |
||
y_pred = torch.flatten(y_pred).float() | ||
y_real = torch.flatten(y_real).float() | ||
|
||
if y_pred.shape != y_real.shape: | ||
raise ValueError( | ||
f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})" | ||
) | ||
|
||
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) | ||
signal_power = torch.pow(y_real, 2).sum(dim=-1) | ||
snr = (noise_power) / (signal_power + 1e-7) | ||
return snr.item() | ||
|
||
|
||
def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs): | ||
""" | ||
A decorator function to assist in performance testing of CUDA operations. | ||
|
||
This function will: | ||
1. Automatically determine whether any parameters in the argument list, | ||
or the output of the `func`, are of type `torch.Tensor`. | ||
2. If so, calculate the memory usage of the input and output tensors | ||
on the GPU (based on their data type and `torch.numel()`). | ||
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations. | ||
4. Record the execution time during these iterations. | ||
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput. | ||
|
||
Args: | ||
func (function): The function to benchmark. | ||
shape (list of int): The problem shape. | ||
tflops (float): The computational workload (in TFLOPS) per call of `func`. | ||
steps (int): The number of times the function is executed during benchmarking. | ||
*args: Positional arguments to be passed to the `func`. | ||
**kwargs: Keyword arguments to be passed to the `func`. | ||
|
||
Returns: | ||
function result | ||
""" | ||
Comment on lines
+168
to
+191
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for the def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
"""
A utility function to assist in performance testing of CUDA operations.
This function will:
1. Automatically determine whether any parameters in the argument list,
or the output of the `func`, are of type `torch.Tensor`.
2. If so, calculate the memory usage of the input and output tensors
on the GPU (based on their data type and `torch.numel()`).
3. Execute `func` repeatedly for `steps` iterations after a warm-up period.
4. Record the execution time during these iterations using CUDA events.
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.
Args:
func (function): The function to benchmark.
shape (list of int): The problem shape.
tflops (float): The computational workload (in TFLOPS) per call of `func`.
steps (int): The number of times the function is executed during benchmarking.
*args: Positional arguments to be passed to the `func`.
**kwargs: Keyword arguments to be passed to the `func`.
Returns:
None
""" |
||
|
||
# Ensure CUDA is available | ||
if not torch.cuda.is_available(): | ||
raise RuntimeError("CUDA is required for benchmarking.") | ||
|
||
# Check for torch.Tensor in inputs and outputs | ||
input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] | ||
input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)] | ||
|
||
def calculate_memory(tensor: torch.Tensor): | ||
"""Calculate memory usage in bytes for a tensor.""" | ||
return tensor.numel() * tensor.element_size() | ||
|
||
input_memory = sum(calculate_memory(t) for t in input_tensors) | ||
|
||
# Execute the function to inspect outputs | ||
with torch.no_grad(): | ||
output = func(*args, **kwargs) | ||
Comment on lines
+208
to
+209
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The initial call to |
||
|
||
output_memory = 0 | ||
if isinstance(output, torch.Tensor): | ||
output_memory = calculate_memory(output) | ||
elif isinstance(output, (list, tuple)): | ||
output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor)) | ||
|
||
total_memory = input_memory + output_memory | ||
|
||
# Warm-up and CUDA graph creation | ||
for _ in range(10): # Warm-up | ||
func(*args, **kwargs) | ||
|
||
torch.cuda.synchronize() # Ensure no pending operations | ||
|
||
# Benchmark the function | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
|
||
start_event.record() | ||
for _ in range(steps): | ||
func(*args, **kwargs) | ||
end_event.record() | ||
|
||
torch.cuda.synchronize() # Ensure all operations are finished | ||
elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds | ||
|
||
# Calculate performance metrics | ||
elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds | ||
avg_time_per_step = elapsed_time_s / steps | ||
compute_performance = tflops / avg_time_per_step # TFLOPS | ||
memory_throughput = (total_memory * steps / (1024 ** 3)) / elapsed_time_s # GB/s | ||
|
||
# Print performance metrics | ||
print(f"Function: {func.__name__}{shape}") | ||
# print(f"Function: {func.__ne__}{shape}") | ||
print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds") | ||
print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms") | ||
print(f"Compute Performance: {compute_performance:.2f} TFLOPS") | ||
print(f"Memory Throughput: {memory_throughput:.2f} GB/s") | ||
print("") # print a blank line. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import unittest | ||
import torch | ||
from lightllm.models.llama.triton_kernel.fused_add_rmsnorm_inplace import fused_add_rmsnorm_inplace | ||
from lightllm.utils.custom_kernel_utis import benchmark, error | ||
|
||
|
||
class TestFusedAddRmsNormInplace(unittest.TestCase): | ||
def setUp(self): | ||
"""Set up common test parameters.""" | ||
self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384] | ||
self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.device = "cuda" | ||
self.dtype = torch.bfloat16 | ||
|
||
def torch_add_rmsnorm(self, X, R, W): | ||
X.add_(R) | ||
return torch.nn.functional.rms_norm(X, (X.shape[1],), W, eps=1e-6) | ||
|
||
def test_accuracy(self): | ||
"""Test the accuracy of fused_add_rmsnorm_inplace against torch.rmsnorm.""" | ||
for token_num in self.tokens: | ||
for dim in self.dims: | ||
with self.subTest(shape=[token_num, dim]): | ||
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) | ||
_X = X.clone() | ||
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) | ||
_R = R.clone() | ||
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype) | ||
|
||
r_real = self.torch_add_rmsnorm(_X, _R, W) | ||
fused_add_rmsnorm_inplace(X, R, W, eps=1e-6) | ||
r_pred = R | ||
self.assertTrue( | ||
error(r_pred, r_real) < 0.01, | ||
f"Accuracy test failed for size {token_num}, {dim}. r_real={r_real}, r_pred={r_pred}", | ||
) | ||
print(f"{error(r_pred, r_real) = }") | ||
|
||
x_real = _X | ||
x_pred = X | ||
self.assertTrue( | ||
error(x_pred, x_real) < 0.01, | ||
f"Accuracy test failed for size {token_num}, {dim}. x_real={x_real}, x_pred={x_pred}", | ||
) | ||
print(f"{error(x_pred, x_real) = }") | ||
|
||
def test_performance(self): | ||
"""Test the performance of rmsnorm using benchmark.""" | ||
for token_num in self.tokens: | ||
for dim in self.dims: | ||
with self.subTest(shape=[token_num, dim]): | ||
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) | ||
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) | ||
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype) | ||
|
||
shape = [token_num, dim] | ||
tflops = 0.0 | ||
benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W) | ||
benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6) | ||
Comment on lines
+47
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This performance test uses the |
||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check
if N > BLOCK_SIZE:
on line 133 disables the multi-block execution path (theelse
block) in the Triton kernel. The kernel is designed to handle feature dimensionsN
larger thanBLOCK_SIZE
, but this check prevents it. This makes theelse
branch in the kernel dead code and limits the kernel's applicability to large feature dimensions. The hard capif BLOCK_SIZE > 16384:
on line 139 is also overly restrictive. Remove the error check and the hard cap.