diff --git a/vllm_omni/model_executor/models/fish_speech/dac_encoder.py b/vllm_omni/model_executor/models/fish_speech/dac_encoder.py index 397530ca340..cdf0da992fc 100644 --- a/vllm_omni/model_executor/models/fish_speech/dac_encoder.py +++ b/vllm_omni/model_executor/models/fish_speech/dac_encoder.py @@ -54,6 +54,9 @@ def _load_dac_codec( if "generator" in state_dict: state_dict = state_dict["generator"] codec.load_state_dict(state_dict, strict=False) + # Encoder path only uses encoder + quantizer.forward(); prune the + # decoder before moving to device to avoid unnecessary GPU allocation. + codec.decoder = None codec = codec.to(device=device, dtype=dtype) codec.eval() diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py index e121b03371c..ed42aa98c00 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py @@ -141,6 +141,13 @@ def _ensure_codec_loaded(self) -> None: self._bake_weight_norm(codec) self._cache_attention_masks(codec) + # Decode path only uses quantizer.decode() + decoder; prune + # encode-only components before moving to device to avoid + # unnecessary GPU allocation. + codec.encoder = None + codec.quantizer.pre_module = None + codec.quantizer.downsample = None + device = self.vllm_config.device_config.device codec = codec.to(device=device, dtype=torch.float32) codec.eval()