Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
bsz, seq_len, _ = hidden_states.shape

Expand All @@ -115,7 +116,8 @@ def forward(
k,
v,
scale=self.scaling,
is_causal=True,
attn_mask=attn_mask,
is_causal=(attn_mask is None),
enable_gqa=self._use_gqa,
)

Expand Down Expand Up @@ -176,10 +178,11 @@ def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, position_ids)
hidden_states = self.self_attn(hidden_states, position_ids, attn_mask)
hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -226,10 +229,11 @@ def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids)
hidden_states = layer(hidden_states, position_ids, attn_mask)
hidden_states = self.norm(hidden_states)
return hidden_states

Expand Down Expand Up @@ -348,8 +352,14 @@ def __init__(
self._proj_buf: torch.Tensor | None = None
self._pos_ids: torch.Tensor | None = None

# torch.compile: fuse small kernels in the 5-layer transformer.
self._compiled_model_fwd: object | None = None
# torch.compile state (set in _setup_compile).
self._compiled_model_fwd = None
self._bucket_sizes: list[int] = []
self._bucket_pos_ids: dict[int, torch.Tensor] = {}

# Cached module list references.
self._lm_heads_list: list[nn.Module] | None = None
self._codec_embeds_list: list[nn.Module] | None = None

def get_input_embeddings(self) -> nn.ModuleList:
return self.model.get_input_embeddings()
Expand All @@ -374,41 +384,33 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
default_weight_loader(params[name], w)
loaded.add(name)

return loaded

# ------------------------------------------------------------------
# Pre-allocated buffer management
# ------------------------------------------------------------------

def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_groups + 1
if (
self._proj_buf is not None
and self._proj_buf.shape[0] >= bsz
and self._proj_buf.device == device
and self._proj_buf.dtype == dtype
):
if self._proj_buf is not None and self._proj_buf.device == device and self._proj_buf.dtype == dtype:
return
# Allocate for max batch size so all bucket sizes fit without realloc.
max_bsz = self._vllm_config.scheduler_config.max_num_seqs
self._proj_buf = torch.zeros(
bsz,
max_bsz,
max_seq,
self._cp_hidden,
dtype=dtype,
device=device,
)
self._pos_ids = torch.arange(
max_seq,
dtype=torch.long,
device=device,
)

def _setup_compile(self) -> None:
"""Lazily set up torch.compiled model forward for kernel fusion.
"""Lazily set up torch.compiled model forward with CUDA graph capture.

Uses ``mode="default"`` so Inductor performs operator fusion without
capturing its own CUDA graphs. This avoids conflicts with vLLM's
``CUDAGraphWrapper`` which manages CUDA graphs for the main Talker
model on the default stream.
Uses ``mode="reduce-overhead"`` with ``dynamic=False`` so Inductor
captures internal CUDA graphs for fixed shapes, eliminating kernel
launch overhead entirely.
"""
if self._compiled_model_fwd is not None:
return
Expand All @@ -418,10 +420,50 @@ def _setup_compile(self) -> None:
return
self._compiled_model_fwd = torch.compile(
self.model.forward,
mode="default",
dynamic=True,
mode="reduce-overhead",
dynamic=False,
)
logger.info("code_predictor: torch.compile enabled (mode=default)")
self._lm_heads_list = list(self.lm_head)
self._codec_embeds_list = list(self.model.codec_embedding)
logger.info("code_predictor: torch.compile enabled (mode=reduce-overhead, dynamic=False)")
# Warmup: trigger Inductor compilation for common batch sizes.
self._warmup_compile()

def _padded_bsz(self, bsz: int) -> int:
"""Map actual batch size to the next pre-warmed bucket size."""
for bucket in self._bucket_sizes:
if bsz <= bucket:
return bucket
# Fallback: if bsz exceeds all buckets (shouldn't happen), use as-is.
return bsz

def _warmup_compile(self) -> None:
"""Pre-trigger Inductor compilation for bucket batch sizes.

Uses power-of-2 buckets up to max_num_seqs. Pre-caches position_ids
per bucket to avoid allocation in the hot loop.
"""
max_bsz = self._vllm_config.scheduler_config.max_num_seqs
bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
if max_bsz not in bucket_sizes:
bucket_sizes.append(max_bsz)
self._bucket_sizes = sorted(bucket_sizes)

max_seq = self._num_groups + 1
device = next(self.model.parameters()).device
base_pos = torch.arange(max_seq, device=device, dtype=torch.long)

# Use slices of the pre-allocated proj_buf so Dynamo guards
# (storage().nbytes(), storage_offset) match runtime exactly.
proj_buf = self._proj_buf
for bsz in self._bucket_sizes:
pos_ids = base_pos if bsz == 1 else base_pos.repeat(bsz)
self._bucket_pos_ids[bsz] = pos_ids
# 3 iterations: reduce-overhead needs multiple passes to
# compile, capture, and stabilize its internal CUDA graphs.
for _ in range(3):
self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids)
logger.info("code_predictor: warmup done for bucket sizes %s", self._bucket_sizes)

# ------------------------------------------------------------------
# Optimized forward: re-prefill + torch.compile + projection cache
Expand Down Expand Up @@ -454,16 +496,16 @@ def forward(
all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device)
all_codes[:, 0] = layer0_code.reshape(bsz)

self._ensure_buffers(bsz, device, dtype)
self._ensure_buffers(device, dtype)
self._setup_compile()

proj_buf = self._proj_buf
pos_ids = self._pos_ids
max_seq = self._num_groups + 1

projection = self.small_to_mtp_projection
model_fwd = self._compiled_model_fwd
lm_heads = list(self.lm_head)
codec_embeds = list(self.model.codec_embedding)
lm_heads = self._lm_heads_list
codec_embeds = self._codec_embeds_list
Comment on lines +488 to +489

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize cached heads in no-compile fallback

When supports_torch_inductor() is false, _setup_compile() now returns early after setting only _compiled_model_fwd, but forward() unconditionally reads _lm_heads_list and _codec_embeds_list from these cached fields. In that environment (e.g., CPU or unsupported GPU), those fields stay None, so the first decode step fails with a NoneType subscript error instead of using the previous working path. Please populate these caches in the fallback branch (or avoid relying on them when compile is disabled).

Useful? React with 👍 / 👎.


proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1)).reshape(bsz, -1)
proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
Expand All @@ -475,15 +517,21 @@ def forward(
"top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0."
)

for step in range(1, num_groups):
seq_len = step + 1
# Pad batch to next pre-warmed bucket
padded_bsz = self._padded_bsz(bsz)
full_pos_ids = self._bucket_pos_ids.get(padded_bsz)
if full_pos_ids is None:
# Fallback for unexpected batch size (should not happen in practice).
base_pos = torch.arange(max_seq, device=device, dtype=torch.long)
full_pos_ids = base_pos if padded_bsz == 1 else base_pos.repeat(padded_bsz)

projected = proj_buf[:bsz, :seq_len, :]
step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz)
for step in range(1, num_groups):
projected = proj_buf[:padded_bsz, :max_seq, :]

hidden_out = model_fwd(projected, step_pos_ids)
hidden_out = model_fwd(projected, full_pos_ids)

logits = lm_heads[step - 1](hidden_out[:, -1, :])
# Slice back to actual batch size — padding rows are discarded.
logits = lm_heads[step - 1](hidden_out[:bsz, step, :])

if use_sampling:
scaled = logits * inv_temperature
Expand Down
Loading