Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/flashinfer_norm_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double e
void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps,
bool enable_pdl);

void layernorm(Tensor out, Tensor input, Tensor gamma, Tensor beta, double eps);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm, rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm, fused_add_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_rmsnorm, gemma_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm, layernorm);
33 changes: 33 additions & 0 deletions csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,36 @@ void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView w
return true;
});
}

void layernorm(Tensor output, Tensor input, Tensor gamma, Tensor beta, double eps) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(gamma);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(beta);
CHECK_DEVICE(input, gamma);
CHECK_DEVICE(input, beta);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, gamma); // gamma: (hidden_size)
CHECK_DIM(1, beta); // beta: (hidden_size)
TVM_FFI_ICHECK_EQ(input->shape[1], gamma->shape[0]);
TVM_FFI_ICHECK_EQ(input->shape[1], beta->shape[0]);
unsigned int batch_size = input->shape[0];
unsigned int hidden_size = input->shape[1];
TVM_FFI_ICHECK_EQ(output->shape[0], batch_size);
TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size);
cudaSetDevice(input->device.device_id);
const cudaStream_t stream = get_stream(input->device);
// TODO(kaixih): This is currently our only use case; Add more if needed.
TVM_FFI_ICHECK_EQ(input->dtype, dl_bfloat16) << "input must be bfloat16";
TVM_FFI_ICHECK_EQ(gamma->dtype, dl_float32) << "gamma must be float32";
TVM_FFI_ICHECK_EQ(beta->dtype, dl_float32) << "beta must be float32";

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
cudaError_t status =
norm::LayerNorm(static_cast<c_type*>(input->data), static_cast<float*>(gamma->data),
static_cast<float*>(beta->data), static_cast<c_type*>(output->data),
batch_size, hidden_size, eps, stream);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "LayerNorm failed with error code " << cudaGetErrorString(status);
return true;
});
}
1 change: 1 addition & 0 deletions docs/api/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Kernels for normalization layers.
fused_add_rmsnorm
gemma_rmsnorm
gemma_fused_add_rmsnorm
layernorm
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
Comment on lines 92 to 96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and maintainability, please keep the imports from the same module sorted alphabetically.

Suggested change
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import layernorm as layernorm
from .norm import rmsnorm as rmsnorm

Expand Down
5 changes: 5 additions & 0 deletions flashinfer/jit/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@


def gen_norm_module() -> JitSpec:
nvcc_flags = [
"-DENABLE_BF16",
"-DENABLE_FP8",
]
return gen_jit_spec(
"norm",
[
jit_env.FLASHINFER_CSRC_DIR / "norm.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_norm_binding.cu",
],
extra_cuda_cflags=nvcc_flags,
)
40 changes: 40 additions & 0 deletions flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,43 @@ def _gemma_fused_add_rmsnorm_fake(
enable_pdl: Optional[bool] = None,
) -> None:
pass


@register_custom_op("flashinfer::layernorm", mutates_args=())
def layernorm(
input: torch.Tensor,
gemma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
r"""Layer normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
gemma: torch.Tensor
Gemma tensor, shape (hidden_size,). Need to be float32.
beta: torch.Tensor
Beta tensor, shape (hidden_size,). Need to be float32.
eps: float
Epsilon for numerical stability.

Returns
-------
output: torch.Tensor
Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
"""
out = torch.empty_like(input)
get_norm_module().layernorm(out, input, gemma, beta, eps)
return out
Comment on lines +250 to +275
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter gemma is confusingly named. In the context of layer normalization, this parameter is standardly referred to as gamma. Using gemma could be misleading, especially since "Gemma" is also the name of a popular model family. The C++ binding in csrc/norm.cu already uses gamma. For consistency and clarity, please rename gemma to gamma in the function signature, docstring, and the call to the backend module.

def layernorm(
    input: torch.Tensor,
    gamma: torch.Tensor,
    beta: torch.Tensor,
    eps: float = 1e-6,
) -> torch.Tensor:
    r"""Layer normalization.
    Parameters
    ----------
    input: torch.Tensor
        Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
    gamma: torch.Tensor
        Gamma tensor, shape (hidden_size,). Need to be float32.
    beta: torch.Tensor
        Beta tensor, shape (hidden_size,). Need to be float32.
    eps: float
        Epsilon for numerical stability.

    Returns
    -------
    output: torch.Tensor
        Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
    """
    out = torch.empty_like(input)
    get_norm_module().layernorm(out, input, gamma, beta, eps)
    return out



@register_fake_op("flashinfer::layernorm")
def _layernorm_fake(
input: torch.Tensor,
gemma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
Comment on lines +279 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the proposed change in layernorm, please also rename gemma to gamma in this fake operator implementation.

Suggested change
def _layernorm_fake(
input: torch.Tensor,
gemma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
def _layernorm_fake(
input: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:

b, k = input.shape
return input.new_empty([b, k])
Loading