Skip to content

Commit 9f25eee

Browse files
akhilg-nvkaixih
andauthored
Add layernorm op for inputs of mixed dtype (#1926)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes Continued from #1914 with small review fixes included, as the original author will be out for the next week. --------- Co-authored-by: kaixih <[email protected]>
1 parent 8d708b2 commit 9f25eee

File tree

10 files changed

+852
-0
lines changed

10 files changed

+852
-0
lines changed

csrc/flashinfer_norm_binding.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double e
2626
void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps,
2727
bool enable_pdl);
2828

29+
void layernorm(Tensor out, Tensor input, Tensor gamma, Tensor beta, double eps);
30+
2931
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm, rmsnorm);
3032
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm, fused_add_rmsnorm);
3133
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_rmsnorm, gemma_rmsnorm);
3234
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm);
35+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm, layernorm);

csrc/norm.cu

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,36 @@ void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView w
160160
return true;
161161
});
162162
}
163+
164+
void layernorm(Tensor output, Tensor input, Tensor gamma, Tensor beta, double eps) {
165+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
166+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(gamma);
167+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(beta);
168+
CHECK_DEVICE(input, gamma);
169+
CHECK_DEVICE(input, beta);
170+
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
171+
CHECK_DIM(1, gamma); // gamma: (hidden_size)
172+
CHECK_DIM(1, beta); // beta: (hidden_size)
173+
TVM_FFI_ICHECK_EQ(input->shape[1], gamma->shape[0]);
174+
TVM_FFI_ICHECK_EQ(input->shape[1], beta->shape[0]);
175+
unsigned int batch_size = input->shape[0];
176+
unsigned int hidden_size = input->shape[1];
177+
TVM_FFI_ICHECK_EQ(output->shape[0], batch_size);
178+
TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size);
179+
cudaSetDevice(input->device.device_id);
180+
const cudaStream_t stream = get_stream(input->device);
181+
// TODO(kaixih): This is currently our only use case; Add more if needed.
182+
TVM_FFI_ICHECK_EQ(input->dtype, dl_bfloat16) << "input must be bfloat16";
183+
TVM_FFI_ICHECK_EQ(gamma->dtype, dl_float32) << "gamma must be float32";
184+
TVM_FFI_ICHECK_EQ(beta->dtype, dl_float32) << "beta must be float32";
185+
186+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
187+
cudaError_t status =
188+
norm::LayerNorm(static_cast<c_type*>(input->data), static_cast<float*>(gamma->data),
189+
static_cast<float*>(beta->data), static_cast<c_type*>(output->data),
190+
batch_size, hidden_size, eps, stream);
191+
TVM_FFI_ICHECK(status == cudaSuccess)
192+
<< "LayerNorm failed with error code " << cudaGetErrorString(status);
193+
return true;
194+
});
195+
}

docs/api/norm.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ Kernels for normalization layers.
1414
fused_add_rmsnorm
1515
gemma_rmsnorm
1616
gemma_fused_add_rmsnorm
17+
layernorm

flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100
9191
from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper
9292
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
93+
from .norm import layernorm as layernorm
9394
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
9495
from .norm import gemma_rmsnorm as gemma_rmsnorm
9596
from .norm import rmsnorm as rmsnorm

flashinfer/jit/norm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@
1919

2020

2121
def gen_norm_module() -> JitSpec:
22+
nvcc_flags = [
23+
"-DENABLE_BF16",
24+
"-DENABLE_FP8",
25+
]
2226
return gen_jit_spec(
2327
"norm",
2428
[
2529
jit_env.FLASHINFER_CSRC_DIR / "norm.cu",
2630
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_norm_binding.cu",
2731
],
32+
extra_cuda_cflags=nvcc_flags,
2833
)

flashinfer/norm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,43 @@ def _gemma_fused_add_rmsnorm_fake(
244244
enable_pdl: Optional[bool] = None,
245245
) -> None:
246246
pass
247+
248+
249+
@register_custom_op("flashinfer::layernorm", mutates_args=())
250+
def layernorm(
251+
input: torch.Tensor,
252+
gemma: torch.Tensor,
253+
beta: torch.Tensor,
254+
eps: float = 1e-6,
255+
) -> torch.Tensor:
256+
r"""Layer normalization.
257+
Parameters
258+
----------
259+
input: torch.Tensor
260+
Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
261+
gemma: torch.Tensor
262+
Gemma tensor, shape (hidden_size,). Need to be float32.
263+
beta: torch.Tensor
264+
Beta tensor, shape (hidden_size,). Need to be float32.
265+
eps: float
266+
Epsilon for numerical stability.
267+
268+
Returns
269+
-------
270+
output: torch.Tensor
271+
Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
272+
"""
273+
out = torch.empty_like(input)
274+
get_norm_module().layernorm(out, input, gemma, beta, eps)
275+
return out
276+
277+
278+
@register_fake_op("flashinfer::layernorm")
279+
def _layernorm_fake(
280+
input: torch.Tensor,
281+
gemma: torch.Tensor,
282+
beta: torch.Tensor,
283+
eps: float = 1e-6,
284+
) -> torch.Tensor:
285+
b, k = input.shape
286+
return input.new_empty([b, k])

0 commit comments

Comments
 (0)