Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion studio/backend/core/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,15 @@ def _generate_dac(
+ "<|text_end|>\n<|audio_start|><|global_features_start|>\n"
)
with torch.inference_mode():
with torch.amp.autocast("cuda", dtype = model.dtype):
# Derive the autocast device from the loaded model, not from the
# global backend: a CPU-fallback DAC on an XPU/CUDA host must not
# open a GPU autocast context around CPU tensors.
device_type = (
model.device.type
if hasattr(model.device, "type")
else str(model.device).split(":", 1)[0]
)
with torch.amp.autocast(device_type, dtype = model.dtype):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard CPU autocast dtype before entering DAC generation

When model.device resolves to CPU (the exact fallback this change is trying to support), this now calls torch.amp.autocast("cpu", dtype=model.dtype). CPU autocast does not accept arbitrary dtypes (notably common float32 CPU models), so entering this context can raise immediately and abort audio generation before model.generate() runs. This creates a regression for CPU-fallback DAC inference paths on non-CUDA setups; CPU should either skip autocast or use a CPU-supported autocast dtype explicitly.

Useful? React with 👍 / 👎.

inputs = tokenizer([prompt], return_tensors = "pt").to(model.device)
generated = model.generate(
**inputs,
Expand Down
13 changes: 9 additions & 4 deletions studio/backend/core/inference/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,8 +1438,9 @@ def unload_model(self) -> bool:
LlamaCppBackend._codec_mgr = None
import torch

if torch.cuda.is_available():
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
return True

def _kill_process(self):
Expand Down Expand Up @@ -3016,7 +3017,9 @@ def init_audio_codec(self, audio_type: str) -> None:
if LlamaCppBackend._codec_mgr is None:
LlamaCppBackend._codec_mgr = AudioCodecManager()

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()
model_repo_path = None

# BiCodec needs a repo with BiCodec/ weights — download canonical SparkTTS
Expand Down Expand Up @@ -3090,7 +3093,9 @@ def generate_audio_response(

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()
return LlamaCppBackend._codec_mgr.decode(
audio_type, device, token_ids = token_ids, text = data.get("content", "")
)
24 changes: 18 additions & 6 deletions studio/backend/core/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,9 @@ def _preprocess_snac_dataset(self, dataset, custom_format_mapping = None):

SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz"
SNAC_SAMPLE_RATE = 24000
device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()
max_length = self.max_seq_length or 2048
tokenizer = self.tokenizer

Expand Down Expand Up @@ -1708,7 +1710,9 @@ def _preprocess_snac_dataset(self, dataset, custom_format_mapping = None):
import gc

gc.collect()
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
self._cuda_audio_used = True

if not processed_examples:
Expand Down Expand Up @@ -1736,7 +1740,9 @@ def _preprocess_bicodec_dataset(self, dataset, custom_format_mapping = None):

import subprocess

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()

# The sparktts Python package lives in the SparkAudio/Spark-TTS GitHub repo,
# NOT in the unsloth/Spark-TTS-0.5B HF model repo. Clone it if needed.
Expand Down Expand Up @@ -1936,7 +1942,9 @@ def extract_wav2vec2_features(wavs: torch.Tensor) -> torch.Tensor:
import gc

gc.collect()
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
self._cuda_audio_used = True

if not processed_examples:
Expand Down Expand Up @@ -1971,7 +1979,9 @@ def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None):
from datasets import Dataset as HFDataset
from utils.paths import ensure_dir, tmp_root

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()

# Clone OuteTTS repo (same as audio_codecs._load_dac)
import subprocess
Expand Down Expand Up @@ -2149,7 +2159,9 @@ def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None):
import gc

gc.collect()
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
self._cuda_audio_used = True

if not processed_examples:
Expand Down
2 changes: 2 additions & 0 deletions studio/backend/utils/hardware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
estimate_required_model_memory_gb,
auto_select_gpu_ids,
prepare_gpu_selection,
get_torch_device_str,
safe_num_proc,
safe_thread_num_proc,
dataset_map_num_proc,
Expand Down Expand Up @@ -68,6 +69,7 @@
"estimate_required_model_memory_gb",
"auto_select_gpu_ids",
"prepare_gpu_selection",
"get_torch_device_str",
"safe_num_proc",
"safe_thread_num_proc",
"dataset_map_num_proc",
Expand Down
Loading
Loading