Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/flashinfer_norm_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

void rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, bool enable_pdl);

void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps,
bool enable_pdl);
void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, TensorView scale,
double eps, bool enable_pdl);

void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps,
bool enable_pdl);

void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
TensorView weight, double scale, double eps, bool enable_pdl);
TensorView weight, TensorView scale, double eps, bool enable_pdl);

void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double eps,
bool enable_pdl);
Expand Down
16 changes: 10 additions & 6 deletions csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps,
}
}

void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps,
bool enable_pdl) {
void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, TensorView scale,
double eps, bool enable_pdl) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
CHECK_DEVICE(input, weight);
CHECK_DEVICE(input, scale);
CHECK_DIM(1, weight); // weight: (hidden_size)
TVM_FFI_ICHECK_EQ(scale.numel(), 1);

auto input_ndim = input.ndim();
if (input_ndim == 2) {
Expand All @@ -103,7 +105,7 @@ void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, doubl
cudaError_t status = norm::RMSNormQuant(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
static_cast<o_type*>(output.data_ptr()), batch_size, hidden_size, input.stride(0),
output.stride(0), static_cast<float>(scale), eps, enable_pdl, stream);
output.stride(0), static_cast<float*>(scale.data_ptr()), eps, enable_pdl, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RMSNormQuant failed with error code " << cudaGetErrorString(status);
return true;
Expand Down Expand Up @@ -145,14 +147,15 @@ void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight,
}

void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
TensorView weight, double scale, double eps, bool enable_pdl) {
TensorView weight, TensorView scale, double eps, bool enable_pdl) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
CHECK_DEVICE(input, residual);
CHECK_DEVICE(input, weight);
CHECK_DEVICE(input, output);
CHECK_DEVICE(input, scale);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
Expand All @@ -162,6 +165,7 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res
TVM_FFI_ICHECK_EQ(residual.size(0), batch_size);
TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size);
TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size);
TVM_FFI_ICHECK_EQ(scale.numel(), 1);
ffi::CUDADeviceGuard device_guard(input.device().device_id);
const cudaStream_t stream = get_stream(input.device());

Expand All @@ -170,8 +174,8 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res
cudaError_t status = norm::FusedAddRMSNormQuant(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), static_cast<o_type*>(output.data_ptr()),
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), scale,
eps, enable_pdl, stream);
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0),
static_cast<float*>(scale.data_ptr()), eps, enable_pdl, stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "FusedAddRMSNormQuant failed with error code " << cudaGetErrorString(status);
Expand Down
2 changes: 2 additions & 0 deletions docs/api/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ Kernels for normalization layers.
:toctree: ../generated

rmsnorm
rmsnorm_quant
fused_add_rmsnorm
fused_add_rmsnorm_quant
gemma_rmsnorm
gemma_fused_add_rmsnorm
layernorm
31 changes: 13 additions & 18 deletions flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def rmsnorm_quant(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: float,
scale: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Root mean square normalization.
) -> None:
r"""Root mean square normalization + fp8 quantization.

``out[i] = (input[i] / RMS(input)) * weight[i]``
``out[i] = ((input[i] / RMS(input)) * weight[i]).to(fp8)``

Parameters
----------
Expand All @@ -114,18 +114,13 @@ def rmsnorm_quant(
Input tensor, 2D shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
scale: float
Scale factor for quantization.
scale: torch.Tensor
Scale factor for quantization, shape (1,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_

Returns
-------
output: torch.Tensor
Normalized tensor, 2D shape (batch_size, hidden_size).
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
Expand All @@ -137,7 +132,7 @@ def _rmsnorm_quant_fake(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: float,
scale: torch.Tensor,
eps: float,
enable_pdl: Optional[bool],
) -> None:
Expand Down Expand Up @@ -200,17 +195,17 @@ def fused_add_rmsnorm_quant(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: float,
scale: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Fused add root mean square normalization.
r"""Fused add root mean square normalization + fp8 quantization.

Step 1:
``residual[i] += input[i]``

Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
``input[i] = ((residual[i] / RMS(residual)) * weight[i]).to(fp8)``

Parameters
----------
Expand All @@ -222,8 +217,8 @@ def fused_add_rmsnorm_quant(
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
scale: float
Scale factor for quantization.
scale: torch.Tensor
Scale factor for quantization, shape (1,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Expand All @@ -243,7 +238,7 @@ def _fused_add_rmsnorm_quant_fake(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: float,
scale: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ template <uint32_t VEC_SIZE, typename T, typename O>
__global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight,
O* __restrict__ output, const uint32_t d,
const uint32_t stride_input, const uint32_t stride_output,
float weight_bias, float scale, float eps) {
float weight_bias, float* scale, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
Expand All @@ -158,7 +158,7 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
const float scale_inv = 1.0f / scale;
const float scale_inv = 1.0f / scale[0];
extern __shared__ float smem[];

float sum_sq = 0.f;
Expand Down Expand Up @@ -228,7 +228,7 @@ __global__ void RMSNormQuantKernel(T* __restrict__ input, T* __restrict__ weight

template <typename T, typename O>
cudaError_t RMSNormQuant(T* input, T* weight, O* output, uint32_t batch_size, uint32_t d,
uint32_t stride_input, uint32_t stride_output, float scale,
uint32_t stride_input, uint32_t stride_output, float* scale,
float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

Expand Down Expand Up @@ -519,15 +519,15 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_
const uint32_t d, const uint32_t stride_input,
const uint32_t stride_residual,
const uint32_t stride_output, float weight_bias,
float scale, float eps) {
float* scale, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
const float scale_inv = 1.0f / scale;
const float scale_inv = 1.0f / scale[0];
extern __shared__ float smem[];
float* smem_x = smem + ceil_div(num_warps, 4) * 4;

Expand Down Expand Up @@ -613,7 +613,7 @@ __global__ void FusedAddRMSNormQuantKernel(T* __restrict__ input, T* __restrict_
template <typename T, typename O>
cudaError_t FusedAddRMSNormQuant(T* input, T* residual, T* weight, O* output, uint32_t batch_size,
uint32_t d, uint32_t stride_input, uint32_t stride_residual,
uint32_t stride_output, float scale, float eps = 1e-5,
uint32_t stride_output, float* scale, float eps = 1e-5,
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

Expand Down
12 changes: 10 additions & 2 deletions tests/utils/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def test_norm_quant(

y_ref = llama_rms_norm_quant(x, w, quant_scale)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda")
flashinfer.norm.rmsnorm_quant(y, x, w, quant_scale, enable_pdl=enable_pdl)
flashinfer.norm.rmsnorm_quant(
y, x, w, torch.tensor(quant_scale, device="cuda"), enable_pdl=enable_pdl
)

torch.testing.assert_close(y_ref.float(), y.float(), rtol=1, atol=1)

Expand Down Expand Up @@ -250,7 +252,13 @@ def test_fused_add_rmsnorm_quant(
residual_fused = residual.clone()
y = torch.empty_like(x, dtype=torch.float8_e4m3fn, device="cuda")
flashinfer.norm.fused_add_rmsnorm_quant(
y, x_fused, residual_fused, weight, quant_scale, eps, enable_pdl=enable_pdl
y,
x_fused,
residual_fused,
weight,
torch.tensor(quant_scale, device="cuda"),
eps,
enable_pdl=enable_pdl,
)

torch.testing.assert_close(y.float(), x_native.float(), rtol=1, atol=1)
Expand Down
Loading