Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
block_quant_dequant,
block_quant_to_tensor_quant,
channel_quant_to_tensor_quant,
input_to_float8,
normalize_e4m3fn_to_e4m3fnuz,
requant_weight_ue8m0_inplace,
)
Expand Down Expand Up @@ -1619,15 +1618,15 @@ def forward_absorb_prepare(
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
# TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9
if _is_cublas_ge_129:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
else:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), zero_allocator.allocate(1)
)
# fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
(
torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
if _is_cublas_ge_129
else zero_allocator.allocate(1)
),
)
Comment on lines +1622 to +1629
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for quantization is duplicated in forward_absorb_core (lines 1726-1733). To improve maintainability and reduce code duplication, consider extracting this logic into a helper method within the DeepseekV2AttentionMLA class.

For example, you could create a method like this:

    def _quantize_for_bmm_fp8(self, x: torch.Tensor, zero_allocator: BumpAllocator):
        return per_tensor_quant_mla_fp8(
            x.transpose(0, 1),
            (
                torch.zeros((1,), dtype=torch.float32, device=x.device)
                if _is_cublas_ge_129
                else zero_allocator.allocate(1)
            ),
        )

Then you can replace the duplicated blocks with:
q_nope_val, q_nope_scale = self._quantize_for_bmm_fp8(q_nope, zero_allocator)
and
attn_output_val, attn_output_scale = self._quantize_for_bmm_fp8(attn_output, zero_allocator)
respectively.

q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
Expand Down Expand Up @@ -1768,14 +1767,14 @@ def forward_absorb_core(
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)

elif self.w_vc.dtype == torch.float8_e4m3fn:
if _is_cublas_ge_129:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
else:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), zero_allocator.allocate(1)
)
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1),
(
torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
if _is_cublas_ge_129
else zero_allocator.allocate(1)
),
)
Comment on lines +1770 to +1777
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a duplicated block of code. As suggested in the comment for forward_absorb_prepare, this could be refactored into a helper method to avoid code duplication.

attn_bmm_output = bmm_fp8(
attn_output_val,
self.w_vc,
Expand Down
Loading