Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,11 @@ def _gather_hidden_states_and_residual(
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
if hasattr(layernorm, 'forward_with_allreduce_fusion'):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(hidden_states, residual)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using getattr with a default value to avoid checking for the attribute forward_with_allreduce_fusion before accessing it. This can simplify the code and make it more readable.

hidden_states, residual = getattr(layernorm, 'forward_with_allreduce_fusion', layernorm)(hidden_states, residual)

return hidden_states, residual

@staticmethod
Expand Down
215 changes: 215 additions & 0 deletions python/sglang/srt/layers/flashinfer_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FlashInfer fusion operations for allreduce + RMS norm."""

import logging
from typing import Tuple

import torch
import torch.distributed as dist

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.utils import is_flashinfer_available

logger = logging.getLogger(__name__)

_flashinfer_comm = None
_workspace_manager = None

if is_flashinfer_available():
try:
import flashinfer.comm as comm
_flashinfer_comm = comm
except ImportError:
logger.warning(
"flashinfer.comm is not available, falling back to standard "
"implementation"
)


class FlashInferWorkspaceManager:
def __init__(self):
self.workspace_tensor = None
self.ipc_handles = None
self.world_size = None
self.rank = None
self.initialized = False

def initialize(self, world_size: int, rank: int, max_token_num: int,
hidden_dim: int, group=None,
use_fp32_lamport: bool = False):
"""Initialize workspace"""
if self.initialized and self.world_size == world_size:
return

if _flashinfer_comm is None:
logger.warning(
"FlashInfer comm not available, skipping workspace "
"initialization"
)
return

# try:
# Clean up previous workspace
self.cleanup()

# Create new workspace
self.ipc_handles, self.workspace_tensor = (
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
max_token_num,
hidden_dim,
group=group,
use_fp32_lamport=use_fp32_lamport,
)
)

self.world_size = world_size
self.rank = rank
self.initialized = True

logger.info(
f"FlashInfer workspace initialized for rank {rank}, "
f"world_size {world_size}"
)

def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.ipc_handles is not None:
try:
_flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
self.ipc_handles,
group=dist.group.WORLD
)
except Exception as e:
logger.warning(f"Failed to cleanup FlashInfer workspace: {e}")
finally:
self.workspace_tensor = None
self.ipc_handles = None
self.initialized = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check to ensure self.initialized is False before attempting to clean up the workspace. This could prevent potential issues if cleanup is called multiple times without a corresponding initialize call.

        if self.initialized and self.ipc_handles is not None:
            if not self.initialized:
                return
            try:



_workspace_manager = FlashInferWorkspaceManager()


def ensure_workspace_initialized(max_token_num: int = 8192,
hidden_dim: int = 4096,
use_fp32_lamport: bool = False):
"""Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None:
return False

world_size = get_tensor_model_parallel_world_size()
if world_size <= 1:
return False

rank = dist.get_rank()

if (not _workspace_manager.initialized or
_workspace_manager.world_size != world_size):
_workspace_manager.initialize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition to check if the workspace needs re-initialization should also include a check for max_token_num and hidden_dim to ensure the workspace is re-initialized if these values change. If these parameters change without re-initialization, it could lead to memory corruption or crashes.

    if (not _workspace_manager.initialized or 
            _workspace_manager.world_size != world_size or
            _workspace_manager.max_token_num != max_token_num or
            _workspace_manager.hidden_dim != hidden_dim):

world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
use_fp32_lamport=use_fp32_lamport
)

return _workspace_manager.initialized

def flashinfer_allreduce_add_rmsnorm(
input_tensor: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 8192,
use_oneshot: bool = True,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation
Args:
input_tensor: Input tensor that needs allreduce
residual: Residual tensor
weight: RMS norm weight
eps: RMS norm epsilon
max_token_num: Maximum token number
use_oneshot: Whether to use oneshot mode
trigger_completion_at_end: Whether to trigger completion at end
fp32_acc: Whether to use fp32 precision
Returns:
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
"""
if not is_flashinfer_available() or _flashinfer_comm is None:
logger.debug(
"FlashInfer not available, falling back to standard "
"implementation"
)
return None, None

world_size = get_tensor_model_parallel_world_size()
if world_size <= 1:
logger.debug("Single GPU, no need for allreduce fusion")
return None, None

if not ensure_workspace_initialized(
max_token_num=max_token_num,
hidden_dim=input_tensor.shape[-1],
use_fp32_lamport=(input_tensor.dtype == torch.float32)
):
logger.debug("FlashInfer workspace not available")
return None, None

token_num, hidden_dim = input_tensor.shape

residual_out = torch.empty_like(residual)
norm_out = torch.empty_like(input_tensor)

_flashinfer_comm.trtllm_allreduce_fusion(
allreduce_in=input_tensor,
world_size=world_size,
world_rank=dist.get_rank(),
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=_workspace_manager.workspace_tensor,
launch_with_pdl=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=(
_flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm
),
allreduce_out=None,
residual_in=residual,
residual_out=residual_out,
norm_out=norm_out,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=eps,
scale_factor=None,
layout_code=None,
)

return norm_out, residual_out


def cleanup_flashinfer_workspace():
global _workspace_manager
if _workspace_manager is not None:
_workspace_manager.cleanup()
26 changes: 26 additions & 0 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
is_cuda,
is_hip,
is_npu,
is_flashinfer_available
)

_is_cuda = is_cuda()
Expand Down Expand Up @@ -163,6 +164,31 @@ def forward_cpu(
else:
return self.forward_native(x, residual)

def forward_with_allreduce_fusion(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward method with allreduce fusion, prioritizing flashinfer fused operations
"""
if is_flashinfer_available() and residual is not None:
from sglang.srt.layers.flashinfer_fusion import flashinfer_allreduce_add_rmsnorm
from sglang.srt.distributed import get_tensor_model_parallel_world_size

# Only use fusion operation in multi-GPU environment
if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,
eps=self.variance_epsilon,
)
if fused_result[0] is not None:
return fused_result

return self.forward(x, residual)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The fallback logic in forward_with_allreduce_fusion skips the tensor_model_parallel_all_reduce operation when the fused kernel is not available, leading to incorrect results in multi-GPU environments. The all-reduce operation must be performed before applying the layer norm to maintain correctness.

        if is_flashinfer_available() and residual is not None:
            from sglang.srt.layers.flashinfer_fusion import flashinfer_allreduce_add_rmsnorm
            from sglang.srt.distributed import get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce
            
            # Only use fusion operation in multi-GPU environment
            if get_tensor_model_parallel_world_size() > 1:
                fused_result = flashinfer_allreduce_add_rmsnorm(
                    input_tensor=x,
                    residual=residual,
                    weight=self.weight,
                    eps=self.variance_epsilon,
                )
                if fused_result[0] is not None:
                    return fused_result
        x = tensor_model_parallel_all_reduce(x)
        return self.forward(x, residual)



class GemmaRMSNorm(CustomOp):
def __init__(
Expand Down
Loading