Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, vllm_config: Any):
self.get_req_chunk: dict[str, int] = defaultdict(int)
self.finished_requests: set[str] = set()
self.request_payload = {}
self.code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list)
self.code_prompt_token_ids: dict[str, list[torch.Tensor]] = defaultdict(list)
self.request_ids_mapping: dict[str, str] = {}

self.waiting_for_chunk_waiting_requests: deque[Any] = deque()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
import torch.nn as nn
from torch.nn.utils.parametrize import remove_parametrizations
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
Expand Down Expand Up @@ -58,6 +59,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._hop_length: int = DAC_HOP_LENGTH
self._logged_codec_stats = False

def _bake_weight_norm(self, codec: nn.Module) -> None:
baked = 0
for module in codec.modules():
parametrizations = getattr(module, "parametrizations", None)
if not parametrizations:
continue
for name in list(parametrizations.keys()):
remove_parametrizations(module, name, leave_parametrized=True)
baked += 1
if baked > 0:
logger.info("Baked %d DAC parametrized weights for inference", baked)

def _cache_attention_masks(self, codec: nn.Module) -> None:
for module in codec.modules():
if not hasattr(module, "make_mask") or not hasattr(module, "make_window_limited_mask"):
continue

base_make_mask = module.make_mask
base_make_window_mask = module.make_window_limited_mask
mask_cache: dict[int, torch.Tensor] = {}
window_mask_cache: dict[int, torch.Tensor] = {}

def make_mask_cached(max_length: int, x_lens: torch.Tensor | None = None, *, _orig=base_make_mask):
if x_lens is not None:
return _orig(max_length, x_lens)
key = int(max_length)
cached = mask_cache.get(key)
if cached is None:
cached = _orig(max_length, x_lens)
mask_cache[key] = cached
return cached

def make_window_mask_cached(
max_length: int,
x_lens: torch.Tensor | None = None,
*,
_orig=base_make_window_mask,
):
if x_lens is not None:
return _orig(max_length, x_lens)
key = int(max_length)
cached = window_mask_cache.get(key)
if cached is None:
cached = _orig(max_length, x_lens)
window_mask_cache[key] = cached
return cached

module.make_mask = make_mask_cached
module.make_window_limited_mask = make_window_mask_cached

def _ensure_codec_loaded(self) -> None:
if self._codec is not None:
return
Expand Down Expand Up @@ -87,6 +138,8 @@ def _ensure_codec_loaded(self) -> None:
if "generator" in state_dict:
state_dict = state_dict["generator"]
codec.load_state_dict(state_dict, strict=False)
self._bake_weight_norm(codec)
self._cache_attention_masks(codec)

