diff --git a/lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py b/lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py new file mode 100644 index 000000000..51853c8a9 --- /dev/null +++ b/lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py @@ -0,0 +1,155 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_fused_add_rmsnorm( + original, + residual, + weight, + original_stride0, + original_stride1, + residual_stride0, + residual_stride1, + N, # number of columns in X + eps, + BLOCK_SIZE: tl.constexpr, +): + block_id = tl.program_id(0) + # data's base address of this block + _original = original + block_id * original_stride0 + _residual = residual + block_id * residual_stride0 + + # avoid repeat loading from gmem to smem + # in some very large size, have better performance + if N <= BLOCK_SIZE: + # data's offset address of this block + range = tl.arange(0, BLOCK_SIZE) + _original_offset = range * original_stride1 + _residual_offset = range * residual_stride1 + _weight_offset = range + + # data's pointers of this block + _original_ptr = _original + _original_offset + _residual_ptr = _residual + _residual_offset + _weight_ptr = weight + _weight_offset + + # load data from memory + mask = range < N + original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) + residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32) + weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32) + + # store (original + residual) to original + original_cache = original_cache + residual_cache + tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask) + + # compute variance + var = tl.sum(original_cache * original_cache) / N + rstd = 1 / tl.sqrt(var + eps) + residual_cache = original_cache * rstd * weight_cache + + # store rmsnorm(original + residual) back to residual + tl.store(_residual_ptr, residual_cache.to(residual.dtype.element_ty), mask=mask) + else: + sum_of_squares = tl.zeros([], dtype=tl.float32) + for block_offset in range(0, N, BLOCK_SIZE): + # data's offset address of this block + range = tl.arange(0, BLOCK_SIZE) + block_offset + _original_offset = range * original_stride1 + _residual_offset = range * residual_stride1 + + # data's pointers of this block + _original_ptr = _original + _original_offset + _residual_ptr = _residual + _residual_offset + + # load data from memory + mask = range < N + original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) + residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32) + + # store (original + residual) to original + original_cache = original_cache + residual_cache + tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask) + + # compute sum_of_squares + sum_of_squares += tl.sum(original_cache * original_cache) + + # compute variance + var = sum_of_squares / N + rstd = 1 / tl.sqrt(var + eps) + + for block_offset in range(0, N, BLOCK_SIZE): + # data's offset address of this block + range = tl.arange(0, BLOCK_SIZE) + block_offset + _original_offset = range * original_stride1 + _residual_offset = range * residual_stride1 + _weight_offset = range + + # data's pointers of this block + _original_ptr = _original + _original_offset + _residual_ptr = _residual + _residual_offset + _weight_ptr = weight + _weight_offset + + # load data from memory + mask = range < N + original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32) + weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32) + + # apply rmsnorm using pre-computed rstd + original_cache = original_cache * rstd * weight_cache + + # store rmsnorm(original) back to residual + tl.store(_residual_ptr, original_cache.to(residual.dtype.element_ty), mask=mask) + + +def fused_add_rmsnorm_inplace( + original: torch.Tensor, # [num_tokens, hidden_size] + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +): + """ + Perform fused add & rmsnorm + + suppose the skip connection result is H(x) = F(x) + x, + then F(x) is the residual, x is the original. + Here original will be (residual + original), residual will be rmsnorm(residual + original) + At first Layer, residual should be all zeros. + """ + # reshape input data into 2D tensor + original_arg = original.view(-1, original.shape[-1]) + residual_arg = residual.view(-1, residual.shape[-1]) + + assert original.data_ptr() == original_arg.data_ptr() + assert residual.data_ptr() == residual_arg.data_ptr() + + M, N = original_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // original.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 4) + num_warps = triton.next_power_of_2(num_warps) + if BLOCK_SIZE > 16384: + BLOCK_SIZE = 16384 + + # enqueue kernel + _fwd_fused_add_rmsnorm[(M,)]( + original_arg, + residual_arg, + weight, + original_arg.stride(0), + original_arg.stride(1), + residual_arg.stride(0), + residual_arg.stride(1), + N, # number of columns in X + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) diff --git a/lightllm/utils/custom_kernel_utis.py b/lightllm/utils/custom_kernel_utis.py index 9a7578a24..cf7d26adf 100644 --- a/lightllm/utils/custom_kernel_utis.py +++ b/lightllm/utils/custom_kernel_utis.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from typing import List +from typing import List, Callable def custom_cat(tensors): @@ -125,3 +125,126 @@ def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int): out[0:origin_batch_size, :] = input out[origin_batch_size:, :] = input[0:1, :] return out + + +def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + torch.Tensor: _description_ + """ + y_pred = torch.flatten(y_pred).float() + y_real = torch.flatten(y_real).float() + + if y_pred.shape != y_real.shape: + raise ValueError( + f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})" + ) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + return snr.item() + + +def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs): + """ + A decorator function to assist in performance testing of CUDA operations. + + This function will: + 1. Automatically determine whether any parameters in the argument list, + or the output of the `func`, are of type `torch.Tensor`. + 2. If so, calculate the memory usage of the input and output tensors + on the GPU (based on their data type and `torch.numel()`). + 3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations. + 4. Record the execution time during these iterations. + 5. Use the information above to compute the compute performance (TFLOPS) and memory throughput. + + Args: + func (function): The function to benchmark. + shape (list of int): The problem shape. + tflops (float): The computational workload (in TFLOPS) per call of `func`. + steps (int): The number of times the function is executed during benchmarking. + *args: Positional arguments to be passed to the `func`. + **kwargs: Keyword arguments to be passed to the `func`. + + Returns: + function result + """ + + # Ensure CUDA is available + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for benchmarking.") + + # Check for torch.Tensor in inputs and outputs + input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] + input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)] + + def calculate_memory(tensor: torch.Tensor): + """Calculate memory usage in bytes for a tensor.""" + return tensor.numel() * tensor.element_size() + + input_memory = sum(calculate_memory(t) for t in input_tensors) + + # Execute the function to inspect outputs + with torch.no_grad(): + output = func(*args, **kwargs) + + output_memory = 0 + if isinstance(output, torch.Tensor): + output_memory = calculate_memory(output) + elif isinstance(output, (list, tuple)): + output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor)) + + total_memory = input_memory + output_memory + + # Warm-up and CUDA graph creation + for _ in range(10): # Warm-up + func(*args, **kwargs) + + torch.cuda.synchronize() # Ensure no pending operations + + # Benchmark the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(steps): + func(*args, **kwargs) + end_event.record() + + torch.cuda.synchronize() # Ensure all operations are finished + elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds + + # Calculate performance metrics + elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds + avg_time_per_step = elapsed_time_s / steps + compute_performance = tflops / avg_time_per_step # TFLOPS + memory_throughput = (total_memory * steps / (1024 ** 3)) / elapsed_time_s # GB/s + + # Print performance metrics + print(f"Function: {func.__name__}{shape}") + # print(f"Function: {func.__ne__}{shape}") + print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds") + print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms") + print(f"Compute Performance: {compute_performance:.2f} TFLOPS") + print(f"Memory Throughput: {memory_throughput:.2f} GB/s") + print("") # print a blank line. diff --git a/unit_tests/models/llama/test_fused_add_rmsnorm.py b/unit_tests/models/llama/test_fused_add_rmsnorm.py new file mode 100644 index 000000000..7999876a9 --- /dev/null +++ b/unit_tests/models/llama/test_fused_add_rmsnorm.py @@ -0,0 +1,63 @@ +import unittest +import torch +from lightllm.models.llama.triton_kernel.fused_add_rmsnorm_inplace import fused_add_rmsnorm_inplace +from lightllm.utils.custom_kernel_utis import benchmark, error + + +class TestFusedAddRmsNormInplace(unittest.TestCase): + def setUp(self): + """Set up common test parameters.""" + self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384] + self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800] + self.device = "cuda" + self.dtype = torch.bfloat16 + + def torch_add_rmsnorm(self, X, R, W): + X.add_(R) + return torch.nn.functional.rms_norm(X, (X.shape[1],), W, eps=1e-6) + + def test_accuracy(self): + """Test the accuracy of fused_add_rmsnorm_inplace against torch.rmsnorm.""" + for token_num in self.tokens: + for dim in self.dims: + with self.subTest(shape=[token_num, dim]): + X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) + _X = X.clone() + R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) + _R = R.clone() + W = torch.randn(size=[dim], device=self.device, dtype=self.dtype) + + r_real = self.torch_add_rmsnorm(_X, _R, W) + fused_add_rmsnorm_inplace(X, R, W, eps=1e-6) + r_pred = R + self.assertTrue( + error(r_pred, r_real) < 0.01, + f"Accuracy test failed for size {token_num}, {dim}. r_real={r_real}, r_pred={r_pred}", + ) + print(f"{error(r_pred, r_real) = }") + + x_real = _X + x_pred = X + self.assertTrue( + error(x_pred, x_real) < 0.01, + f"Accuracy test failed for size {token_num}, {dim}. x_real={x_real}, x_pred={x_pred}", + ) + print(f"{error(x_pred, x_real) = }") + + def test_performance(self): + """Test the performance of rmsnorm using benchmark.""" + for token_num in self.tokens: + for dim in self.dims: + with self.subTest(shape=[token_num, dim]): + X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) + R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype) + W = torch.randn(size=[dim], device=self.device, dtype=self.dtype) + + shape = [token_num, dim] + tflops = 0.0 + benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W) + benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6) + + +if __name__ == "__main__": + unittest.main()