-
Notifications
You must be signed in to change notification settings - Fork 580
Add layernorm op for inputs of mixed dtype #1926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,3 +14,4 @@ Kernels for normalization layers. | |
| fused_add_rmsnorm | ||
| gemma_rmsnorm | ||
| gemma_fused_add_rmsnorm | ||
| layernorm | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -244,3 +244,43 @@ def _gemma_fused_add_rmsnorm_fake( | |||||||||||||||||||||||||
| enable_pdl: Optional[bool] = None, | ||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @register_custom_op("flashinfer::layernorm", mutates_args=()) | ||||||||||||||||||||||||||
| def layernorm( | ||||||||||||||||||||||||||
| input: torch.Tensor, | ||||||||||||||||||||||||||
| gemma: torch.Tensor, | ||||||||||||||||||||||||||
| beta: torch.Tensor, | ||||||||||||||||||||||||||
| eps: float = 1e-6, | ||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||
| r"""Layer normalization. | ||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||
| input: torch.Tensor | ||||||||||||||||||||||||||
| Input tensor, shape (batch_size, hidden_size). Need to be bfloat16. | ||||||||||||||||||||||||||
| gemma: torch.Tensor | ||||||||||||||||||||||||||
| Gemma tensor, shape (hidden_size,). Need to be float32. | ||||||||||||||||||||||||||
| beta: torch.Tensor | ||||||||||||||||||||||||||
| Beta tensor, shape (hidden_size,). Need to be float32. | ||||||||||||||||||||||||||
| eps: float | ||||||||||||||||||||||||||
| Epsilon for numerical stability. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||
| output: torch.Tensor | ||||||||||||||||||||||||||
| Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
| out = torch.empty_like(input) | ||||||||||||||||||||||||||
| get_norm_module().layernorm(out, input, gemma, beta, eps) | ||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||
|
Comment on lines
+250
to
+275
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter def layernorm(
input: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
r"""Layer normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
gamma: torch.Tensor
Gamma tensor, shape (hidden_size,). Need to be float32.
beta: torch.Tensor
Beta tensor, shape (hidden_size,). Need to be float32.
eps: float
Epsilon for numerical stability.
Returns
-------
output: torch.Tensor
Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
"""
out = torch.empty_like(input)
get_norm_module().layernorm(out, input, gamma, beta, eps)
return out |
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @register_fake_op("flashinfer::layernorm") | ||||||||||||||||||||||||||
| def _layernorm_fake( | ||||||||||||||||||||||||||
| input: torch.Tensor, | ||||||||||||||||||||||||||
| gemma: torch.Tensor, | ||||||||||||||||||||||||||
| beta: torch.Tensor, | ||||||||||||||||||||||||||
| eps: float = 1e-6, | ||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||
|
Comment on lines
+279
to
+284
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency with the proposed change in
Suggested change
|
||||||||||||||||||||||||||
| b, k = input.shape | ||||||||||||||||||||||||||
| return input.new_empty([b, k]) | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better readability and maintainability, please keep the imports from the same module sorted alphabetically.