Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
c488653
fully working mixed requests version
shaharmor98 Dec 21, 2025
d7ee37c
add working real checkpoint code
shaharmor98 Dec 21, 2025
02c1b02
CUDA graphs work
roikoren755 Nov 26, 2025
c84fec2
Prefix caching works
roikoren755 Nov 27, 2025
8250af1
Fix triton kernel to support varlen, and update call site
roikoren755 Nov 27, 2025
a807460
Full CUDA graphs
roikoren755 Nov 30, 2025
c62c4ff
Prefix caching + refactored code
shaharmor98 Dec 21, 2025
6c08722
added updated load weights
shaharmor98 Dec 21, 2025
3d6a23b
Fix CUDA graphs compat
roikoren755 Dec 9, 2025
4db7e20
Working speculative code with multiple MTP layers
shaharmor98 Dec 21, 2025
d27da73
running mtp for num_speculative > 2
shaharmor98 Dec 21, 2025
2c4f1c7
remove redundant ids handling, add eagle3 support flag
shaharmor98 Dec 21, 2025
5aabfd5
temporarily disable update_block_table
shaharmor98 Dec 21, 2025
f3918a1
remove redundant code, support cuda graph for target excluding drafter
shaharmor98 Dec 23, 2025
c35cf11
move faulty assertion
shaharmor98 Dec 23, 2025
965a716
remove full cg support for mamba
shaharmor98 Dec 30, 2025
43728d8
commenting eagle changes
shaharmor98 Dec 30, 2025
a72852c
fix block size used in EAGLE slot mapping
benchislett Dec 30, 2025
d0f85ad
remove eagle multi layer support
shaharmor98 Jan 1, 2026
4e3ff41
remove multi layer spec step idx from eagle
shaharmor98 Jan 1, 2026
fe2ff13
update code to a single MTP layer, remove speculative enforce eager, …
shaharmor98 Jan 1, 2026
fd28da9
change mtp layer count
shaharmor98 Jan 4, 2026
45fbd1b
final cleanup
shaharmor98 Jan 8, 2026
3b211ea
polished implementation
benchislett Jan 30, 2026
5a6131d
tweaks for perf
benchislett Jan 30, 2026
f213300
tweak
benchislett Jan 30, 2026
f461834
tweaks for non-spec and spec compat
benchislett Jan 31, 2026
6a27a0f
Merge branch 'main' into bchislett/nemotron-h-mtp-old-rebased
benchislett Jan 31, 2026
f56420e
simple patches for rebase
benchislett Jan 31, 2026
cc9b29f
patch
benchislett Feb 2, 2026
896e01f
remove unused files
benchislett Feb 2, 2026
b96184d
update mamba backend with specdec support
benchislett Feb 3, 2026
6d59c03
refactor mamba attn cudagraph logic
benchislett Feb 3, 2026
8819eb7
Merge branch 'main' into bchislett/mamba-nemotron-mtp
benchislett Feb 3, 2026
b75e7ac
cleanup
benchislett Feb 3, 2026
bca337b
cleanup
benchislett Feb 3, 2026
7da121d
revert unneeded refactor
benchislett Feb 3, 2026
5db7692
rename prefill_state_indices_tensor and decode_state_indices_tensor
benchislett Feb 6, 2026
b1e927e
rewrite todo comment for known issue
benchislett Feb 6, 2026
03ef64d
change max_query_len overwrite to assert
benchislett Feb 6, 2026
8b7c03f
update cuda graph padding boundaries
benchislett Feb 6, 2026
c67fccb
avoid slicing state indices tensor for 'all' prefix cache mode
benchislett Feb 6, 2026
0739591
remove validation mode leftover from debugging
benchislett Feb 6, 2026
0a5f1a3
revert no-longer-needed diff in layer.py
benchislett Feb 6, 2026
063f54b
update layer.py from main to fix conflict
benchislett Feb 6, 2026
d66babd
Merge branch 'main' into bchislett/mamba-nemotron-mtp
benchislett Feb 6, 2026
f7fef88
adjust other mamba-base backends to the new state_indices_tensor repr…
shaharmor98 Feb 8, 2026
e71fb7a
merge main
shaharmor98 Feb 11, 2026
b642f2e
Merge remote-tracking branch 'origin/main' into bchislett/mamba-nemot…
shaharmor98 Feb 15, 2026
a02ac98
[Bugfix] Fix assertion error in _dummy_run for MTP speculative decoding
LucasWilkinson Feb 12, 2026
ce598de
Merge remote-tracking branch 'origin/main' into bchislett/mamba-nemot…
shaharmor98 Feb 16, 2026
79b6e40
fix wrong max num tokens init in GPUModelRunner
shaharmor98 Feb 16, 2026
4a9274f
revert assetion relaxation
shaharmor98 Feb 16, 2026
77d5cc1
fix cr comments
shaharmor98 Feb 17, 2026
dc29214
Merge commit 'dc5fa77a4eb6680339cb77abe713fb22d7795560' into bchislet…
benchislett Feb 19, 2026
6b50bd0
fix comment
benchislett Feb 19, 2026
1c45407
simplify prefix caching slicing
benchislett Feb 19, 2026
05efc54
adjust decode threshold based on specdec existance in the request
shaharmor98 Feb 19, 2026
10cb13e
fix frivolous assert
benchislett Feb 20, 2026
516c7b5
Merge branch 'main' into bchislett/mamba-nemotron-mtp
benchislett Feb 20, 2026
4797d1d
Merge branch 'main' into bchislett/mamba-nemotron-mtp
benchislett Feb 23, 2026
f77866f
use new vllmconfig getter for num_spec
benchislett Feb 23, 2026
810caf1
use new vllmconfig getter for num_spec
benchislett Feb 23, 2026
d511f19
remove no-longer-necessary bug workaround
benchislett Feb 23, 2026
623c1f7
add placeholder model for NemotronH MTP
benchislett Feb 23, 2026
1a688ba
fix prefix caching for non-MTP case
benchislett Feb 23, 2026
685fe39
Merge branch 'main' into bchislett/mamba-nemotron-mtp
benchislett Feb 23, 2026
1a28f23
Merge branch 'main' into bchislett/mamba-nemotron-mtp
benchislett Feb 24, 2026
cb56163
update mamba block table test config
benchislett Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,11 @@ def check_available_online(
},
is_available_online=False,
),
"NemotronHMTPModel": _HfExamplesInfo(
"nvidia/Nemotron-Super-Placeholder",
speculative_model="nvidia/Nemotron-Super-Placeholder",
is_available_online=False,
),
}

_TRANSFORMERS_BACKEND_MODELS = {
Expand Down
8 changes: 7 additions & 1 deletion tests/v1/attention/test_mamba_update_block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def _make_vllm_config(block_size, max_model_len, max_num_seqs):
cudagraph_mode=CUDAGraphMode.FULL,
max_cudagraph_capture_size=None,
),
speculative_config=None,
num_speculative_tokens=0,
parallel_config=SimpleNamespace(decode_context_parallel_size=1),
scheduler_config=SimpleNamespace(max_num_seqs=max_num_seqs),
model_config=SimpleNamespace(max_model_len=max_model_len),
)
Expand Down Expand Up @@ -92,7 +95,10 @@ def test_update_block_table_copies_block_idx_to_persistent_buffers():
has_initial_states_p=None,
query_start_loc_p=None,
num_computed_tokens_p=None,
state_indices_tensor=builder_a.state_indices_tensor[:num_reqs],
state_indices_tensor_p=None,
query_start_loc_d=None,
num_accepted_tokens=None,
state_indices_tensor_d=builder_a.state_indices_tensor_d[:num_reqs],
block_idx_last_scheduled_token=(
builder_a.block_idx_last_scheduled_token[:num_reqs]
),
Expand Down
19 changes: 17 additions & 2 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"glm4_moe_lite_mtp",
"glm_ocr_mtp",
"ernie_mtp",
"nemotron_h_mtp",
"exaone_moe_mtp",
"qwen3_next_mtp",
"qwen3_5_mtp",
Expand Down Expand Up @@ -255,6 +256,19 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
)

if (
hf_config.model_type == "nemotron_h"
and hasattr(hf_config, "num_nextn_predict_layers")
and hf_config.num_nextn_predict_layers > 0
):
# Check if this is an MTP variant
hf_config.model_type = "nemotron_h_mtp"
if hf_config.model_type == "nemotron_h_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update(
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
)

if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
Expand Down Expand Up @@ -325,7 +339,7 @@ def __post_init__(self):
if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# FIXME(luccafong): cudagraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model:
Expand Down Expand Up @@ -427,7 +441,7 @@ def __post_init__(self):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"Enabling num_speculative_tokens > 1 will run"
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
",which may result in lower acceptance rate"
)
Expand Down Expand Up @@ -712,6 +726,7 @@ def _verify_args(self) -> Self:
"hunyuan_vl",
"hunyuan_v1_dense",
"afmoe",
"nemotron_h",
]
if (
self.method == "eagle3"
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,15 @@ def compute_hash(self) -> str:
]
return hash_str

@property
def num_speculative_tokens(self) -> int:
if (
self.speculative_config is not None
and self.speculative_config.num_speculative_tokens is not None
):
return self.speculative_config.num_speculative_tokens
return 0

@property
def needs_dp_coordinator(self) -> bool:
"""
Expand Down
8 changes: 0 additions & 8 deletions vllm/model_executor/layers/mamba/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass

def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type
not in ["qwen3_next", "qwen3_5", "qwen3_5_moe"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = vllm_config.cache_config.mamba_block_size
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec(
Expand Down
20 changes: 2 additions & 18 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor = attn_metadata.state_indices_tensor
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
Expand Down Expand Up @@ -295,17 +296,13 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
gate,
state_indices_tensor,
num_prefill_tokens,
num_prefills,
num_decode_tokens,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d

if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
Expand Down Expand Up @@ -477,16 +474,12 @@ class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor


def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int,
num_prefills: int,
num_decode_tokens: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_decode_tokens
Expand All @@ -501,20 +494,11 @@ def split_batch_to_prefill_and_decode(
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
)

# num_decode_tokens accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_decode_tokens + num_prefills],
[num_decode_tokens, num_prefills],
dim=0,
)

return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
)


Expand Down
46 changes: 27 additions & 19 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,8 @@ def __init__(
dim=-1,
)

compilation_config = get_current_vllm_config().compilation_config
vllm_config = get_current_vllm_config()
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
Expand All @@ -488,6 +489,8 @@ def __init__(
self.cache_config = cache_config
self.prefix = prefix

self.num_spec = vllm_config.num_speculative_tokens

# Pre-compute sizes for forward pass
self.tped_intermediate_size = self.intermediate_size // self.tp_size
self.tped_conv_size = self.conv_dim // self.tp_size
Expand Down Expand Up @@ -576,14 +579,19 @@ def conv_ssm_forward(
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
query_start_loc_p = attn_metadata.query_start_loc_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
num_accepted_tokens = attn_metadata.num_accepted_tokens
query_start_loc_d = attn_metadata.query_start_loc_d
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens

if attn_metadata is None:
# profile run
Expand All @@ -593,29 +601,21 @@ def conv_ssm_forward(
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
return hidden_states

num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
num_actual_tokens = num_prefill_tokens + num_decode_tokens

# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
[num_decode_tokens, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
[num_decode_tokens, num_prefill_tokens],
dim=0,
)

Expand All @@ -642,16 +642,16 @@ def conv_ssm_forward(
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_computed_token_d = None
num_computed_tokens_p = None

preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
output[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
[num_decode_tokens, num_prefill_tokens],
dim=0,
)

Expand Down Expand Up @@ -709,6 +709,7 @@ def conv_ssm_forward(
)

# NOTE: final output is an in-place update of out tensor
assert preallocated_ssm_out_p is not None
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
Expand Down Expand Up @@ -840,6 +841,9 @@ def conv_ssm_forward(
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
num_accepted_tokens=num_accepted_tokens,
query_start_loc=query_start_loc_d,
max_query_len=state_indices_tensor_d.size(-1),
)

hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
Expand All @@ -862,6 +866,7 @@ def conv_ssm_forward(
-1, self.num_heads // self.tp_size, self.head_dim
)

assert preallocated_ssm_out_d is not None
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
Expand All @@ -879,7 +884,9 @@ def conv_ssm_forward(
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
out=preallocated_ssm_out_d.view(num_decode_tokens, -1, self.head_dim),
num_accepted_tokens=num_accepted_tokens,
cu_seqlens=query_start_loc_d,
is_blackwell=self.is_blackwell,
)

Expand All @@ -901,6 +908,7 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
num_spec=self.num_spec,
)

@property
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def mamba2_state_shape(
head_dim: int,
state_size: int,
conv_kernel: int,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
Expand All @@ -141,7 +142,7 @@ def mamba2_state_shape(
conv_dim = intermediate_size + 2 * n_groups * state_size

# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
conv_state_shape = (conv_kernel - 1 + num_spec, divide(conv_dim, tp_world_size))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shaharmor98 Could you explain the motivation for this change? I'm not sure of its purpose

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can recall, in GDN’s original PR they modified the causal_conv1d_update so that the underlying Triton kernel caches the state for every speculated token, and then during verification restores the correct cache according to the accepted token count.
One of the things they had to change, and so did we, was to account for the num_spec tokens as part of the conv_state shape declaration.
I made this change quite a long time ago, but IIRC it caused the outputs to be garbage when that value wasn’t added.
Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py#L185-L188


# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,9 @@ def causal_conv1d_update(
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch,) == conv_state_indices.shape
assert batch == conv_state_indices.shape[0], (
f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}"
)

assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/layers/mamba/short_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def forward_cuda(
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
has_initial_states_p = attn_metadata.has_initial_states_p
query_start_loc_p = attn_metadata.query_start_loc_p

Expand Down Expand Up @@ -163,13 +164,6 @@ def forward_cuda(
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)

conv_output_list = []

if has_prefill:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def get_mamba_state_shape_from_config(
head_dim=hf_config.head_dim,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
num_spec=vllm_config.num_speculative_tokens,
)

@classmethod
Expand Down
Loading