Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion studio/backend/core/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,8 +1642,31 @@ def _generate_dac(
+ text
+ "<|text_end|>\n<|audio_start|><|global_features_start|>\n"
)
import contextlib

with torch.inference_mode():
with torch.amp.autocast("cuda", dtype = model.dtype):
# Derive the autocast device from the loaded model, not from the
# global backend: a CPU-fallback DAC on an XPU/CUDA host must not
# open a GPU autocast context around CPU tensors.
device_type = (
model.device.type
if hasattr(model.device, "type")
else str(model.device).split(":", 1)[0]
)
# Clamp to autocast-supported backends so exotic devices
# (e.g. "meta" during accelerate offloaded loading) do not raise.
# MPS is autocast-supported since torch 2.3, keep it in the set.
if device_type not in ("cuda", "xpu", "mps", "cpu"):
device_type = "cpu"
# CPU and XPU autocast only accept bfloat16/float16. For a
# float32 model, skip autocast entirely to avoid raising or
# producing a warning on every generate call.
autocast_dtype_supported = model.dtype in (torch.bfloat16, torch.float16)
if device_type in ("cpu", "xpu") and not autocast_dtype_supported:
autocast_ctx = contextlib.nullcontext()
else:
autocast_ctx = torch.amp.autocast(device_type, dtype = model.dtype)
with autocast_ctx:
inputs = tokenizer([prompt], return_tensors = "pt").to(model.device)
generated = model.generate(
**inputs,
Expand Down
13 changes: 9 additions & 4 deletions studio/backend/core/inference/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,8 +1438,9 @@ def unload_model(self) -> bool:
LlamaCppBackend._codec_mgr = None
import torch

if torch.cuda.is_available():
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
return True

def _kill_process(self):
Expand Down Expand Up @@ -3016,7 +3017,9 @@ def init_audio_codec(self, audio_type: str) -> None:
if LlamaCppBackend._codec_mgr is None:
LlamaCppBackend._codec_mgr = AudioCodecManager()

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()
model_repo_path = None

# BiCodec needs a repo with BiCodec/ weights — download canonical SparkTTS
Expand Down Expand Up @@ -3090,7 +3093,9 @@ def generate_audio_response(

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()
return LlamaCppBackend._codec_mgr.decode(
audio_type, device, token_ids = token_ids, text = data.get("content", "")
)
24 changes: 18 additions & 6 deletions studio/backend/core/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,9 @@ def _preprocess_snac_dataset(self, dataset, custom_format_mapping = None):

SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz"
SNAC_SAMPLE_RATE = 24000
device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()
max_length = self.max_seq_length or 2048
tokenizer = self.tokenizer

Expand Down Expand Up @@ -1708,7 +1710,9 @@ def _preprocess_snac_dataset(self, dataset, custom_format_mapping = None):
import gc

gc.collect()
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
self._cuda_audio_used = True

if not processed_examples:
Expand Down Expand Up @@ -1736,7 +1740,9 @@ def _preprocess_bicodec_dataset(self, dataset, custom_format_mapping = None):

import subprocess

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()

# The sparktts Python package lives in the SparkAudio/Spark-TTS GitHub repo,
# NOT in the unsloth/Spark-TTS-0.5B HF model repo. Clone it if needed.
Expand Down Expand Up @@ -1936,7 +1942,9 @@ def extract_wav2vec2_features(wavs: torch.Tensor) -> torch.Tensor:
import gc

gc.collect()
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
self._cuda_audio_used = True

if not processed_examples:
Expand Down Expand Up @@ -1971,7 +1979,9 @@ def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None):
from datasets import Dataset as HFDataset
from utils.paths import ensure_dir, tmp_root

device = "cuda" if torch.cuda.is_available() else "cpu"
from utils.hardware import get_torch_device_str

device = get_torch_device_str()

# Clone OuteTTS repo (same as audio_codecs._load_dac)
import subprocess
Expand Down Expand Up @@ -2149,7 +2159,9 @@ def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None):
import gc

gc.collect()
torch.cuda.empty_cache()
from utils.hardware import clear_gpu_cache

clear_gpu_cache()
self._cuda_audio_used = True

if not processed_examples:
Expand Down
2 changes: 2 additions & 0 deletions studio/backend/utils/hardware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
estimate_required_model_memory_gb,
auto_select_gpu_ids,
prepare_gpu_selection,
get_torch_device_str,
safe_num_proc,
safe_thread_num_proc,
dataset_map_num_proc,
Expand Down Expand Up @@ -68,6 +69,7 @@
"estimate_required_model_memory_gb",
"auto_select_gpu_ids",
"prepare_gpu_selection",
"get_torch_device_str",
"safe_num_proc",
"safe_thread_num_proc",
"dataset_map_num_proc",
Expand Down
Loading
Loading