Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/flashinfer_norm_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

void rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, bool enable_pdl);

void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps,
bool enable_pdl);
void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, TensorView scale,
double eps, bool enable_pdl);

void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps,
bool enable_pdl);

void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
TensorView weight, double scale, double eps, bool enable_pdl);
TensorView weight, TensorView scale, double eps, bool enable_pdl);

void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double eps,
bool enable_pdl);
Expand Down
16 changes: 10 additions & 6 deletions csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps,
}
}

void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps,
bool enable_pdl) {
void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, TensorView scale,
double eps, bool enable_pdl) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
CHECK_DEVICE(input, weight);
CHECK_DEVICE(input, scale);
CHECK_DIM(1, weight); // weight: (hidden_size)
TVM_FFI_ICHECK_EQ(scale.numel(), 1);

auto input_ndim = input.ndim();
if (input_ndim == 2) {
Expand All @@ -103,7 +105,7 @@ void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, doubl
cudaError_t status = norm::RMSNormQuant(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
static_cast<o_type*>(output.data_ptr()), batch_size, hidden_size, input.stride(0),
output.stride(0), static_cast<float>(scale), eps, enable_pdl, stream);
output.stride(0), static_cast<float*>(scale.data_ptr()), eps, enable_pdl, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "RMSNormQuant failed with error code " << cudaGetErrorString(status);
return true;
Expand Down Expand Up @@ -145,14 +147,15 @@ void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight,
}

void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
TensorView weight, double scale, double eps, bool enable_pdl) {
TensorView weight, TensorView scale, double eps, bool enable_pdl) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
CHECK_DEVICE(input, residual);
CHECK_DEVICE(input, weight);
CHECK_DEVICE(input, output);
CHECK_DEVICE(input, scale);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
Expand All @@ -162,6 +165,7 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res
TVM_FFI_ICHECK_EQ(residual.size(0), batch_size);
TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size);
TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size);
TVM_FFI_ICHECK_EQ(scale.numel(), 1);
ffi::CUDADeviceGuard device_guard(input.device().device_id);
const cudaStream_t stream = get_stream(input.device());

Expand All @@ -170,8 +174,8 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res
cudaError_t status = norm::FusedAddRMSNormQuant(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), static_cast<o_type*>(output.data_ptr()),
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), scale,
eps, enable_pdl, stream);
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0),
static_cast<float*>(scale.data_ptr()), eps, enable_pdl, stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "FusedAddRMSNormQuant failed with error code " << cudaGetErrorString(status);
Expand Down
2 changes: 2 additions & 0 deletions docs/api/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ Kernels for normalization layers.
:toctree: ../generated

rmsnorm
rmsnorm_quant
fused_add_rmsnorm
fused_add_rmsnorm_quant
gemma_rmsnorm
gemma_fused_add_rmsnorm
layernorm
7 changes: 5 additions & 2 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm

from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
try:
from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
except (ImportError, AttributeError):
pass # nvidia-cutlass-dsl not installed
Comment on lines +115 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Expose quantized norm APIs at the package level.

Line 97-101 exports rmsnorm and fused_add_rmsnorm, but the new quantized variants (rmsnorm_quant, fused_add_rmsnorm_quant) from flashinfer.norm are still missing at the top level. Consider exporting them here so flashinfer.rmsnorm_quant works consistently.

✅ Suggested export additions
 from .norm import fused_add_rmsnorm as fused_add_rmsnorm
+from .norm import fused_add_rmsnorm_quant as fused_add_rmsnorm_quant
 from .norm import layernorm as layernorm
 from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
 from .norm import gemma_rmsnorm as gemma_rmsnorm
 from .norm import rmsnorm as rmsnorm
+from .norm import rmsnorm_quant as rmsnorm_quant

As per coding guidelines: Export new operations in flashinfer/init.py to make them available at package level.

🤖 Prompt for AI Agents
In `@flashinfer/__init__.py` around lines 103 - 107, The package-level exports for
the quantized norm variants are missing: import rmsnorm_fp4quant and
add_rmsnorm_fp4quant from flashinfer.norm (already attempted in the try block)
and then assign them to the public names used elsewhere (e.g., expose
rmsnorm_fp4quant as rmsnorm_quant and add_rmsnorm_fp4quant as
fused_add_rmsnorm_quant) so flashinfer.rmsnorm_quant and
flashinfer.fused_add_rmsnorm_quant resolve; update the try block in
flashinfer.__init__.py to perform these assignments (keep the existing
ImportError/AttributeError handling).

from .page import append_paged_kv_cache as append_paged_kv_cache
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
from .page import get_batch_indices_positions as get_batch_indices_positions
Expand Down
31 changes: 31 additions & 0 deletions flashinfer/cute_dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@
AddRMSNormFP4QuantKernel,
)

# Backwards-compatible re-exports from flashinfer.norm.kernels submodule
from ..norm.kernels import (
# Kernel classes
RMSNormKernel,
QKRMSNormKernel,
RMSNormQuantKernel,
FusedAddRMSNormKernel,
FusedAddRMSNormQuantKernel,
LayerNormKernel,
# Python API functions
rmsnorm_cute,
qk_rmsnorm_cute,
rmsnorm_quant_cute,
fused_add_rmsnorm_cute,
fused_add_rmsnorm_quant_cute,
layernorm_cute,
)

__all__ = [
# Utils (always available)
"is_cute_dsl_available",
Expand All @@ -79,4 +97,17 @@
# Add + RMSNorm + FP4 Quantization
"add_rmsnorm_fp4quant",
"AddRMSNormFP4QuantKernel",
# Norm kernels (CuTe DSL) - backwards-compatible re-exports
"RMSNormKernel",
"QKRMSNormKernel",
"RMSNormQuantKernel",
"FusedAddRMSNormKernel",
"FusedAddRMSNormQuantKernel",
"LayerNormKernel",
"rmsnorm_cute",
"qk_rmsnorm_cute",
"rmsnorm_quant_cute",
"fused_add_rmsnorm_cute",
"fused_add_rmsnorm_quant_cute",
"layernorm_cute",
]
4 changes: 2 additions & 2 deletions flashinfer/cute_dsl/add_rmsnorm_fp4quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,8 +1012,8 @@ def tensor_api(
s_tensor,
s_unswizzled.contiguous(),
global_scale,
Int32(M),
Float32(eps),
M,
eps,
)

return tensor_api
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/cute_dsl/rmsnorm_fp4quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,8 @@ def tensor_api(
y_uint8,
s_tensor,
global_scale,
Int32(M),
Float32(eps),
M,
eps,
)

return tensor_api
Expand Down
Loading
Loading