Skip to content

Conversation

@akhilg-nv
Copy link
Contributor

@akhilg-nv akhilg-nv commented Oct 14, 2025

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Continued from #1914 with small review fixes included, as the original author will be out for the next week.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @akhilg-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a highly optimized Layer Normalization operation to the FlashInfer library, specifically designed to handle mixed precision inputs (bfloat16 for input, float32 for parameters). The core of this feature is a new CUDA kernel that leverages advanced GPU programming techniques and integrates utility functions from TensorRT-LLM for efficient data handling and reductions. This enhancement aims to improve performance and flexibility for models utilizing Layer Normalization with mixed-precision training or inference.

Highlights

  • New LayerNorm Operation: Introduced a new layernorm operation with a dedicated CUDA kernel for efficient computation.
  • Mixed Dtype Support: The layernorm implementation supports mixed data types, specifically bfloat16 for the input tensor and float32 for the gamma and beta parameters.
  • CUDA Kernel Implementation: A generalLayerNorm CUDA kernel has been added, leveraging shared memory and various reduction utilities for optimized performance.
  • Python API Integration: The new layernorm operation is exposed through the Python API using register_custom_op and includes a fake op for testing.
  • New CUDA Utilities: Integrated common CUDA utility headers from tensorrt_llm for type packing, element counting, and reduction operations, enabling more flexible and performant kernel development.
  • Comprehensive Testing: Added a new test case to validate the layernorm implementation against torch.nn.functional.layer_norm across various batch and hidden sizes with bfloat16 inputs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a layernorm operation designed for mixed-precision inputs, specifically bfloat16 for the input tensor and float32 for the weight and bias. The implementation appears to be adapted from TensorRT-LLM, providing a generic kernel that includes currently disabled quantization features. The changes are well-structured across CUDA kernels, C++ bindings, and the Python API, and include appropriate tests. My primary feedback focuses on resolving a naming inconsistency where the scaling parameter is referred to as gemma instead of the standard gamma, which could cause confusion. I have also included a minor suggestion to improve code maintainability.

Comment on lines 92 to 96
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and maintainability, please keep the imports from the same module sorted alphabetically.

Suggested change
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import layernorm as layernorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
from .norm import fused_add_rmsnorm as fused_add_rmsnorm
from .norm import gemma_fused_add_rmsnorm as gemma_fused_add_rmsnorm
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import layernorm as layernorm
from .norm import rmsnorm as rmsnorm

