Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,13 @@ def __init__(

# Pre-allocated buffers (lazily initialized on first forward).
self._proj_buf: torch.Tensor | None = None
self._pos_ids: torch.Tensor | None = None

# torch.compile: fuse small kernels in the 5-layer transformer.
self._compiled_model_fwd: object | None = None
# torch.compile + warmup state (lazily initialized in _setup_compile).
self._compiled_model_fwd = None
self._bucket_sizes: list[int] = []
self._bucket_pos_ids: dict[int, torch.Tensor] = {}
self._lm_heads_list: list[nn.Module] | None = None
self._codec_embeds_list: list[nn.Module] | None = None

def get_input_embeddings(self) -> nn.ModuleList:
return self.model.get_input_embeddings()
Expand All @@ -374,54 +377,74 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
continue
default_weight_loader(params[name], w)
loaded.add(name)

return loaded

# ------------------------------------------------------------------
# Pre-allocated buffer management
# ------------------------------------------------------------------

def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_groups + 1
if (
self._proj_buf is not None
and self._proj_buf.shape[0] >= bsz
and self._proj_buf.device == device
and self._proj_buf.dtype == dtype
):
if self._proj_buf is not None and self._proj_buf.device == device and self._proj_buf.dtype == dtype:
return
max_bsz = self._vllm_config.scheduler_config.max_num_seqs
self._proj_buf = torch.zeros(
bsz,
max_bsz,
max_seq,
self._cp_hidden,
dtype=dtype,
device=device,
)
self._pos_ids = torch.arange(
max_seq,
dtype=torch.long,
device=device,
)

def _setup_compile(self) -> None:
"""Lazily set up torch.compiled model forward for kernel fusion.

Uses ``mode="default"`` so Inductor performs operator fusion without
capturing its own CUDA graphs. This avoids conflicts with vLLM's
``CUDAGraphWrapper`` which manages CUDA graphs for the main Talker
model on the default stream.
Uses ``mode="reduce-overhead"`` with ``dynamic=False`` so Inductor
captures internal CUDA graphs for fixed shapes, eliminating kernel
launch overhead entirely.
"""
if self._compiled_model_fwd is not None:
return
self._lm_heads_list = list(self.lm_head)
self._codec_embeds_list = list(self.model.codec_embedding)
if not current_omni_platform.supports_torch_inductor():
logger.warning_once("code_predictor: torch.compile disabled")
self._compiled_model_fwd = self.model.forward
return
self._compiled_model_fwd = torch.compile(
self.model.forward,
mode="default",
dynamic=True,
mode="reduce-overhead",
dynamic=False,
)
logger.info("code_predictor: torch.compile enabled (mode=default)")
logger.info("code_predictor: torch.compile enabled (mode=reduce-overhead, dynamic=False)")
self._warmup_compile()

def _padded_bsz(self, bsz: int) -> int:
for bucket in self._bucket_sizes:
if bsz <= bucket:
return bucket
return bsz

def _warmup_compile(self) -> None:
"""Warmup power-of-2 batch-size buckets to front-load compilation."""
max_bsz = self._vllm_config.scheduler_config.max_num_seqs
bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
if max_bsz not in bucket_sizes:
bucket_sizes.append(max_bsz)
self._bucket_sizes = sorted(bucket_sizes)

max_seq = self._num_groups + 1
device = next(self.model.parameters()).device
base_pos = torch.arange(max_seq, device=device, dtype=torch.long)

proj_buf = self._proj_buf
for bsz in self._bucket_sizes:
pos_ids = base_pos if bsz == 1 else base_pos.repeat(bsz)
self._bucket_pos_ids[bsz] = pos_ids
for _ in range(3):
self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids)
logger.info("code_predictor: warmup done for bucket sizes %s", self._bucket_sizes)

# ------------------------------------------------------------------
# Optimized forward: re-prefill + torch.compile + projection cache
Expand Down Expand Up @@ -454,19 +477,16 @@ def forward(
all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device)
all_codes[:, 0] = layer0_code.reshape(bsz)

self._ensure_buffers(bsz, device, dtype)
self._ensure_buffers(device, dtype)
self._setup_compile()

proj_buf = self._proj_buf
pos_ids = self._pos_ids
max_seq = self._num_groups + 1

projection = self.small_to_mtp_projection
model_fwd = self._compiled_model_fwd
lm_heads = list(self.lm_head)
codec_embeds = list(self.model.codec_embedding)

proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1)).reshape(bsz, -1)
proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
lm_heads = self._lm_heads_list
codec_embeds = self._codec_embeds_list
Comment on lines +488 to +489
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize cached heads in no-compile fallback

When supports_torch_inductor() is false, _setup_compile() now returns early after setting only _compiled_model_fwd, but forward() unconditionally reads _lm_heads_list and _codec_embeds_list from these cached fields. In that environment (e.g., CPU or unsupported GPU), those fields stay None, so the first decode step fails with a NoneType subscript error instead of using the previous working path. Please populate these caches in the fallback branch (or avoid relying on them when compile is disabled).

Useful? React with 👍 / 👎.


use_sampling = do_sample and temperature > 0
inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0
Expand All @@ -475,15 +495,21 @@ def forward(
"top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0."
)

for step in range(1, num_groups):
seq_len = step + 1
padded_bsz = self._padded_bsz(bsz)
proj_buf[:padded_bsz].zero_()

projected = proj_buf[:bsz, :seq_len, :]
step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz)
proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1)).reshape(bsz, -1)
proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
full_pos_ids = self._bucket_pos_ids.get(padded_bsz)
if full_pos_ids is None:
base_pos = torch.arange(max_seq, device=device, dtype=torch.long)
full_pos_ids = base_pos if padded_bsz == 1 else base_pos.repeat(padded_bsz)

hidden_out = model_fwd(projected, step_pos_ids)
for step in range(1, num_groups):
projected = proj_buf[:padded_bsz, :max_seq, :]

logits = lm_heads[step - 1](hidden_out[:, -1, :])
hidden_out = model_fwd(projected, full_pos_ids)
logits = lm_heads[step - 1](hidden_out[:bsz, step, :])

if use_sampling:
scaled = logits * inv_temperature
Expand Down
Loading