diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 62cabd0a6e..9ea91d5508 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -107,10 +107,12 @@ from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper from .norm import fused_add_rmsnorm as fused_add_rmsnorm +from .norm import fused_add_rmsnorm_quant as fused_add_rmsnorm_quant from .norm import layernorm as layernorm from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm from .norm import gemma_rmsnorm as gemma_rmsnorm from .norm import rmsnorm as rmsnorm +from .norm import rmsnorm_quant as rmsnorm_quant try: from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant