Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/config/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def with_default(
"flashinfer_cutlass",
"flashinfer_cutedsl",
"marlin",
"humming",
"aiter",
"emulation",
]
Expand Down Expand Up @@ -145,6 +146,7 @@ class KernelConfig:
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
- "marlin": Use Marlin kernels (weight-only quantization)
- "humming": Use Humming Mixed Precision kernels
- "aiter": Use AMD AITer kernels (ROCm only)
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
running QDQ on activations.
Expand Down
4 changes: 2 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,8 +1222,8 @@ def _get_or_set_default() -> str:
# if 1, force use indexed gemm
# if 0, force use grouped gemm
# if None, choose better gemm type automatically
"VLLM_HUMMING_MOE_GEMM_TYPE": lambda: maybe_convert_bool(
os.environ.get("VLLM_HUMMING_MOE_GEMM_TYPE", None)
"VLLM_HUMMING_MOE_GEMM_TYPE": lambda: os.environ.get(
"VLLM_HUMMING_MOE_GEMM_TYPE", None
),
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# only supported on Blackwell GPUs and with
Expand Down
119 changes: 70 additions & 49 deletions vllm/model_executor/layers/fused_moe/fused_humming_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
import math
from typing import TYPE_CHECKING, Any
from typing import Any

import torch
from humming import dtypes
Expand All @@ -16,7 +16,11 @@
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
Expand All @@ -34,21 +38,16 @@
from vllm.platforms import current_platform
from vllm.v1.worker.workspace import current_workspace_manager

if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.humming import HummingMoEMethod


logger = init_logger(__name__)


def get_humming_moe_gemm_type() -> str:
env_gemm_type: str = envs.VLLM_HUMMING_MOE_GEMM_TYPE or ""
env_gemm_type = env_gemm_type.lower()
if env_gemm_type in ["indexed", "grouped"]:
if env_gemm_type == "indexed":
gemm_type = env_gemm_type
elif current_platform.has_device_capability(90):
# for device that supports TMA, use grouped gemm
gemm_type = "grouped"
elif env_gemm_type in ["grouped_contiguous", "grouped"]:
gemm_type = "grouped_contiguous"
else:
gemm_type = "indexed"

Expand All @@ -60,49 +59,44 @@ class HummingExpertsBase(mk.FusedMoEExpertsModular):
def __init__(
self,
layer: torch.nn.Module,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try to avoid passing the layer here if at all possible. It contains the modular kernels. If we ever construct the modular kernels at __init__ time of the layer (which we are considering) then this will lead to all sorts of problems.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Humming supports a wide variety of quantization combinations, the corresponding weight combinations are also quite numerous. To reduce the complexity on the caller side, I prefer to use a layer-based approach. If directly passing the FusedMoE layer would cause issues, do you think it would be a good choice to directly extract all the required weights and reconstruct a temporary layer inside the modular kernels.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what "construct the modular kernels at __init__ time of the layer" means. Since the modular kernels currently require passing in a FusedMoEQuantConfig, and this config can only be fully defined after process_weights_after_loading, how are we supposed to construct the modular kernels at the __init__ stage? Do you plan to pass these in as runtime variables?

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what "construct the modular kernels at __init__ time of the layer" means. Since the modular kernels currently require passing in a FusedMoEQuantConfig, and this config can only be fully defined after process_weights_after_loading, how are we supposed to construct the modular kernels at the __init__ stage? Do you plan to pass these in as runtime variables?

Even though the modular kernels require a FusedMoEQuantConfig at construction time, they don't really need much information from it (if any). We've been discussing removing this as a requirement for construction so that modular kernels can be instantiated at the same time as the quant methods that own them. This is to address other subtle order of initialization issues related to the FusedMoE layer, quant methods, SharedExperts, MoERunner, etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, are you planning to pass model parameters or layers as arguments to the apply function? (Many quantization methods have additional parameters besides weight and scale.) I can do the relevant refactoring work for humming in advance.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, are you planning to pass model parameters or layers as arguments to the apply function? (Many quantization methods have additional parameters besides weight and scale.) I can do the relevant refactoring work for humming in advance.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the layer will still be passed as a runtime arg to apply. It's only a problem when used as an argument to __init__ any modular kernel objects.

Copy link
Copy Markdown
Contributor Author

@jinzhen-lin jinzhen-lin May 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bnellnm , since layer is not yet a parameter for methods like apply, moe_problem_size, or workspace_shapes, I can't remove it from the __init__ arguments for now. Is there a timeline for the refactoring you mentioned?

Copy link
Copy Markdown
Collaborator

@bnellnm bnellnm May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bnellnm , since layer is not yet a parameter for methods like apply, moe_problem_size, or workspace_shapes, I can't remove it from the __init__ arguments for now. Is there a timeline for the refactoring you mentioned?

It looks like a number of the values you need are members of FusedMoEConfig or FusedMoEQuantConfig which each MK has as attributes, e.g. hidden_size, number of experts, etc.

The layer is not going to be passed down to the MK apply function. It is passed to the quant_method apply functions and it is the quant_method's responsibility to unpack any data needed from the layer and pass it along to the MK.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the parameters for the MK apply function are fixed. Is there a way to pass specific parameters to the kernel? For example, the Humming MoE requires a layer object, and similarly, the Marlin MoE requires variables like w13_g_idx.

quant_method: "HummingMoEMethod",
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular | None = None,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
self.layer = layer
self.num_experts = self.layer.num_experts
self.global_num_experts = self.layer.global_num_experts
self.init_humming_moe()

if prepare_finalize is not None:
max_num_tokens: int | None = None
num_dispatchers: int | None = None
if self.is_batched:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
num_dispatchers = prepare_finalize.num_dispatchers()

assert quant_method.moe_quant_config is not None
super().__init__(
moe_config=quant_method.moe,
quant_config=quant_method.moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
else:
assert not self.is_batched
if self.is_batched():
assert max_num_tokens is not None and num_dispatchers is not None

super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)

def init_humming_moe(self):
self.compute_config = {
"use_batch_invariant": envs.VLLM_BATCH_INVARIANT,
"use_f16_accum": envs.VLLM_HUMMING_USE_F16_ACCUM,
"gemm_type": self.humming_gemm_type.value,
"gemm_type": self.humming_gemm_type().value,
}
self.w13_tuning_config = HummingMethod.get_default_tuning_configs(
layer=self.layer,
use_f16_accum=envs.VLLM_HUMMING_USE_F16_ACCUM,
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
gemm_type=self.humming_gemm_type,
gemm_type=self.humming_gemm_type(),
sublayer_name="w13",
)
self.w2_tuning_config = HummingMethod.get_default_tuning_configs(
layer=self.layer,
use_f16_accum=envs.VLLM_HUMMING_USE_F16_ACCUM,
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
gemm_type=self.humming_gemm_type,
gemm_type=self.humming_gemm_type(),
sublayer_name="w2",
)
self.compute_config_str = json.dumps(self.compute_config)
Expand All @@ -124,13 +118,13 @@ def estimate_local_valid_shape_m(self, topk_ids: torch.Tensor):
global_num_experts = self.global_num_experts
return math.ceil(global_valid_shape_m * num_experts / global_num_experts)

@property
def humming_gemm_type(self) -> HummingGemmType:
@staticmethod
def humming_gemm_type() -> HummingGemmType:
raise NotImplementedError

@property
def is_batched(self) -> bool:
return self.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts
@classmethod
def is_batched(cls) -> bool:
return cls.activation_format() == mk.FusedMoEActivationFormat.BatchedExperts

@staticmethod
def _supports_quant_scheme(
Expand Down Expand Up @@ -189,7 +183,7 @@ def moe_problem_size(
assert w1.size(0) == num_experts
assert w2.size(0) == num_experts

if not self.is_batched:
if not self.is_batched():
num_tokens = a1.size(0)
assert topk_ids.size(0) == num_tokens
else:
Expand All @@ -201,7 +195,7 @@ def moe_problem_size(

def get_buffer_metas(self, M: int, topk: int, activation: MoEActivation):
num_experts = self.num_experts
N = self.layer.intermediate_size
N = self.layer.intermediate_size_per_partition
K = self.layer.hidden_size
assert isinstance(num_experts, int)
assert isinstance(N, int)
Expand All @@ -218,7 +212,7 @@ def get_buffer_metas(self, M: int, topk: int, activation: MoEActivation):
# The output must be derived from workspace1.

output_shape: tuple[int, ...]
if self.is_batched:
if self.is_batched():
max_num_tokens = self.max_num_tokens
num_dispatchers = self.num_dispatchers
assert max_num_tokens is not None and num_dispatchers is not None
Expand All @@ -227,7 +221,7 @@ def get_buffer_metas(self, M: int, topk: int, activation: MoEActivation):
output_shape = (num_experts, max_num_tokens * num_dispatchers, K)
else:
input_shape_m = M
if self.humming_gemm_type != HummingGemmType.INDEXED:
if self.humming_gemm_type() != HummingGemmType.INDEXED:
input_shape_m = M * topk
real_shape_m = M * topk
output_shape = (M, K)
Expand Down Expand Up @@ -262,7 +256,7 @@ def get_buffer_metas(self, M: int, topk: int, activation: MoEActivation):
"dtype": torch_dtype_map[a_dtype],
},
"down_output": {
"shape": output_shape if self.is_batched else (real_shape_m, K),
"shape": output_shape if self.is_batched() else (real_shape_m, K),
"dtype": torch_dtype_map[c_dtype],
},
"output": {
Expand All @@ -288,7 +282,7 @@ def get_buffer_metas(self, M: int, topk: int, activation: MoEActivation):
]

# batched moe use down_output as output
if not self.is_batched:
if not self.is_batched():
required_buffers.append("output")

return buffer_metas, required_buffers
Expand All @@ -308,7 +302,7 @@ def _workspace_shapes(self, M: int, topk: int, activation: MoEActivation):
else:
workspace2_nbytes = max(workspace2_nbytes, nbytes)

output_key = "down_output" if self.is_batched else "output"
output_key = "down_output" if self.is_batched() else "output"
output_shape = buffer_metas[output_key]["shape"]

return (workspace1_nbytes // 2,), (workspace2_nbytes // 2,), output_shape
Expand Down Expand Up @@ -395,6 +389,33 @@ def main_apply(
):
raise NotImplementedError

@staticmethod
def is_supported_config(
cls: type[mk.FusedMoEExperts],
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
if activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
supported = cls.activation_format() == activation_format
reason = "activation_format mismatched"
elif activation_format == mk.FusedMoEActivationFormat.Standard:
if cls.activation_format() != mk.FusedMoEActivationFormat.Standard:
supported = False
reason = "activation_format mismatched"
else:
assert hasattr(cls, "humming_gemm_type")
gemm_type = cls.humming_gemm_type().value.lower()
preferred_gemm_type = get_humming_moe_gemm_type().lower()
supported = preferred_gemm_type == gemm_type
reason = "preferred gemm type mismatched"
else:
supported = False
reason = "unsupported activation_format"

return supported, None if supported else reason


class HummingIndexedExperts(HummingExpertsBase):
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
Expand All @@ -404,8 +425,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard

@property
def humming_gemm_type(self) -> HummingGemmType:
@staticmethod
def humming_gemm_type() -> HummingGemmType:
return HummingGemmType.INDEXED

def prepare_humming_moe_kwargs(
Expand Down Expand Up @@ -526,8 +547,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard

@property
def humming_gemm_type(self) -> HummingGemmType:
@staticmethod
def humming_gemm_type() -> HummingGemmType:
return HummingGemmType.GROUPED_CONTIGUOUS

def main_apply(
Expand Down Expand Up @@ -619,8 +640,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts

@property
def humming_gemm_type(self) -> HummingGemmType:
@staticmethod
def humming_gemm_type() -> HummingGemmType:
return HummingGemmType.GROUPED_MASKED

def main_apply(
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,9 +1103,6 @@ def weight_loader(
return_success: bool = False,
) -> bool | None:
quant_config_name = self.quant_config and self.quant_config.get_name()
if quant_config_name == "humming":
assert hasattr(self.quant_method, "weight_schema")
quant_config_name = self.quant_method.weight_schema.quant_method
if quant_config_name == "gpt_oss_mxfp4":
# (FIXME) for gpt-oss all experts are combined
if "bias" in weight_name:
Expand Down
Loading
Loading