Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@
):
_enable_jit_deepgemm = True


logger = logging.getLogger(__name__)


if supports_custom_op():

def deep_gemm_fp8_fp8_bf16_nt(
Expand Down Expand Up @@ -897,24 +895,28 @@ def _per_tensor_quant_mla_fp8_stage2(


def per_tensor_quant_mla_fp8(
x: torch.Tensor, eps: float = 1e-12
x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function quantizes input values to float8 values with tensor-wise quantization
and specialized for mla absorbed case.
"""
assert x.dim() == 3, "`x` is not a 3d-tensor"
assert (
x_s_out.shape == (1,)
and x_s_out.dtype == torch.float32
and x_s_out.device == x.device
)

x_q = x.new_empty(x.size(), dtype=_fp8_type)
x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)

num_head, num_seq, head_size = x.shape
BLOCK_SIZE = triton.next_power_of_2(head_size)
grid = (num_seq, num_head)

_per_tensor_quant_mla_fp8_stage1[grid](
x,
x_s,
x_s_out,
head_size,
x.stride(0),
x.stride(1),
Expand All @@ -924,7 +926,7 @@ def per_tensor_quant_mla_fp8(
)
_per_tensor_quant_mla_fp8_stage2[grid](
x,
x_s,
x_s_out,
x_q,
num_seq,
head_size,
Expand All @@ -935,7 +937,7 @@ def per_tensor_quant_mla_fp8(
BLOCK_SIZE,
)

return x_q, x_s
return x_q, x_s_out


def scaled_fp8_quant(
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/models/deepseek_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_cuda, is_hip
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip

_is_hip = is_hip()
_is_cuda = is_cuda()
Expand Down Expand Up @@ -91,6 +91,12 @@ def forward(
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
buffer_size=2,
dtype=torch.float32,
device=input_ids.device,
)

if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
Expand All @@ -108,7 +114,7 @@ def forward(

residual = None
hidden_states, residual = self.decoder(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)

if not forward_batch.forward_mode.is_idle():
Expand Down
41 changes: 32 additions & 9 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip

_is_hip = is_hip()
_is_cuda = is_cuda()
Expand All @@ -97,7 +97,6 @@


class AttnForwardMethod(IntEnum):

# Use multi-head attention
MHA = auto()

Expand Down Expand Up @@ -590,6 +589,7 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
Expand All @@ -615,9 +615,13 @@ def forward(
positions, hidden_states, forward_batch
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)

def forward_normal(
self,
Expand Down Expand Up @@ -666,6 +670,7 @@ def forward_absorb(
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
Expand All @@ -690,6 +695,7 @@ def forward_absorb(
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
Expand Down Expand Up @@ -721,6 +727,7 @@ def forward_absorb(
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
)
attn_bmm_output = bmm_fp8(
attn_output_val,
Expand All @@ -741,6 +748,7 @@ def forward_absorb_fused_mla_rope(
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
) -> torch.Tensor:
enable_rope_fusion = (
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
Expand All @@ -767,7 +775,9 @@ def forward_absorb_fused_mla_rope(
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
Expand Down Expand Up @@ -863,7 +873,9 @@ def forward_absorb_fused_mla_rope(
)
elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
attn_output.transpose(0, 1),
zero_allocator.allocate(1),
dtype=torch.float8_e4m3fn,
)
attn_bmm_output = bmm_fp8(
attn_output_val,
Expand Down Expand Up @@ -1115,14 +1127,15 @@ def forward(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
else:
raise NotImplementedError
Expand All @@ -1133,6 +1146,7 @@ def forward_ffn_with_full_input(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:

if hidden_states.shape[0] == 0:
Expand All @@ -1153,6 +1167,7 @@ def forward_ffn_with_full_input(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)

# Gather
Expand Down Expand Up @@ -1200,6 +1215,7 @@ def forward_ffn_with_scattered_input(
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:

if hidden_states.shape[0] == 0:
Expand All @@ -1225,6 +1241,7 @@ def forward_ffn_with_scattered_input(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
zero_allocator=zero_allocator,
)

if self.attn_tp_size != 1:
Expand Down Expand Up @@ -1312,6 +1329,12 @@ def forward(
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
zero_allocator = BumpAllocator(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size=len(self.layers) * 2,
dtype=torch.float32,
device=input_ids.device,
)

if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
Expand All @@ -1323,7 +1346,7 @@ def forward(
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,3 +1932,16 @@ def is_fa3_default_architecture(hf_config):
"MistralForCausalLM",
}
return architectures[0] in default_archs


# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class BumpAllocator:
def __init__(self, buffer_size: int, dtype, device):
self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
self._pointer = 0

def allocate(self, size: int):
assert self._pointer + size <= len(self._buffer)
output = self._buffer[self._pointer : self._pointer + size]
self._pointer += size
return output
Loading