device = self.vllm_config.device_config.device
codec = codec.to(device=device, dtype=torch.float32)
Expand Down Expand Up @@ -160,7 +213,7 @@ def forward(
ids = input_ids.reshape(-1).to(dtype=torch.long)
request_ids_list = self._split_request_ids(ids, kwargs.get("seq_token_counts"))

parsed: list[tuple[int, int]] = []
parsed_ctx_frames: list[int] = []
valid_codes_qf: list[torch.Tensor] = []
valid_indices: list[int] = []
left_context_size = [0] * len(request_ids_list)
Expand All @@ -173,7 +226,7 @@ def forward(

for i, req_ids in enumerate(request_ids_list):
if req_ids.numel() < 1:
parsed.append((0, 0))
parsed_ctx_frames.append(0)
continue
ctx_frames = left_context_size[i]
flat = req_ids
Expand All @@ -185,11 +238,11 @@ def forward(
n,
q,
)
parsed.append((0, 0))
parsed_ctx_frames.append(0)
continue
frames = n // q
codes_qf = flat.reshape(q, frames)
parsed.append((ctx_frames, frames))
parsed_ctx_frames.append(ctx_frames)
valid_codes_qf.append(codes_qf)
valid_indices.append(i)

Expand Down Expand Up @@ -219,23 +272,33 @@ def forward(
except Exception:
pass

# Decode each request individually.
wav_tensors: list[torch.Tensor] = []
for codes_qf in valid_codes_qf:
codes_bqf = codes_qf.unsqueeze(0) # [1, num_codebooks, num_frames]
num_frames = codes_qf.shape[1]
feature_lengths = torch.tensor([num_frames], device=codes_bqf.device)
with torch.cuda.amp.autocast(enabled=False):
wav, audio_lengths = self._codec.decode(codes_bqf, feature_lengths)
# wav shape: [1, 1, wav_len]
wav_tensors.append(wav.squeeze(0).squeeze(0)) # [wav_len]
feature_lengths = torch.tensor(
[codes_qf.shape[1] for codes_qf in valid_codes_qf],
device=valid_codes_qf[0].device,
dtype=torch.long,
)
max_frames = int(feature_lengths.max().item())
batch_size = len(valid_codes_qf)

codes_bqf = torch.zeros(
(batch_size, q, max_frames),
device=valid_codes_qf[0].device,
dtype=torch.long,
)
for i, codes_qf in enumerate(valid_codes_qf):
frame_count = int(feature_lengths[i].item())
codes_bqf[i, :, :frame_count] = codes_qf

with torch.amp.autocast("cuda", enabled=False):
wav_batch, audio_lengths = self._codec.decode(codes_bqf, feature_lengths)

audios: list[torch.Tensor] = [empty] * num_req
srs = [sr_tensor] * num_req

for j, idx in enumerate(valid_indices):
ctx_frames, actual_frames = parsed[idx]
wav = wav_tensors[j]
ctx_frames = parsed_ctx_frames[idx]
audio_len = int(audio_lengths[j].item()) if audio_lengths.numel() > j else int(wav_batch.shape[-1])
wav = wav_batch[j, 0, :audio_len]
# Trim context frames (left overlap for streaming).
if ctx_frames > 0:
cut = ctx_frames * self._hop_length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ def __init__(
self._embed_buf: torch.Tensor | None = None
self._pos_ids: torch.Tensor | None = None
self._compiled_model_fwd: object | None = None
self._compile_attempted = False
self._compile_failed = False

def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes
Expand All @@ -322,11 +324,61 @@ def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) ->
self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device)

def _setup_compile(self) -> None:
if self._compiled_model_fwd is not None:
if self._compile_attempted:
return
# TODO: Enable torch.compile for performance. Eager for now to avoid
# potential graph-break issues during initial bring-up.
self._compiled_model_fwd = self.model.forward
self._compile_attempted = True
try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
# Keep the helper compiler separate from vLLM's outer
# cudagraph-managed Stage-0 execution.
mode="default",
dynamic=True,
fullgraph=False,
)
except Exception as exc:
self._compile_failed = True
logger.warning("Failed to enable torch.compile for Fish Speech Fast AR: %s", exc)
self._compiled_model_fwd = self.model.forward
else:
logger.info("Enabled torch.compile for Fish Speech Fast AR forward (mode=default)")

@torch.inference_mode()
def warmup_compile(
self,
device: torch.device,
dtype: torch.dtype,
batch_sizes: tuple[int, ...] = (1,),
) -> None:
self._setup_compile()
if self._compiled_model_fwd is self.model.forward or self._compile_failed:
return
for batch_size in batch_sizes:
hidden = torch.zeros((batch_size, self.slow_ar_config.hidden_size), device=device, dtype=dtype)
semantic = torch.full(
(batch_size,),
self.slow_ar_config.semantic_begin_id,
device=device,
dtype=torch.long,
)
self(hidden, semantic, do_sample=False)
torch.cuda.synchronize(device)

@torch.inference_mode()
def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor:
# Default-on compile only pays off for single-request decode. For
# batched decode, eager preserves loaded throughput and avoids the
# regression seen with batch>1 compiled execution.
model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
try:
return model_fwd(step_input, step_pos_ids)
except Exception as exc:
if model_fwd is self.model.forward or self._compile_failed:
raise
self._compile_failed = True
self._compiled_model_fwd = self.model.forward
logger.warning("Fish Speech Fast AR torch.compile fallback to eager after runtime failure: %s", exc)
return self.model.forward(step_input, step_pos_ids)

@torch.inference_mode()
def forward(
Expand Down Expand Up @@ -369,7 +421,6 @@ def forward(

embed_buf = self._embed_buf
pos_ids = self._pos_ids
model_fwd = self._compiled_model_fwd

# Position 0: projected Slow AR hidden state.
projected = self.fast_project_in(slow_ar_hidden.reshape(bsz, -1))
Expand All @@ -390,9 +441,11 @@ def forward(
for step in range(1, num_cb):
seq_len = step + 1
step_input = embed_buf[:bsz, :seq_len, :]
step_pos_ids = pos_ids[:seq_len] if bsz == 1 else pos_ids[:seq_len].repeat(bsz)
# Use a dense 2D position tensor for every batch size; stride-0
# views from expand() were fragile under compiled execution.
step_pos_ids = pos_ids[:seq_len].unsqueeze(0).repeat(bsz, 1)

hidden_out = model_fwd(step_input, step_pos_ids)
hidden_out = self._run_model(step_input, step_pos_ids, bsz)
logits = self.fast_output(self.fast_norm(hidden_out[:, -1, :]))

# Residual codebooks (step >= 1) only have 1024 entries.
Expand Down
Loading
Loading