Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
ca114bd
a2a
Jan 19, 2026
d51a1a6
base device communicator
Jan 19, 2026
4a37fbb
rest of the device communicators
Jan 19, 2026
6ab0934
add moe prepare and finalize naive ep
Jan 19, 2026
5de3d38
get naive dispatch/combine into a good spot
Jan 19, 2026
bc30ec4
update maybe_make_prepare_finalize and select_gemm_impl
Jan 19, 2026
1666529
update make_nvfp4 oracle
Jan 19, 2026
8f9e6cd
make the prepare finalize in the oracle
Jan 19, 2026
6121050
remove some fi functions
Jan 19, 2026
2d16269
things are 'working' end to end
Jan 19, 2026
8e60d88
fix topk indices dtype
Jan 19, 2026
c3ee917
Support nccl fp8 communication (#32760)
amirkl94 Jan 21, 2026
be84e3b
merge
Jan 21, 2026
71aa335
Merge branch 'naive-pf-separation' of https://github.com/vllm-project…
Jan 21, 2026
21f3c10
pre-comimt remediation
Jan 21, 2026
6ac652a
pre-comimt remediation
Jan 21, 2026
fece963
fun with pre-commit
Jan 21, 2026
8d0cc52
remove FlashInferAG
Jan 21, 2026
4fc2917
updated
Jan 21, 2026
60a15b0
small nit
Jan 21, 2026
0a84042
move mkm
Jan 21, 2026
c5f4734
rename file
Jan 21, 2026
5010925
support flashinfer cutlass fp8 with DeepEPHT
Jan 21, 2026
1d09143
make deepepht prepare-finalize work with cutlass fp8 block
Jan 21, 2026
62efd36
stash changes to make nvfp4 unfused
Jan 21, 2026
92fd9b0
do the fused swizzling
Jan 21, 2026
2b406e3
stash
Jan 21, 2026
01f004e
updated to remove the explicit gather call
Jan 21, 2026
a3462dd
make mnnvl work properly
Jan 21, 2026
70b7909
make mnnvl work properly
Jan 21, 2026
9f9557e
remove bad import
Jan 21, 2026
859fd35
add compatibility for flashinfer mnnvl
Jan 21, 2026
15c0112
add todo
Jan 21, 2026
530b463
merge main
Jan 23, 2026
ddc5b2a
pre-commit
Jan 23, 2026
03fe6ec
update type signatures
Jan 23, 2026
10b4922
update type signatures
Jan 23, 2026
8e6783f
make pre-commit happy
Jan 24, 2026
ceb5ec0
get tests passing
Jan 24, 2026
e726362
self.kernel -> self.mk_moe; move shared items up to the base class
Jan 24, 2026
dbd27e6
self.kernel -> self.mk_moe; move shared items up to the base class
Jan 24, 2026
a29abcc
nit
Jan 24, 2026
dfcadb8
updated
Jan 24, 2026
f2531a7
reduce LOC change
Jan 24, 2026
be22a24
update comments
Jan 24, 2026
c466250
revert removal of extra tesnors from dispatch/combine
Jan 24, 2026
ba6903f
reduce loc change
Jan 24, 2026
1e24906
remove extratensors for naive A2a
Jan 24, 2026
45bc73e
moe_mk -> mk_moe
Jan 24, 2026
66c9388
moe_mk -> mk_moe
Jan 24, 2026
14be251
fix flashinfer 0.5.3 -> 0.6.1 update
Jan 24, 2026
e76eb02
remove support for defer input quant to batched P/Fes
Jan 24, 2026
a5b25d9
Merge remote-tracking branch 'origin/main' into naive-pf-separation
Jan 24, 2026
19f8470
fixed bug with ag/rs and nvfp4 quantization related the the swizzling
Jan 24, 2026
ad2a758
updated
Jan 24, 2026
b456504
fix up cutlass + pplx
Jan 24, 2026
d9438b5
fix up pre-commit
Jan 24, 2026
c4d86e0
fix up pre-commit
Jan 24, 2026
1d95eab
make pre-commit support other routing strategy
Jan 24, 2026
dd71aee
remove debug cruft
Jan 24, 2026
f5d44bc
reduce LOC change
Jan 24, 2026
038af58
reduce LOC changed
Jan 24, 2026
226ad44
reduce LOC changed
Jan 24, 2026
24ce77b
fix handling of monolithic kernels
Jan 25, 2026
b66c07c
Merge remote-tracking branch 'origin/main' into naive-pf-separation
Jan 25, 2026
732abd5
fix AG/RS dtype
Jan 25, 2026
a5cf835
fix incorrect arg in test
Jan 25, 2026
4a5fce9
fix missing args
Jan 25, 2026
7cbeae8
fix mk
Jan 25, 2026
af0c268
fix pre-commit
Jan 25, 2026
b5474b7
fix docs build
Jan 25, 2026
9105ead
confirmed shared experts working
Jan 25, 2026
0da54f1
update, remove to cuda
Jan 25, 2026
ef152ad
verify
Jan 25, 2026
b366ad4
fix DP/TP deployment
Jan 25, 2026
135cc53
update defer input quant
Jan 26, 2026
a9faf58
merge main
Jan 26, 2026
609298d
remove empty file
Jan 26, 2026
53cfc42
reduce LOC changed
Jan 26, 2026
08c442a
reduce LOC changed
Jan 26, 2026
2ca7950
reduce LOC changed
Jan 26, 2026
f184c48
reduce LOC changed
Jan 26, 2026
881f436
reduce LOC changed
Jan 26, 2026
9e38aca
reduce LOC changed
Jan 26, 2026
98220b0
reduce LOC changed
Jan 26, 2026
7bd2a44
reduce LOC changed
Jan 26, 2026
b6d5107
reduce LOC changed
Jan 26, 2026
678c34a
reduce LOC changed
Jan 26, 2026
f9b5b92
reduce LOC changed
Jan 26, 2026
e2b9969
reduce LOC changed
Jan 26, 2026
3e3f17c
reduce LOC changed
Jan 26, 2026
4f058f4
reduce LOC changed
Jan 26, 2026
41c1264
reduce LOC changed
Jan 26, 2026
0ad06db
reduce LOC changed
Jan 26, 2026
bf4d327
reduce LOC changed
Jan 26, 2026
578c183
reduce LOC changed
Jan 26, 2026
bf0acb1
reduce LOC changed
Jan 26, 2026
fe8fb7f
reduce LOC changed
Jan 26, 2026
8f46f93
reduce LOC changed
Jan 26, 2026
52124ea
reduce LOC changed
Jan 26, 2026
59bab2d
reduce LOC changed
Jan 26, 2026
ab22193
update comment
Jan 26, 2026
eeca003
Merge branch 'main' into naive-pf-separation
robertgshaw2-redhat Jan 26, 2026
9217c64
Merge branch 'main' into naive-pf-separation
robertgshaw2-redhat Jan 26, 2026
ed6d5fd
Merge branch 'main' into naive-pf-separation
robertgshaw2-redhat Jan 26, 2026
7e57390
fix docstring
Jan 26, 2026
0e79444
add comemnta bout commBackned
Jan 26, 2026
ec4edbd
add comemnta bout commBackned
Jan 26, 2026
acd7150
Merge branch 'main' into naive-pf-separation
robertgshaw2-redhat Jan 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/test_areas/kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ steps:
- csrc/quantization/cutlass_w8a8/moe/
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/v1/attention/backends/mla/cutlass_mla.py
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_cutlass_moe_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def run_cutlass_moe_fp4(
)

kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
make_dummy_moe_config(),
quant_config=quant_config,
Expand Down Expand Up @@ -242,7 +242,7 @@ def run_cutlass_from_graph(
)

kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
make_dummy_moe_config(),
quant_config=quant_config,
Expand Down
3 changes: 1 addition & 2 deletions docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ th {
| pplx | batched | fp8,int8 | G,A,T | Y | Y | [`PplxPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.pplx_prepare_finalize.PplxPrepareAndFinalize] |
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize.DeepEPLLPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize.DeepEPHTPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferAllToAllMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferAllToAllMoEPrepareAndFinalize] |
| flashinfer<sup>4</sup> | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferCutlassMoEPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize.FlashInferCutlassMoEPrepareAndFinalize] |
| flashinfer_all2allv | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferA2APrepareAndFinalize`][vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize.FlashInferA2APrepareAndFinalize] |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be an entry for MoEPrepareAndFinalizeNaiveEP?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

| MoEPrepareAndFinalizeNoEP<sup>5</sup> | standard | fp8,int8 | G,A,T | N | Y | [`MoEPrepareAndFinalizeNoEP`][vllm.model_executor.layers.fused_moe.prepare_finalize.MoEPrepareAndFinalizeNoEP] |
| BatchedPrepareAndFinalize<sup>5</sup> | batched | fp8,int8 | G,A,T | N | Y | [`BatchedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedPrepareAndFinalize] |

Expand Down
12 changes: 8 additions & 4 deletions tests/kernels/moe/modular_kernel_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
Expand All @@ -40,7 +43,6 @@
TestMoEQuantConfig,
expert_info,
make_fused_experts,
make_prepare_finalize,
prepare_finalize_info,
)
from .parallel_utils import ProcessGroupInfo
Expand Down Expand Up @@ -603,10 +605,12 @@ def next_power_of_2(x):
routing_method=RoutingMethodType.DeepSeekV3,
)

# make modular kernel
prepare_finalize = make_prepare_finalize(
config.prepare_finalize_type, config.all2all_backend(), moe, quant_config
prepare_finalize = maybe_make_prepare_finalize(
moe=moe,
quant_config=quant_config,
allow_new_interface=True,
)
assert prepare_finalize is not None

fused_experts = make_fused_experts(
config.fused_experts_type,
Expand Down
28 changes: 3 additions & 25 deletions tests/kernels/moe/modular_kernel_tools/mk_objects.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test/registration for MoEPrepareAndFinalizeNaiveEP?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
Expand Down Expand Up @@ -255,13 +252,12 @@ def expert_info(kind) -> ExpertInfo:
)

if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize,
)

register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
Expand Down Expand Up @@ -429,24 +425,6 @@ def expert_info(kind) -> ExpertInfo:
]


def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: str | None,
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = maybe_make_prepare_finalize(moe, quant_config)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1
)
else:
return MoEPrepareAndFinalizeNoEP()


def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
Expand Down
7 changes: 1 addition & 6 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,7 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
)

kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
Expand Down
7 changes: 1 addition & 6 deletions tests/kernels/moe/test_flashinfer_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,7 @@ def test_flashinfer_fp4_moe_no_graph(
)

flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
defer_input_quant=FlashInferExperts.expects_unquantized_inputs(
moe_config=moe_config,
quant_config=quant_config,
)
),
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_cutlass_fp4_moe_no_graph(
)

kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
moe_config=make_dummy_moe_config(),
quant_config=quant_config,
Expand Down
102 changes: 98 additions & 4 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def naive_multicast(

return buffer

def dispatch(
def dispatch_router_logits(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to deprecate dispatch_router_logits eventually or keep it around for different cases?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will still need it for monolithic kernels (trtllm)

self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -84,6 +84,34 @@ def dispatch(

return hidden_states, router_logits

def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

hidden_states = self.naive_multicast(
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_weights = self.naive_multicast(
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
)
topk_ids = self.naive_multicast(
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, topk_weights, topk_ids

def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
Expand Down Expand Up @@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)

def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down Expand Up @@ -148,6 +176,46 @@ def dispatch(
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]

def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]

tensors_to_gather = [hidden_states, topk_weights, topk_ids]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)

gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)

hidden_states = gathered_tensors[0]
topk_weights = gathered_tensors[1]
topk_ids = gathered_tensors[2]

if extra_tensors is None:
return hidden_states, topk_weights, topk_ids

return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]

def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
Expand Down Expand Up @@ -216,7 +284,7 @@ def get_handle(self, kwargs):
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)

def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -225,6 +293,19 @@ def dispatch(
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError

def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
Expand Down Expand Up @@ -264,7 +345,7 @@ def __init__(self, cpu_group):
def get_handle(self, kwargs):
raise NotImplementedError

def dispatch(
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand All @@ -273,6 +354,19 @@ def dispatch(
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError

def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
Expand Down
Loading