-
Notifications
You must be signed in to change notification settings - Fork 578
Add layernorm op for inputs of mixed dtype #1926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @akhilg-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a highly optimized Layer Normalization operation to the FlashInfer library, specifically designed to handle mixed precision inputs (bfloat16 for input, float32 for parameters). The core of this feature is a new CUDA kernel that leverages advanced GPU programming techniques and integrates utility functions from TensorRT-LLM for efficient data handling and reductions. This enhancement aims to improve performance and flexibility for models utilizing Layer Normalization with mixed-precision training or inference. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a layernorm operation designed for mixed-precision inputs, specifically bfloat16 for the input tensor and float32 for the weight and bias. The implementation appears to be adapted from TensorRT-LLM, providing a generic kernel that includes currently disabled quantization features. The changes are well-structured across CUDA kernels, C++ bindings, and the Python API, and include appropriate tests. My primary feedback focuses on resolving a naming inconsistency where the scaling parameter is referred to as gemma instead of the standard gamma, which could cause confusion. I have also included a minor suggestion to improve code maintainability.
| from .norm import fused_add_rmsnorm as fused_add_rmsnorm | ||
| from .norm import layernorm as layernorm | ||
| from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm | ||
| from .norm import gemma_rmsnorm as gemma_rmsnorm | ||
| from .norm import rmsnorm as rmsnorm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better readability and maintainability, please keep the imports from the same module sorted alphabetically.
| from .norm import fused_add_rmsnorm as fused_add_rmsnorm | |
| from .norm import layernorm as layernorm | |
| from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm | |
| from .norm import gemma_rmsnorm as gemma_rmsnorm | |
| from .norm import rmsnorm as rmsnorm | |
| from .norm import fused_add_rmsnorm as fused_add_rmsnorm | |
| from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm | |
| from .norm import gemma_rmsnorm as gemma_rmsnorm | |
| from .norm import layernorm as layernorm | |
| from .norm import rmsnorm as rmsnorm |
| def layernorm( | ||
| input: torch.Tensor, | ||
| gemma: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| ) -> torch.Tensor: | ||
| r"""Layer normalization. | ||
| Parameters | ||
| ---------- | ||
| input: torch.Tensor | ||
| Input tensor, shape (batch_size, hidden_size). Need to be bfloat16. | ||
| gemma: torch.Tensor | ||
| Gemma tensor, shape (hidden_size,). Need to be float32. | ||
| beta: torch.Tensor | ||
| Beta tensor, shape (hidden_size,). Need to be float32. | ||
| eps: float | ||
| Epsilon for numerical stability. | ||
| Returns | ||
| ------- | ||
| output: torch.Tensor | ||
| Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input. | ||
| """ | ||
| out = torch.empty_like(input) | ||
| get_norm_module().layernorm(out, input, gemma, beta, eps) | ||
| return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter gemma is confusingly named. In the context of layer normalization, this parameter is standardly referred to as gamma. Using gemma could be misleading, especially since "Gemma" is also the name of a popular model family. The C++ binding in csrc/norm.cu already uses gamma. For consistency and clarity, please rename gemma to gamma in the function signature, docstring, and the call to the backend module.
def layernorm(
input: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
r"""Layer normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
gamma: torch.Tensor
Gamma tensor, shape (hidden_size,). Need to be float32.
beta: torch.Tensor
Beta tensor, shape (hidden_size,). Need to be float32.
eps: float
Epsilon for numerical stability.
Returns
-------
output: torch.Tensor
Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
"""
out = torch.empty_like(input)
get_norm_module().layernorm(out, input, gamma, beta, eps)
return out| def _layernorm_fake( | ||
| input: torch.Tensor, | ||
| gemma: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| eps: float = 1e-6, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the proposed change in layernorm, please also rename gemma to gamma in this fake operator implementation.
| def _layernorm_fake( | |
| input: torch.Tensor, | |
| gemma: torch.Tensor, | |
| beta: torch.Tensor, | |
| eps: float = 1e-6, | |
| ) -> torch.Tensor: | |
| def _layernorm_fake( | |
| input: torch.Tensor, | |
| gamma: torch.Tensor, | |
| beta: torch.Tensor, | |
| eps: float = 1e-6, | |
| ) -> torch.Tensor: |
| __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, T const* gemma, | ||
| T const* beta, int i) { | ||
| Tf ret = (val - s_mean) * s_variance * cuda_cast<Tf>(gemma[i]); | ||
| if (beta != nullptr) { | ||
| ret = ret + cuda_cast<Tf>(beta[i]); | ||
| } | ||
| return ret; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter gemma should be renamed to gamma for consistency and clarity. The standard term for this parameter in layer normalization is gamma.
__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, T const* gamma,
T const* beta, int i) {
Tf ret = (val - s_mean) * s_variance * cuda_cast<Tf>(gamma[i]);
if (beta != nullptr) {
ret = ret + cuda_cast<Tf>(beta[i]);
}
return ret;
}
| template <typename T, typename Tw, typename QuantT, bool USE_SHMEM, | ||
| bool USE_DIFF_OF_SQUARES = false> | ||
| __global__ void generalLayerNorm(T const* input, Tw const* gemma, Tw const* beta, T* normed_output, | ||
| float const eps, int tokens, int hidden_dim, | ||
| float const* clamp_ptr, float const* scale_orig_quant_per_tensor, | ||
| float* scale_orig_quant_per_token, float* sum_per_token, | ||
| QuantT* normed_output_quant, bool has_fp8_min_scaling) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter gemma should be renamed to gamma for consistency and clarity. Please update it in the function signature and its usage within the function body.
template <typename T, typename Tw, typename QuantT, bool USE_SHMEM,
bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(T const* input, Tw const* gamma, Tw const* beta, T* normed_output,
float const eps, int tokens, int hidden_dim,
float const* clamp_ptr, float const* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token, float* sum_per_token,
QuantT* normed_output_quant, bool has_fp8_min_scaling) {
| cudaError_t LayerNorm(T* input, Tw* gemma, Tw* beta, T* out, uint32_t tokens, uint32_t hidden_dim, | ||
| float eps = 1e-5, cudaStream_t stream = 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter gemma should be renamed to gamma for consistency and clarity. Please update it in the function signature and in the calls within the function body.
cudaError_t LayerNorm(T* input, Tw* gamma, Tw* beta, T* out, uint32_t tokens, uint32_t hidden_dim,
float eps = 1e-5, cudaStream_t stream = 0) {
|
Hi @akhilg-nv how is this PR different to #1914? |
The previous PR author is out for the next week and we want to get these changes merged in soon, so I just took over. Shall I go ahead with the |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @akhilg-nv thanks for clarification, LGTM overall.
|
Sorry I just noticed the gemini's review and I think the suggested change is fair ( |
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Continued from #1914 with small review fixes included, as the original author will be out for the next week.