Comment on lines +250 to +275
def layernorm(
input: torch.Tensor,
gemma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
r"""Layer normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
gemma: torch.Tensor
Gemma tensor, shape (hidden_size,). Need to be float32.
beta: torch.Tensor
Beta tensor, shape (hidden_size,). Need to be float32.
eps: float
Epsilon for numerical stability.
Returns
-------
output: torch.Tensor
Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
"""
out = torch.empty_like(input)
get_norm_module().layernorm(out, input, gemma, beta, eps)
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter gemma is confusingly named. In the context of layer normalization, this parameter is standardly referred to as gamma. Using gemma could be misleading, especially since "Gemma" is also the name of a popular model family. The C++ binding in csrc/norm.cu already uses gamma. For consistency and clarity, please rename gemma to gamma in the function signature, docstring, and the call to the backend module.

def layernorm(
    input: torch.Tensor,
    gamma: torch.Tensor,
    beta: torch.Tensor,
    eps: float = 1e-6,
) -> torch.Tensor:
    r"""Layer normalization.
    Parameters
    ----------
    input: torch.Tensor
        Input tensor, shape (batch_size, hidden_size). Need to be bfloat16.
    gamma: torch.Tensor
        Gamma tensor, shape (hidden_size,). Need to be float32.
    beta: torch.Tensor
        Beta tensor, shape (hidden_size,). Need to be float32.
    eps: float
        Epsilon for numerical stability.

    Returns
    -------
    output: torch.Tensor
        Layer Normalized tensor, shape (batch_size, hidden_size). Same dtype as input.
    """
    out = torch.empty_like(input)
    get_norm_module().layernorm(out, input, gamma, beta, eps)
    return out

Comment on lines +279 to +284
def _layernorm_fake(
input: torch.Tensor,
gemma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the proposed change in layernorm, please also rename gemma to gamma in this fake operator implementation.

Suggested change
def _layernorm_fake(
input: torch.Tensor,
gemma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:
def _layernorm_fake(
input: torch.Tensor,
gamma: torch.Tensor,
beta: torch.Tensor,
eps: float = 1e-6,
) -> torch.Tensor:

Comment on lines +483 to +489
__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, T const* gemma,
T const* beta, int i) {
Tf ret = (val - s_mean) * s_variance * cuda_cast<Tf>(gemma[i]);
if (beta != nullptr) {
ret = ret + cuda_cast<Tf>(beta[i]);
}
return ret;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter gemma should be renamed to gamma for consistency and clarity. The standard term for this parameter in layer normalization is gamma.

__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, T const* gamma,
                                           T const* beta, int i) {
  Tf ret = (val - s_mean) * s_variance * cuda_cast<Tf>(gamma[i]);
  if (beta != nullptr) {
    ret = ret + cuda_cast<Tf>(beta[i]);
  }
  return ret;
}

Comment on lines +492 to +498
template <typename T, typename Tw, typename QuantT, bool USE_SHMEM,
bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(T const* input, Tw const* gemma, Tw const* beta, T* normed_output,
float const eps, int tokens, int hidden_dim,
float const* clamp_ptr, float const* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token, float* sum_per_token,
QuantT* normed_output_quant, bool has_fp8_min_scaling) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter gemma should be renamed to gamma for consistency and clarity. Please update it in the function signature and its usage within the function body.

template <typename T, typename Tw, typename QuantT, bool USE_SHMEM,
          bool USE_DIFF_OF_SQUARES = false>
__global__ void generalLayerNorm(T const* input, Tw const* gamma, Tw const* beta, T* normed_output,
                                 float const eps, int tokens, int hidden_dim,
                                 float const* clamp_ptr, float const* scale_orig_quant_per_tensor,
                                 float* scale_orig_quant_per_token, float* sum_per_token,
                                 QuantT* normed_output_quant, bool has_fp8_min_scaling) {

Comment on lines +691 to +692
cudaError_t LayerNorm(T* input, Tw* gemma, Tw* beta, T* out, uint32_t tokens, uint32_t hidden_dim,
float eps = 1e-5, cudaStream_t stream = 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter gemma should be renamed to gamma for consistency and clarity. Please update it in the function signature and in the calls within the function body.

cudaError_t LayerNorm(T* input, Tw* gamma, Tw* beta, T* out, uint32_t tokens, uint32_t hidden_dim,
                      float eps = 1e-5, cudaStream_t stream = 0) {

@yzh119
Copy link
Collaborator

yzh119 commented Oct 14, 2025

Hi @akhilg-nv how is this PR different to #1914?

@akhilg-nv
Copy link
Contributor Author

Hi @akhilg-nv how is this PR different to #1914?

The previous PR author is out for the next week and we want to get these changes merged in soon, so I just took over. Shall I go ahead with the gemma -> gamma suggested change?
The other comment on #1914 that I haven't addressed is refactoring reduceKernelUtils.cuh and norm.cuh which we can address in a separate PR.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @akhilg-nv thanks for clarification, LGTM overall.

@yzh119 yzh119 merged commit 9f25eee into flashinfer-ai:main Oct 15, 2025
2 checks passed
@yzh119
Copy link
Collaborator

yzh119 commented Oct 15, 2025

Sorry I just noticed the gemini's review and I think the suggested change is fair (gemma -> gamma, gemma is the name of a model and gamma is the variable name), we can leave it for another PR anyways.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants