-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[BugFix] replace the input_to_float8 used in dsv2 #11612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
523214a
74733a5
926fe55
81fd5bc
930d9f5
1cdf859
c2f3988
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,7 +92,6 @@ | |
| block_quant_dequant, | ||
| block_quant_to_tensor_quant, | ||
| channel_quant_to_tensor_quant, | ||
| input_to_float8, | ||
| normalize_e4m3fn_to_e4m3fnuz, | ||
| requant_weight_ue8m0_inplace, | ||
| ) | ||
|
|
@@ -1619,15 +1618,15 @@ def forward_absorb_prepare( | |
| self.w_kc.to(torch.bfloat16) * self.w_scale, | ||
| ) | ||
| elif self.w_kc.dtype == torch.float8_e4m3fn: | ||
| # TODO fix the per_tensor_quant_mla_fp8 for cublas 12.9 | ||
| if _is_cublas_ge_129: | ||
| q_nope_val, q_nope_scale = input_to_float8( | ||
| q_nope.transpose(0, 1), torch.float8_e4m3fn | ||
| ) | ||
| else: | ||
| q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( | ||
| q_nope.transpose(0, 1), zero_allocator.allocate(1) | ||
| ) | ||
| # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612 | ||
| q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( | ||
| q_nope.transpose(0, 1), | ||
| ( | ||
| torch.zeros((1,), dtype=torch.float32, device=q_nope.device) | ||
| if _is_cublas_ge_129 | ||
| else zero_allocator.allocate(1) | ||
| ), | ||
| ) | ||
|
Comment on lines
+1622
to
+1629
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of code for quantization is duplicated in For example, you could create a method like this: def _quantize_for_bmm_fp8(self, x: torch.Tensor, zero_allocator: BumpAllocator):
return per_tensor_quant_mla_fp8(
x.transpose(0, 1),
(
torch.zeros((1,), dtype=torch.float32, device=x.device)
if _is_cublas_ge_129
else zero_allocator.allocate(1)
),
)Then you can replace the duplicated blocks with: |
||
| q_nope_out = bmm_fp8( | ||
| q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 | ||
| ) | ||
|
|
@@ -1768,14 +1767,14 @@ def forward_absorb_core( | |
| attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) | ||
|
|
||
| elif self.w_vc.dtype == torch.float8_e4m3fn: | ||
| if _is_cublas_ge_129: | ||
| attn_output_val, attn_output_scale = input_to_float8( | ||
| attn_output.transpose(0, 1), torch.float8_e4m3fn | ||
| ) | ||
| else: | ||
| attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( | ||
| attn_output.transpose(0, 1), zero_allocator.allocate(1) | ||
| ) | ||
| attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( | ||
| attn_output.transpose(0, 1), | ||
| ( | ||
| torch.zeros((1,), dtype=torch.float32, device=attn_output.device) | ||
| if _is_cublas_ge_129 | ||
| else zero_allocator.allocate(1) | ||
| ), | ||
| ) | ||
|
Comment on lines
+1770
to
+1777
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| attn_bmm_output = bmm_fp8( | ||
| attn_output_val, | ||
| self.w_vc, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.