diff --git a/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py b/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py index 088b1ce415..e4eb04c9fc 100644 --- a/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py +++ b/aiter/ops/triton/_triton_kernels/fused_fp8_quant.py @@ -31,6 +31,112 @@ def _fp8_quant_op( return x, scale_out +@triton.jit +def _fused_rms_fp8_per_tensor_static_quant_kernel( + inp1_ptr, + weight1_ptr, + inp2_ptr, + weight2_ptr, + res1_ptr, + out1_fp8_ptr, + out2_ptr, + out_res1_ptr, + out1_ptr, + scale_ptr, + eps1, + eps2, + n_rows, + inp1_n_cols, + inp2_n_cols, + inp1_row_stride, + inp2_row_stride, + inp1_col_stride, + inp2_col_stride, + res1_row_stride, + res1_col_stride, + out1_fp8_row_stride, + out1_fp8_col_stride, + out2_row_stride, + out2_col_stride, + out_res1_row_stride, + out_res1_col_stride, + out1_row_stride, + out1_col_stride, + BLOCK_SIZE_N: tl.constexpr, + DTYPE_MAX: tl.constexpr, + DTYPE_MIN: tl.constexpr, + HAVE_SECOND_INPUT: tl.constexpr, + FIRST_INPUT_RES: tl.constexpr, + FIRST_INPUT_OUT: tl.constexpr, +): + m_pid = tl.program_id(0) + n_offs = tl.arange(0, BLOCK_SIZE_N) + + mask1 = n_offs < inp1_n_cols + inp1 = tl.load( + inp1_ptr + m_pid * inp1_row_stride + n_offs * inp1_col_stride, + mask=mask1, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + + if FIRST_INPUT_RES: + res1 = tl.load( + res1_ptr + m_pid * res1_row_stride + n_offs * res1_col_stride, + mask=mask1, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + inp1 = inp1 + res1 + + w1 = tl.load(weight1_ptr + n_offs, mask=mask1, other=0.0).to(tl.float32) + norm1 = _rmsmorm_op(inp1, w1, inp1_n_cols, eps1) + + if FIRST_INPUT_OUT: + mask1 = n_offs < inp1_n_cols + tl.store( + out1_ptr + m_pid * out1_row_stride + n_offs * out1_col_stride, + norm1, + mask=mask1, + ) + + # apply quantization + scale = tl.load(scale_ptr).to(tl.float32) + scale_recip = 1.0 / scale + out1_fp8 = tl.clamp(norm1 * scale_recip, DTYPE_MIN, DTYPE_MAX) + + # store the results + tl.store( + out1_fp8_ptr + m_pid * out1_fp8_row_stride + n_offs * out1_fp8_col_stride, + out1_fp8.to(out1_fp8_ptr.dtype.element_ty), + mask=mask1, + ) + + if HAVE_SECOND_INPUT: + mask2 = n_offs < inp2_n_cols + inp2 = tl.load( + inp2_ptr + m_pid * inp2_row_stride + n_offs * inp2_col_stride, + mask=mask2, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + w2 = tl.load(weight2_ptr + n_offs, mask=mask2, other=0.0).to(tl.float32) + norm2 = _rmsmorm_op(inp2, w2, inp2_n_cols, eps2) + tl.store( + out2_ptr + m_pid * out2_row_stride + n_offs * out2_col_stride, + norm2, + mask=mask2, + ) + + if FIRST_INPUT_RES: + inp1 = inp1.to(out_res1_ptr.dtype.element_ty) + tl.store( + out_res1_ptr + m_pid * out_res1_row_stride + n_offs * out_res1_col_stride, + inp1, + mask=mask1, + ) + + @triton.jit def _fused_rms_fp8_group_quant_kernel( inp1_ptr, diff --git a/aiter/ops/triton/fused_fp8_quant.py b/aiter/ops/triton/fused_fp8_quant.py index 39e1c58777..709a2d4f23 100644 --- a/aiter/ops/triton/fused_fp8_quant.py +++ b/aiter/ops/triton/fused_fp8_quant.py @@ -3,6 +3,7 @@ import triton import aiter from aiter.ops.triton._triton_kernels.fused_fp8_quant import ( + _fused_rms_fp8_per_tensor_static_quant_kernel, _fused_rms_fp8_group_quant_kernel, _fused_flatten_fp8_group_quant_kernel, _fused_reduce_act_mul_fp8_group_quant, @@ -18,6 +19,141 @@ fp8_dtype = aiter.dtypes.fp8 +def fused_rms_fp8_per_tensor_static_quant( + inp1, + inp1_weight, + inp1_epsilon, + inp1_scale, + inp2=None, + inp2_weight=None, + inp2_epsilon=None, + dtype_quant=fp8_dtype, + res1=None, + output_unquantized_inp1=False, +): + """ + This op contains several steps: + 1. if res1 is not None, inp1 = inp1 + res1, and store inp1 to out_res1 + 2. perform RMS norm along the last dimenion for inp1 + 3. if inp2 is not None, perform RMS norm along the last dimenion for inp2 + 4. perform fp8 quantization for inp1 only + + Key parameters: + - x: Matrix X with shape (M, N1, N2). + + Returns: + - out1_fp8: The output matrix with shape (M, N1). + - out1_s: The output matrix with shape (1,). + - out1: The output matrix with shape (M, N1). + - out2: The output matrix with shape (M, N2). + - out_res1: The output matrix with shape (M, N1). + - out1: The output matrix with shape (M, N1). + """ + M, N1 = inp1.shape + BLOCK_SIZE_N = triton.next_power_of_2(N1) + if inp2 is not None: + M2, N2 = inp2.shape + BLOCK_SIZE_N = triton.next_power_of_2(N2) + assert ( + M == M2 + ), "The leading dimension should be identical between inp1 and inp2" + else: + N2 = 0 + out1_fp8 = torch.empty((M, N1), dtype=dtype_quant, device=inp1.device) + + out2 = None + out2_row_stride = 0 + out2_col_stride = 0 + inp2_row_stride = 0 + inp2_col_stride = 0 + if inp2 is not None: + out2 = torch.empty((M, N2), dtype=inp1.dtype, device=inp1.device) + inp2_row_stride = inp2.stride(0) + inp2_col_stride = inp2.stride(1) + out2_row_stride = out2.stride(0) + out2_col_stride = out2.stride(1) + + out1 = None + out1_row_stride = 0 + out1_col_stride = 0 + if output_unquantized_inp1: + out1 = torch.empty((M, N1), dtype=inp1.dtype, device=inp1.device) + out1_row_stride = out1.stride(0) + out1_col_stride = out2.stride(1) + + out_res1 = None + res1_row_stride = 0 + res1_col_stride = 0 + out_res1_row_stride = 0 + out_res1_col_stride = 0 + if res1 is not None: + Mr, Nr = res1.shape + assert ( + M == Mr and N1 == Nr + ), "The shape should be identical between inp1 and res1" + out_res1 = torch.empty((M, N1), dtype=inp1.dtype, device=inp1.device) + res1_row_stride = res1.stride(0) + res1_col_stride = res1.stride(1) + out_res1_row_stride = out_res1.stride(0) + out_res1_col_stride = out_res1.stride(1) + + if BLOCK_SIZE_N <= 512: + num_warps = 1 + elif BLOCK_SIZE_N <= 2048: + num_warps = 4 + elif BLOCK_SIZE_N <= 4096: + num_warps = 8 + else: + num_warps = 16 + + DTYPE_MAX = ( + torch.finfo(out1_fp8.dtype).max + if torch.is_floating_point(out1_fp8) + else torch.iinfo(out1_fp8.dtype).max + ) + + _fused_rms_fp8_per_tensor_static_quant_kernel[(M,)]( + inp1, + inp1_weight, + inp2, + inp2_weight, + res1, + out1_fp8, + out2, + out_res1, + out1, + inp1_scale, + inp1_epsilon, + inp2_epsilon, + M, + N1, + N2, + inp1.stride(0), + inp2_row_stride, + inp1.stride(1), + inp2_col_stride, + res1_row_stride, + res1_col_stride, + out1_fp8.stride(0), + out1_fp8.stride(1), + out2_row_stride, + out2_col_stride, + out_res1_row_stride, + out_res1_col_stride, + out1_row_stride, + out1_col_stride, + BLOCK_SIZE_N=BLOCK_SIZE_N, + DTYPE_MAX=DTYPE_MAX, + DTYPE_MIN=-DTYPE_MAX, + HAVE_SECOND_INPUT=(inp2 is not None), + FIRST_INPUT_RES=(res1 is not None), + FIRST_INPUT_OUT=output_unquantized_inp1, + num_warps=num_warps, + ) + + return out1_fp8, out1, out2, out_res1 + + def fused_rms_fp8_group_quant( inp1, inp1_weight, diff --git a/op_tests/triton_tests/test_fused_fp8_quant.py b/op_tests/triton_tests/test_fused_fp8_quant.py index 9621756f4f..c0ebaeae16 100644 --- a/op_tests/triton_tests/test_fused_fp8_quant.py +++ b/op_tests/triton_tests/test_fused_fp8_quant.py @@ -1,6 +1,7 @@ import torch import pytest from aiter.ops.triton.fused_fp8_quant import ( + fused_rms_fp8_per_tensor_static_quant, fused_rms_fp8_group_quant, fused_flatten_fp8_group_quant, fused_reduce_act_mul_fp8_group_quant, @@ -37,6 +38,13 @@ def per_token_fp8_group_quant(x, dtype_quant, group_size=128): return x_quant, x_scale +def per_tensor_fp8_static_quant(x, dtype_quant, x_scale): + DTYPE_MAX = torch.finfo(dtype_quant).max + scale_recip = 1.0 / x_scale + x_quant = torch.clamp(x * scale_recip, -DTYPE_MAX, DTYPE_MAX).to(dtype_quant) + return x_quant + + def upcast(x, s, dtype, group_size=128): x_N = x.shape[1] x = x.reshape(-1, x_N // group_size, group_size).to(torch.float32) * s.reshape( @@ -65,6 +73,54 @@ def generate_fused_rms_quant_data(M, N1, N2, dtype=torch.bfloat16): return x1, w1, x2, w2, res1 +def run_torch_rms_fp8_per_tensor_static_quant( + x1, w1, eps1, x2, w2, eps2, res1, dtype_quant, x1_scale +): + s = x1 + res1 + y1 = rmsnorm(s, w1, eps1) + y2 = rmsnorm(x2, w2, eps2) + y1_q = per_tensor_fp8_static_quant(y1, dtype_quant, x1_scale) + return y1_q, y1.to(x1.dtype), y2.to(x1.dtype), s.to(x1.dtype) + + +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_rms_fp8_per_tensor_static_quant(M: int, N1: int, N2: int, dtype): + dtype_quant = aiter.dtypes.fp8 + scale = torch.randn(1, dtype=torch.float32, device="cuda") + x1, w1, x2, w2, res1 = generate_fused_rms_quant_data(M, N1, N2, dtype) + + y1_q_torch, y1_torch, y2_torch, y1_res_torch = ( + run_torch_rms_fp8_per_tensor_static_quant( + x1, w1, 1e-6, x2, w2, 1e-6, res1, dtype_quant, scale + ) + ) + + y1_q_triton, y1_triton, y2_triton, y1_res_triton = ( + fused_rms_fp8_per_tensor_static_quant( + x1, + w1, + 1e-6, + scale, + inp2=x2, + inp2_weight=w2, + inp2_epsilon=1e-6, + dtype_quant=dtype_quant, + res1=res1, + output_unquantized_inp1=True, + ) + ) + + torch.testing.assert_close(y1_torch, y1_triton, atol=0.1, rtol=0.1) + torch.testing.assert_close(y2_torch, y2_triton, atol=0.1, rtol=0.1) + torch.testing.assert_close(y1_res_torch, y1_res_triton, atol=0.1, rtol=0.1) + + y1_upcast_torch = y1_q_torch.to(torch.float32) * scale + y1_upcast_triton = y1_q_triton.to(torch.float32) * scale + torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1) + + @pytest.mark.parametrize("M", [1, 32, 256]) @pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])