[W8A8 Block Linear Refactor][2/N] Make Fp8 block linear Op use kernel abstraction.#33891
[W8A8 Block Linear Refactor][2/N] Make Fp8 block linear Op use kernel abstraction.#33891maralbahari wants to merge 15 commits intovllm-project:mainfrom
Conversation
Signed-off-by: maral <maralbahari.98@gmail.com>
There was a problem hiding this comment.
Code Review
This PR introduces a new kernel abstraction for FP8 block-scaled linear layers, which is a great step towards improving code clarity and maintainability. The changes are extensive and well-documented. However, I've found several critical issues in the implementation of the new DynamicMMLinearKernel and its integration, which could lead to runtime errors. These include logical errors in support checks, typos causing NameError, and type incompatibilities in kernel initialization. Please see the detailed comments for each issue.
vllm/model_executor/layers/quantization/kernels/scaled_mm/cuda.py
Outdated
Show resolved
Hide resolved
vllm/model_executor/layers/quantization/kernels/scaled_mm/BlockScaledMMLinearKernel.py
Show resolved
Hide resolved
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
…r.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
…kScaledMMLinearKernel.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
…ement for cutlass and fix type error in dynamic deepgemm/flash-infer Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
|
@robertgshaw2-redhat @ProExpertProg @mgoin cloud you review this PR. appreciate it. |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
| if ( | ||
| self.flashinfer_deepgemm_kernel is not None | ||
| and should_use_flashinfer_for_blockscale_fp8_gemm( | ||
| True, output_dtype, input_2d, weight |
There was a problem hiding this comment.
This is set to true because FlashInferFp8DeepGEMMDynamicBlockScaledKernel
is_flashinfer_fp8_blockscale_gemm_supported()
is evaluated in the __init__()
self.flashinfer_deepgemm_kernel: (
FlashInferFp8DeepGEMMDynamicBlockScaledKernel | None
) = None
if FlashInferFp8DeepGEMMDynamicBlockScaledKernel.is_supported()[0]:
So, this condition self.flashinfer_deepgemm_kernel is not None is testing whether is_flashinfer_supported.
We can now set first argument of should_use_flashinfer_for_blockscale_fp8_gemm to be True
There was a problem hiding this comment.
Benefit of doing this self.flashinfer_deepgemm_kernel is not None first is that it short-circuits the conditions.
There was a problem hiding this comment.
We should try static dispatching either in this PR or upcoming PR. By doing in another PR we can confine the changes of this PR as just refactoring. Either ways work for me.
| and should_use_flashinfer_for_blockscale_fp8_gemm( | ||
| True, output_dtype, input_2d, weight | ||
| ) | ||
| and should_use_deepgemm_for_fp8_linear(output_dtype, weight, True) |
There was a problem hiding this comment.
The reason that the last argument of should_use_deepgemm_for_fp8_linear can be set to True is the same as in https://github.com/vllm-project/vllm/pull/33891/changes#r2851594385
| return self.flashinfer_deepgemm_kernel.apply_weights(layer, x, bias) | ||
|
|
||
| if self.deepgemm_kernel is not None and should_use_deepgemm_for_fp8_linear( | ||
| output_dtype, weight, True |
There was a problem hiding this comment.
The reason that the last argument of should_use_deepgemm_for_fp8_linear can be set to True is the same as in https://github.com/vllm-project/vllm/pull/33891/changes#r2851594385
| self.is_deep_gemm_supported = is_deep_gemm_supported() | ||
| self.input_quant_op = QuantFP8( | ||
| static=False, | ||
| group_shape=act_scale_descriptor.group_shape, |
There was a problem hiding this comment.
Missing tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
There was a problem hiding this comment.
@tjtanaa added and updated the 3/N PR as well.
| act_scale_descriptor = config.activation_quant_key.scale | ||
| self.is_deep_gemm_supported = is_deep_gemm_supported() | ||
| self.input_quant_op = QuantFP8( | ||
| static=False, |
There was a problem hiding this comment.
Missing column_major_scales=True,
| return [CutlassFp8BlockScaledMMKernel, TritonFp8BlockScaledMMKernel] | ||
|
|
||
| @classmethod | ||
| def is_supported(cls, compute_capability=None): |
There was a problem hiding this comment.
It seems they hardcoded the output_dtype for output tensor of deepgemm to torch.bfloat16, we can assume that it is a condition that we should add to is_supported.
There was a problem hiding this comment.
FlashInfer and DeepGEMM are not following current abstraction. They are wrapping the quant ops in a direct_register_custom_op as shown in
vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py
Lines 272 to 282 in 675ec59
and
vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py
Lines 284 to 306 in 675ec59
| self.input_quant_op = QuantFP8( | ||
| static=act_scale_descriptor.static, | ||
| group_shape=act_scale_descriptor.group_shape, | ||
| num_token_padding=self.get_output_padding(), |
There was a problem hiding this comment.
Following the implementation here, it seems we always explicitly set the use_ue8m0
vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py
Lines 562 to 584 in 675ec59
| def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: | ||
| super().__init__(config) | ||
| act_scale_descriptor = config.activation_quant_key.scale | ||
| self.input_quant_op = QuantFP8( |
There was a problem hiding this comment.
I noticed that FlashInferFp8BlockScaledMMKernel is not using this quant op, can you add a comment why it is needed here?
| return torch.ops.vllm.flashinfer_fp8_blockscale_gemm( | ||
| A, # BF16 input | ||
| B, # FP8 weight | ||
| Bs, # Weight scales |
There was a problem hiding this comment.
|
Since the abstraction and code introduced in this PR is not used and is served to highlight the core changes of refactoring the FP8 block linear op only. We will directly proceed with the 3/N PR #33892 which uses the code introduced in the PR and validate through CI. |
Signed-off-by: maral <maralbahari.98@gmail.com>
Purpose
closing this PR in favor of #33892
Test Plan
Does not require testing since the code path is not utilized yet.
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.