Skip to content

[W8A8 Block Linear Refactor][2/N] Make Fp8 block linear Op use kernel abstraction.#33891

Closed
maralbahari wants to merge 15 commits intovllm-project:mainfrom
EmbeddedLLM:2n-block-scaled-rfc-pr
Closed

[W8A8 Block Linear Refactor][2/N] Make Fp8 block linear Op use kernel abstraction.#33891
maralbahari wants to merge 15 commits intovllm-project:mainfrom
EmbeddedLLM:2n-block-scaled-rfc-pr

Conversation

@maralbahari
Copy link
Contributor

@maralbahari maralbahari commented Feb 5, 2026

Purpose

closing this PR in favor of #33892

Test Plan

Does not require testing since the code path is not utilized yet.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: maral <maralbahari.98@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces a new kernel abstraction for FP8 block-scaled linear layers, which is a great step towards improving code clarity and maintainability. The changes are extensive and well-documented. However, I've found several critical issues in the implementation of the new DynamicMMLinearKernel and its integration, which could lead to runtime errors. These include logical errors in support checks, typos causing NameError, and type incompatibilities in kernel initialization. Please see the detailed comments for each issue.

maralbahari and others added 7 commits February 5, 2026 18:04
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
…r.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
…kScaledMMLinearKernel.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
…ement for cutlass and fix type error in dynamic deepgemm/flash-infer

Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
Signed-off-by: maral <maralbahari.98@gmail.com>
@maralbahari
Copy link
Contributor Author

@robertgshaw2-redhat @ProExpertProg @mgoin cloud you review this PR. appreciate it.

@mergify
Copy link

mergify bot commented Feb 24, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maralbahari.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 24, 2026
Signed-off-by: maral <maralbahari.98@gmail.com>
@mergify mergify bot removed the needs-rebase label Feb 24, 2026
Signed-off-by: maral <maralbahari.98@gmail.com>
@tjtanaa tjtanaa added the rocm Related to AMD ROCm label Feb 25, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 25, 2026
if (
self.flashinfer_deepgemm_kernel is not None
and should_use_flashinfer_for_blockscale_fp8_gemm(
True, output_dtype, input_2d, weight
Copy link
Collaborator

@tjtanaa tjtanaa Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is set to true because FlashInferFp8DeepGEMMDynamicBlockScaledKernel

is_flashinfer_fp8_blockscale_gemm_supported()

is evaluated in the __init__()

        self.flashinfer_deepgemm_kernel: (
            FlashInferFp8DeepGEMMDynamicBlockScaledKernel | None
        ) = None
        if FlashInferFp8DeepGEMMDynamicBlockScaledKernel.is_supported()[0]:

So, this condition self.flashinfer_deepgemm_kernel is not None is testing whether is_flashinfer_supported.

We can now set first argument of should_use_flashinfer_for_blockscale_fp8_gemm to be True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benefit of doing this self.flashinfer_deepgemm_kernel is not None first is that it short-circuits the conditions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try static dispatching either in this PR or upcoming PR. By doing in another PR we can confine the changes of this PR as just refactoring. Either ways work for me.

and should_use_flashinfer_for_blockscale_fp8_gemm(
True, output_dtype, input_2d, weight
)
and should_use_deepgemm_for_fp8_linear(output_dtype, weight, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that the last argument of should_use_deepgemm_for_fp8_linear can be set to True is the same as in https://github.com/vllm-project/vllm/pull/33891/changes#r2851594385

return self.flashinfer_deepgemm_kernel.apply_weights(layer, x, bias)

if self.deepgemm_kernel is not None and should_use_deepgemm_for_fp8_linear(
output_dtype, weight, True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that the last argument of should_use_deepgemm_for_fp8_linear can be set to True is the same as in https://github.com/vllm-project/vllm/pull/33891/changes#r2851594385

self.is_deep_gemm_supported = is_deep_gemm_supported()
self.input_quant_op = QuantFP8(
static=False,
group_shape=act_scale_descriptor.group_shape,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa added and updated the 3/N PR as well.

act_scale_descriptor = config.activation_quant_key.scale
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.input_quant_op = QuantFP8(
static=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing column_major_scales=True,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa added.

return [CutlassFp8BlockScaledMMKernel, TritonFp8BlockScaledMMKernel]

@classmethod
def is_supported(cls, compute_capability=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems they hardcoded the output_dtype for output tensor of deepgemm to torch.bfloat16, we can assume that it is a condition that we should add to is_supported.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlashInfer and DeepGEMM are not following current abstraction. They are wrapping the quant ops in a direct_register_custom_op as shown in

def run_flashinfer_deepgemm_swapAB(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)

and

def run_deepgemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
q_input, input_scale = per_token_group_quant_fp8(
input,
group_size=group_size,
column_major_scales=True,
use_ue8m0=use_deep_gemm_e8m0,
)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
return output

self.input_quant_op = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_ue8m0: bool | None = None, # for Torch compile

Following the implementation here, it seems we always explicitly set the use_ue8m0

if use_cutlass:
return self._run_cutlass, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False,
)
)
if use_aiter_and_is_supported:
return self._run_aiter, QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
return self._run_triton, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)

def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
act_scale_descriptor = config.activation_quant_key.scale
self.input_quant_op = QuantFP8(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that FlashInferFp8BlockScaledMMKernel is not using this quant op, can you add a comment why it is needed here?

return torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
A, # BF16 input
B, # FP8 weight
Bs, # Weight scales
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa this is just a placeholder. the issue is addressed in the PR after to independently register flashinfers swap gemm #33892.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Feb 25, 2026

Since the abstraction and code introduced in this PR is not used and is served to highlight the core changes of refactoring the FP8 block linear op only. We will directly proceed with the 3/N PR #33892 which uses the code introduced in the PR and validate through CI.

Signed-off-by: maral <maralbahari.98@gmail.com>
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 26, 2026
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Feb 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia rocm Related to AMD ROCm

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants