Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
4. Unified EPLB integration for backends that support it
"""

import copy
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -164,21 +165,34 @@ def __init__(
self.apply_router_weight_on_input = apply_router_weight_on_input

# ========== Create MoE Backend (Default: Cutlass) ==========
from tensorrt_llm._torch.modules.fused_moe.create_moe import create_moe_backend, get_moe_cls
from tensorrt_llm._torch.modules.fused_moe.create_moe import (
create_moe_backend,
resolve_moe_cls,
)

# Get MoE backend class based on override_quant_config, routing_method, and model_config
moe_cls = resolve_moe_cls(
model_config,
routing_method,
self.dtype,
override_quant_config=override_quant_config,
)

# Get MoE backend class based on override_quant_config or model_config
moe_cls = get_moe_cls(model_config, override_quant_config=override_quant_config)
backend_model_config = model_config
if override_quant_config is not None:
backend_model_config = copy.deepcopy(model_config)
backend_model_config.quant_config = override_quant_config

# Call create_moe_backend with all necessary parameters
# init_load_balancer=False: Prevents backend from registering itself with load balancer
# without_comm=True: Prevents backend from initializing communication (ConfigurableMoE handles it)
# skip_create_weights_in_init=True: Prevents backend from creating weights in __init__
# because backend uses layer_idx=None and may have different expert assignments
# We will create weights after syncing attributes from ConfigurableMoE
tmp_skip_create_weights_in_init = model_config.skip_create_weights_in_init
model_config._frozen = False
model_config.skip_create_weights_in_init = True
model_config._frozen = True
tmp_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init
backend_model_config._frozen = False
backend_model_config.skip_create_weights_in_init = True
backend_model_config._frozen = True
Comment on lines +192 to +195
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Config freeze/skip flags are not safely restored across all paths.

Line 193 and Line 195 force _frozen=True instead of restoring the original frozen state, and if create_moe_backend(...) throws, skip_create_weights_in_init is left mutated. This can leak state into subsequent layer construction.

Suggested fix
-        tmp_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init
-        backend_model_config._frozen = False
-        backend_model_config.skip_create_weights_in_init = True
-        backend_model_config._frozen = True
-
-        backend = create_moe_backend(
+        original_skip_create_weights_in_init = backend_model_config.skip_create_weights_in_init
+        original_frozen = backend_model_config._frozen
+        try:
+            backend_model_config._frozen = False
+            backend_model_config.skip_create_weights_in_init = True
+            backend_model_config._frozen = original_frozen
+
+            backend = create_moe_backend(
             moe_cls=moe_cls,
             routing_method=routing_method,
             num_experts=self.num_experts,
             hidden_size=self.hidden_size,
             intermediate_size=self.intermediate_size,
             dtype=self.dtype,
             reduce_results=self.reduce_results,
             model_config=backend_model_config,
             aux_stream_dict=self.aux_stream_dict,
             weight_loading_mode=self.weight_loading_mode,
             bias=kwargs.get("bias", False),
             apply_router_weight_on_input=self.apply_router_weight_on_input,
             layer_idx=None,
             swiglu_alpha=kwargs.get("swiglu_alpha"),
             swiglu_beta=kwargs.get("swiglu_beta"),
             swiglu_limit=kwargs.get("swiglu_limit"),
             init_load_balancer=False,
             without_comm=True,
             activation_type=self.activation_type,
-        )
+            )
+        finally:
+            backend_model_config._frozen = False
+            backend_model_config.skip_create_weights_in_init = original_skip_create_weights_in_init
+            backend_model_config._frozen = original_frozen
@@
-        backend_model_config._frozen = False
-        backend_model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
-        backend_model_config._frozen = True
-        if not backend_model_config.skip_create_weights_in_init:
+        if not original_skip_create_weights_in_init:
             self.backend.create_weights()

Also applies to: 240-244

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py` around lines 192 -
195, The code temporarily mutates
backend_model_config.skip_create_weights_in_init and
backend_model_config._frozen but does not reliably restore the original values
on all paths; capture the original values (e.g., tmp_skip_create_weights_in_init
and tmp_frozen = backend_model_config._frozen) before mutating, set
skip_create_weights_in_init = True and _frozen = False only for the operation,
and restore both original values in a finally block around the call to
create_moe_backend (and the analogous block at the 240-244 site) so exceptions
do not leak mutated state.


backend = create_moe_backend(
moe_cls=moe_cls,
Expand All @@ -188,7 +202,7 @@ def __init__(
intermediate_size=self.intermediate_size,
dtype=self.dtype,
reduce_results=self.reduce_results,
model_config=model_config,
model_config=backend_model_config,
aux_stream_dict=self.aux_stream_dict,
weight_loading_mode=self.weight_loading_mode,
bias=kwargs.get("bias", False),
Expand Down Expand Up @@ -223,10 +237,10 @@ def __init__(
self.backend.expert_size_per_partition = self.expert_size_per_partition

# Create weights here, because the backend needs the layer_load_balancer info to create weights
model_config._frozen = False
model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
model_config._frozen = True
if not model_config.skip_create_weights_in_init:
backend_model_config._frozen = False
backend_model_config.skip_create_weights_in_init = tmp_skip_create_weights_in_init
backend_model_config._frozen = True
if not backend_model_config.skip_create_weights_in_init:
self.backend.create_weights()

# ========== Create Communication Strategy ==========
Expand Down
45 changes: 37 additions & 8 deletions tensorrt_llm/_torch/modules/fused_moe/create_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ def get_moe_cls(
return CutlassFusedMoE
return DenseGEMMFusedMoE
elif moe_backend.upper() == "TRTLLM":
if quant_config is not None and (
quant_config.quant_mode.has_fp8_block_scales()
or quant_config.quant_mode.has_nvfp4()
or quant_config.quant_mode.has_w4a16_mxfp4()
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
has_quant = quant_config is not None and quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True)
if has_quant and (quant_config.quant_mode.has_fp8_block_scales()
or quant_config.quant_mode.has_nvfp4()
or quant_config.quant_mode.has_w4a16_mxfp4()
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
return TRTLLMGenFusedMoE
if not has_quant and model_config.pretrained_config is not None and getattr(
model_config.pretrained_config, "torch_dtype",
None) == torch.bfloat16:
if TRTLLMGenFusedMoE._is_flashinfer_fused_moe_available():
return TRTLLMGenFusedMoE
raise RuntimeError(
"TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE with "
"trtllm_bf16_moe support, but it is not available.")
Comment on lines +75 to +82
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Honor call-site dtype overrides when resolving the BF16 backend.

resolve_moe_cls() takes dtype, but the unquantized BF16 branch still keys off model_config.pretrained_config.torch_dtype only. A caller doing create_moe(..., dtype=torch.bfloat16) will therefore fall back to CutlassFusedMoE whenever the pretrained config is unset or still says float16, so the new FlashInfer-backed TRTLLM path never gets selected.

Also applies to: 97-113

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/create_moe.py` around lines 75 - 82,
The BF16 branch in resolve_moe_cls/create_moe currently only checks
model_config.pretrained_config.torch_dtype and ignores the call-site dtype
argument, so callers passing dtype=torch.bfloat16 won't select
TRTLLMGenFusedMoE; update both BF16 checks (the block referencing
TRTLLMGenFusedMoE around the current 75-82 and the similar branch at 97-113) to
honor the dtype parameter by treating the branch as true when dtype is
torch.bfloat16 OR model_config.pretrained_config.torch_dtype is torch.bfloat16
(while still requiring not has_quant and the FlashInfer availability check),
i.e., use a combined condition like: not has_quant and (dtype is torch.bfloat16
or getattr(model_config.pretrained_config, "torch_dtype", None) is
torch.bfloat16) before returning TRTLLMGenFusedMoE or raising the same
RuntimeError.

else:
logger.warning(
"TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, and w4a8_mxfp4_mxfp8. "
Expand All @@ -85,6 +94,25 @@ def get_moe_cls(
raise ValueError(f"Unsupported moe backend: {moe_backend}")


def resolve_moe_cls(
model_config: ModelConfig,
routing_method: BaseMoeRoutingMethod,
dtype: Optional[torch.dtype],
override_quant_config: Optional[QuantConfig] = None) -> Type[MoE]:
moe_cls = get_moe_cls(model_config, override_quant_config)

effective_quant_config = override_quant_config or model_config.quant_config
has_quant = (effective_quant_config is not None
and effective_quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True))
if (moe_cls == TRTLLMGenFusedMoE and not has_quant
and not TRTLLMGenFusedMoE._supports_flashinfer_bf16_routing_method(
routing_method)):
return CutlassFusedMoE

return moe_cls


def create_moe_backend(
moe_cls: Type[MoE],
routing_method: BaseMoeRoutingMethod,
Expand Down Expand Up @@ -379,7 +407,8 @@ def create_moe(
pretrained_config, 'torch_dtype'):
dtype = pretrained_config.torch_dtype

moe_cls = get_moe_cls(model_config, override_quant_config)
moe_cls = resolve_moe_cls(model_config, routing_method, dtype,
override_quant_config)

enable_configurable_moe = os.environ.get("ENABLE_CONFIGURABLE_MOE",
"1") == "1"
Expand Down
136 changes: 115 additions & 21 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@

# isort: off
from .quantization import (
DeepSeekFP8BlockScalesFusedMoEMethod, NVFP4TRTLLMGenFusedMoEBaseMethod,
NVFP4TRTLLMGenFusedMoEMethod, W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod, W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
W4A16MXFP4TRTLLMGenFusedMoEMethod)
BF16TRTLLMGenFusedMoEMethod, DeepSeekFP8BlockScalesFusedMoEMethod,
NVFP4TRTLLMGenFusedMoEBaseMethod, NVFP4TRTLLMGenFusedMoEMethod,
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod, W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod, W4A16MXFP4TRTLLMGenFusedMoEMethod)
# isort: on
from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
DefaultMoeRoutingMethod)
Expand Down Expand Up @@ -115,7 +115,8 @@ def can_implement(
- W4A8_MXFP4_FP8
- W4A8_MXFP4_MXFP8

Does NOT support unquantized mode. Output dtype is hardcoded to bfloat16.
Unquantized BF16 path is supported only with FlashInfer fused MoE backend.
Output dtype is hardcoded to bfloat16.

Args:
quant_algo: The quantization algorithm to check (None for unquantized)
Expand Down Expand Up @@ -143,10 +144,16 @@ def can_implement(
f"TRTLLMGenFusedMoE only supports bfloat16 activation, got {dtype_activation}"
)

# TRTLLMGenFusedMoE does NOT support unquantized mode
if quant_algo is None:
return _warn_and_return(
"TRTLLMGenFusedMoE does not support unquantized mode")
if swiglu_gptoss_style:
return _warn_and_return(
"TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."
)
if not cls._is_flashinfer_fused_moe_available():
return _warn_and_return(
"TRTLLMGenFusedMoE unquantized BF16 path requires FlashInfer fused MoE "
"with trtllm_bf16_moe support.")
return True, None

# Check if quant_algo is supported
if quant_algo not in cls._SUPPORTED_QUANT_ALGOS:
Expand Down Expand Up @@ -210,7 +217,14 @@ def __init__(

assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."

self.use_flashinfer = self._check_op_backend_support()
self.use_flashinfer = self._check_flashinfer_backend_support()
if (self.quant_config is None
or not self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True)) and not self.use_flashinfer:
raise NotImplementedError(
"TRTLLMGenFusedMoE BF16 path requires FlashInfer fused MoE. "
"Please install a FlashInfer build with trtllm_bf16_moe support."
)
backend_name = "flashinfer" if self.use_flashinfer else "trtllm"
self.op_backend: MoEOpBackend = get_op_backend(backend_name)

Expand Down Expand Up @@ -292,7 +306,43 @@ def _to_trtllm_gen_activation_type(self,
else:
raise ValueError(f"Unsupported activation type: {activation_type}")

def _check_op_backend_support(self) -> bool:
@staticmethod
def _is_flashinfer_fused_moe_available() -> bool:
try:
from flashinfer.fused_moe import core as _core
except (ImportError, ModuleNotFoundError):
return False
return (hasattr(_core, "trtllm_bf16_moe")
and hasattr(_core, "trtllm_bf16_routed_moe"))

def _is_unquantized_path(self) -> bool:
return self.quant_config is None or not self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True)

@staticmethod
def _supports_flashinfer_bf16_routing_method(
routing_method: BaseMoeRoutingMethod, ) -> bool:
# FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
Comment on lines +323 to +325
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the hanging indent in _supports_flashinfer_bf16_routing_method.

Flake8 is already reporting E125 here, so the lint job will remain red until the continuation line is re-indented or the closing parenthesis is moved.

Minimal fix
     `@staticmethod`
     def _supports_flashinfer_bf16_routing_method(
-        routing_method: BaseMoeRoutingMethod, ) -> bool:
+        routing_method: BaseMoeRoutingMethod,
+    ) -> bool:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _supports_flashinfer_bf16_routing_method(
routing_method: BaseMoeRoutingMethod, ) -> bool:
# FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
`@staticmethod`
def _supports_flashinfer_bf16_routing_method(
routing_method: BaseMoeRoutingMethod,
) -> bool:
# FIXME: ban DeepSeekV3 FlashInfer trtllm_bf16_routed_moe() as it appears to have bug
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 324-324: continuation line with same indent as next logical line

(E125)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py` around lines
323 - 325, The function signature for _supports_flashinfer_bf16_routing_method
has a hanging indent causing an E125 lint error; fix it by reformatting the
parameter continuation so the closing parenthesis aligns with the opening or
move the closing parenthesis to the same line as the last parameter, e.g. adjust
the indentation of the line with "routing_method: BaseMoeRoutingMethod, ) ->
bool:" so it no longer creates a misaligned continuation; update the def for
_supports_flashinfer_bf16_routing_method accordingly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this will be addressed by flashinfer-ai/flashinfer#2911.

return not isinstance(routing_method, DeepSeekV3MoeRoutingMethod)

def _requires_separated_routing(self) -> bool:
"""Whether this backend instance expects precomputed top-k routing."""
# FIXME: ban FlashInfer BF16 MoE direct routing as it appears to have accuracy bug
return self.use_flashinfer and self._is_unquantized_path()

def _check_flashinfer_backend_support(self) -> bool:
# For BF16 (unquantized) path, we will use FlashInfer regardless whether
# env TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER=1 is set or not as it's the only way.
if self._is_unquantized_path():
if not self._is_flashinfer_fused_moe_available():
return False
if self.activation_type != ActivationType.Swiglu:
return False
if not self._supports_flashinfer_bf16_routing_method(
self.routing_method):
return False
return True

use_flashinfer = os.environ.get("TRTLLM_GEN_FUSED_MOE_USE_FLASHINFER",
"0")
if use_flashinfer != "1":
Expand All @@ -311,8 +361,6 @@ def _check_op_backend_support(self) -> bool:
if type(quant_method) is NVFP4TRTLLMGenFusedMoEBaseMethod:
return True

if self.quant_config is None:
return False
mode = self.quant_config.layer_quant_mode

# These quant modes are never supported via op backend
Expand Down Expand Up @@ -365,7 +413,14 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
return AlltoallMethodType.NVLinkOneSided

def _supports_load_balancer(self) -> bool:
"""TRTLLMGenFusedMoE supports load balancer."""
"""Whether separated routing (top-k outside the kernel) is used.

ConfigurableMoE uses this flag to decide whether routing is separated
(top-k ids/scales computed outside backend) or fused inside the kernel.
BF16 FlashInfer path always requires separated routing.
"""
if self._requires_separated_routing():
return True
return self.use_dp and self.parallel_size > 1

@cached_property
Expand All @@ -375,9 +430,17 @@ def enable_alltoall(self):
return self.alltoall_method_type != AlltoallMethodType.NotEnabled

def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
assert not self.has_any_quant \
or self.has_deepseek_fp8_block_scales \
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, \
"TRTLLMGenFusedMoE only supports bf16 (FlashInfer), fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."

if not self.has_any_quant:
assert self.activation_type == ActivationType.Swiglu, \
"TRTLLMGenFusedMoE BF16 path only supports Swiglu activation."
assert not self.bias and self.swiglu_alpha is None and self.swiglu_beta is None and self.swiglu_limit is None, \
"TRTLLMGenFusedMoE BF16 path does not support bias/swiglu custom parameters."

if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
assert self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE supports bias/swiglu only for nvfp4 and mxfp4 variants."
Expand Down Expand Up @@ -405,8 +468,7 @@ def _get_quant_method(self):
f"Unsupported quantization method by TRTLLMGenFusedMoE: {self.quant_config.quant_mode}"
)
else:
raise NotImplementedError(
"TRTLLMGenFusedMoE doesn't support fp16/bf16/fp32 MoE.")
return BF16TRTLLMGenFusedMoEMethod()

def create_weights(self):
if self._weights_created:
Expand Down Expand Up @@ -467,6 +529,8 @@ def quantize_input(self, x, post_quant_comm: bool = True):
- scaling_vector_size is typically the group size for block-wise quantization
"""
x_sf = None
if not self.has_any_quant:
return x, x_sf
if self.has_w4a8_mxfp4_fp8:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
Expand Down Expand Up @@ -526,7 +590,7 @@ def quantize_input(self, x, post_quant_comm: bool = True):
return x, x_sf

def supports_moe_output_in_alltoall_workspace(self):
return True
return self.has_any_quant and not self.use_flashinfer

def run_moe(
self,
Expand All @@ -542,8 +606,8 @@ def run_moe(
Run MoE computation with TRTLLMGen backend.

This method encapsulates the core MoE computation logic, handling different
quantization schemes (fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8,
w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).
quantization schemes (bf16, fp8_block_scales, nvfp4, w4a16_mxfp4,
w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, w4a8_mxfp4_mxfp8).

Args:
# Standard MoE interface parameters:
Expand Down Expand Up @@ -592,7 +656,37 @@ def run_moe(
) == 2, f"x_sf should be 2D tensor, got shape {x_sf.shape}"
x_sf = x_sf.flatten()

if self.has_deepseek_fp8_block_scales:
if not self.has_any_quant:
result = self.op_backend.run_bf16_moe(
router_logits=router_logits,
routing_bias=routing_bias,
hidden_states=x,
gemm1_weights=self.w3_w1_weight,
gemm2_weights=self.w2_weight,
num_experts=self.num_slots,
top_k=top_k,
n_group=n_group,
topk_group=topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.slot_start,
local_num_experts=self.expert_size_per_partition,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=self.routing_method.routing_method_type,
topk_weights=token_final_scales,
topk_ids=token_selected_experts,
gated_act_type=self._to_trtllm_gen_activation_type(
self.activation_type),
output=moe_output,
use_shuffled_weight=getattr(self.quant_method,
"use_shuffled_weight", False),
weight_layout=getattr(self.quant_method, "weight_layout", 0),
do_finalize=do_finalize,
)
if not do_finalize:
assert not self.reduce_results, "reduce_results must be False when do_finalize is False"
return result
final_hidden_states = result
elif self.has_deepseek_fp8_block_scales:
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
# fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
if x_sf is None:
Expand Down
Loading
Loading