|
94 | 94 | block_quant_dequant, |
95 | 95 | block_quant_to_tensor_quant, |
96 | 96 | channel_quant_to_tensor_quant, |
| 97 | + input_to_float8, |
97 | 98 | normalize_e4m3fn_to_e4m3fnuz, |
98 | 99 | requant_weight_ue8m0_inplace, |
99 | 100 | ) |
|
131 | 132 | is_hip, |
132 | 133 | is_non_idle_and_non_empty, |
133 | 134 | is_npu, |
| 135 | + is_nvidia_cublas_cu12_version_ge_12_9, |
134 | 136 | is_sm100_supported, |
135 | 137 | log_info_on_rank0, |
136 | 138 | make_layers, |
|
189 | 191 |
|
190 | 192 | _is_flashinfer_available = is_flashinfer_available() |
191 | 193 | _is_sm100_supported = is_cuda() and is_sm100_supported() |
| 194 | +_is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9() |
192 | 195 |
|
193 | 196 |
|
194 | 197 | logger = logging.getLogger(__name__) |
@@ -1572,10 +1575,15 @@ def forward_absorb_prepare( |
1572 | 1575 | self.w_kc.to(torch.bfloat16) * self.w_scale, |
1573 | 1576 | ) |
1574 | 1577 | elif self.w_kc.dtype == torch.float8_e4m3fn: |
1575 | | - q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( |
1576 | | - q_nope.transpose(0, 1), |
1577 | | - zero_allocator.allocate(1), |
1578 | | - ) |
| 1578 | + # TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9 |
| 1579 | + if _is_cublas_ge_129: |
| 1580 | + q_nope_val, q_nope_scale = input_to_float8( |
| 1581 | + q_nope.transpose(0, 1), torch.float8_e4m3fn |
| 1582 | + ) |
| 1583 | + else: |
| 1584 | + q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( |
| 1585 | + q_nope.transpose(0, 1), zero_allocator.allocate(1) |
| 1586 | + ) |
1579 | 1587 | q_nope_out = bmm_fp8( |
1580 | 1588 | q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 |
1581 | 1589 | ) |
@@ -1716,10 +1724,14 @@ def forward_absorb_core( |
1716 | 1724 | attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) |
1717 | 1725 |
|
1718 | 1726 | elif self.w_vc.dtype == torch.float8_e4m3fn: |
1719 | | - attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( |
1720 | | - attn_output.transpose(0, 1), |
1721 | | - zero_allocator.allocate(1), |
1722 | | - ) |
| 1727 | + if _is_cublas_ge_129: |
| 1728 | + attn_output_val, attn_output_scale = input_to_float8( |
| 1729 | + attn_output.transpose(0, 1), torch.float8_e4m3fn |
| 1730 | + ) |
| 1731 | + else: |
| 1732 | + attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( |
| 1733 | + attn_output.transpose(0, 1), zero_allocator.allocate(1) |
| 1734 | + ) |
1723 | 1735 | attn_bmm_output = bmm_fp8( |
1724 | 1736 | attn_output_val, |
1725 | 1737 | self.w_vc, |
|
0 commit comments