Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Python interface #817

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions fish_speech/inference_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, No

ref_id: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
# Load the reference audio and text based on id or hash
# Load the reference audio and text based on id, hash, or preprocessed references
if ref_id is not None:
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)

Expand All @@ -57,6 +57,10 @@ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, No
req.references, req.use_memory_cache
)

elif req.preprocessed_references:
prompt_tokens = [ref.tokens for ref in req.preprocessed_references]
prompt_texts = [ref.text for ref in req.preprocessed_references]

# Set the random seed if provided
if req.seed is not None:
set_seed(req.seed)
Expand Down Expand Up @@ -106,7 +110,7 @@ def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, No
if result.action != "next":
segment = self.get_audio_segment(result)

if req.streaming: # Used only by the API server
if req.streaming:
yield InferenceResult(
code="segment",
audio=(sample_rate, segment),
Expand Down
258 changes: 258 additions & 0 deletions fish_speech/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import warnings
from queue import Queue
from typing import Generator, List, Literal, Optional, Union

import numpy as np
import torch

from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
from fish_speech.models.vqgan.inference import load_model as load_vqgan_model
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.utils.file import audio_to_bytes
from fish_speech.utils.schema import Reference, ServeTTSRequest

Device = Literal["cuda", "mps", "cpu"]

warnings.simplefilter(action="ignore", category=FutureWarning)


class Pipeline:

def __init__(
self,
llama_path: str,
vqgan_path: str,
vqgan_config: str = "firefly_gan_vq",
device: Device = "cpu",
half: bool = False,
compile: bool = False,
) -> None:
"""
Initialize the TTS pipeline.

Args:
llama_path (str): Path to the LLAMA model.
vqgan_path (str): Path to the VQ-GAN model.
vqgan_config (str, optional): VQ-GAN model configuration name. Defaults to base configuration.
device (Device, optional): Device to run the pipeline on. Defaults to "cpu".
half (bool, optional): Use half precision. Defaults to False.
compile (bool, optional): Compile the models. Defaults to False.
"""

# Validate input
assert isinstance(llama_path, str), "llama_path must be a string."
assert isinstance(vqgan_path, str), "vqgan_path must be a string."
assert isinstance(vqgan_config, str), "vqgan_config must be a string."
assert isinstance(half, bool), "half must be a boolean."
assert isinstance(compile, bool), "compile must be a boolean."

device = self.check_device(device)
precision = torch.half if half else torch.bfloat16

llama = self.load_llama(llama_path, device, precision, compile)
vqgan = self.load_vqgan(vqgan_config, vqgan_path, device)

self.inference_engine = TTSInferenceEngine(
llama_queue=llama,
decoder_model=vqgan,
precision=precision,
compile=compile,
)

self.warmup(self.inference_engine)

def check_device(self, device: str) -> Device:
"""Check if the device is available."""
device = device.lower()

# If CUDA or MPS chosen, check if available
match device:
case "cuda":
if not torch.cuda.is_available():
warnings.warn("CUDA is not available, running on CPU.")
device = "cpu"
case "mps":
if not torch.backends.mps.is_available():
warnings.warn("MPS is not available, running on CPU.")
device = "cpu"
case "cpu":
pass
case _:
raise ValueError("Invalid device, choose from 'cuda', 'mps', 'cpu'.")

return device

def load_llama(
self, llama_path: str, device: str, precision: torch.dtype, compile: bool
) -> Queue:
"""Load the LLAMA model."""
try:
return launch_thread_safe_queue(
checkpoint_path=llama_path,
device=device,
precision=precision,
compile=compile,
)
except Exception as e:
raise ValueError(f"Failed to load LLAMA model: {e}")

def load_vqgan(
self, vqgan_config: str, vqgan_path: str, device: str
) -> FireflyArchitecture:
"""Load the VQ-GAN model."""
try:
return load_vqgan_model(
config_name=vqgan_config,
checkpoint_path=vqgan_path,
device=device,
)
except Exception as e:
raise ValueError(f"Failed to load VQ-GAN model: {e}")

def warmup(self, inference_engine: TTSInferenceEngine) -> None:
"""Warm up the inference engine."""
try:
list(
inference_engine.inference(
ServeTTSRequest(
text="Hello world.",
references=[],
reference_id=None,
max_new_tokens=1024,
chunk_length=200,
top_p=0.7,
repetition_penalty=1.5,
temperature=0.7,
format="wav",
)
)
)
except Exception as e:
raise ValueError(f"Failed to warm up the inference engine: {e}")

@property
def sample_rate(self) -> int:
"""Get the sample rate of the audio."""
return self.inference_engine.decoder_model.spec_transform.sample_rate

def make_reference(self, audio_path: str, text: str) -> Reference:
"""Create a reference object from audio and text."""
audio_bytes = audio_to_bytes(audio_path)
if audio_bytes is None:
raise ValueError("Failed to load audio file.")

tokens = self.inference_engine.encode_reference(audio_bytes, True)
return Reference(tokens=tokens, text=text)

def generate_streaming(
self,
text: str,
references: Union[List[Reference], Reference] = [],
seed: Optional[int] = None,
streaming: bool = False,
max_new_tokens: int = 0,
chunk_length: int = 200,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
temperature: Optional[float] = None,
) -> Generator:
"""
Generate audio from text.

Args:
text (str): Text to generate audio from.
references (Union[List[Reference], Reference], optional): List of reference audios. Defaults to [].
seed (Optional[int], optional): Random seed. Defaults to None.
streaming (bool, optional): Stream the audio. Defaults to False.
max_new_tokens (int, optional): Maximum number of tokens. Defaults to 0 (no limit).
chunk_length (int, optional): Chunk length for streaming. Defaults to 200.
top_p (Optional[float], optional): Top-p sampling. Defaults to None.
repetition_penalty (Optional[float], optional): Repetition penalty. Defaults to None.
temperature (Optional[float], optional): Sampling temperature. Defaults to None.
"""
references = [references] if isinstance(references, Reference) else references

request = ServeTTSRequest(
text=text,
preprocessed_references=references,
seed=seed,
streaming=streaming,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p or 0.7,
repetition_penalty=repetition_penalty or 1.2,
temperature=temperature or 0.7,
)

count = 0
for result in self.inference_engine.inference(request):
match result.code:
case "header":
pass # In this case, we only want to yield the audio (amplitude)
# User can save with a library like soundfile if needed

case "error":
if isinstance(result.error, Exception):
raise result.error
else:
raise RuntimeError("Unknown error")

case "segment":
count += 1
if isinstance(result.audio, tuple) and streaming:
yield result.audio[1]

case "final":
count += 1
if isinstance(result.audio, tuple) and not streaming:
yield result.audio[1]

if count == 0:
raise RuntimeError("No audio generated, please check the input text.")

def generate(
self,
text: str,
references: Union[List[Reference], Reference] = [],
seed: Optional[int] = None,
streaming: bool = False,
max_new_tokens: int = 0,
chunk_length: int = 200,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
temperature: Optional[float] = None,
) -> Union[Generator, np.ndarray]:
"""
Wrapper for the generate_streaming method.
Returns either a generator or directly the final audio.

Args:
text (str): Text to generate audio from.
references (Union[List[Reference], Reference], optional): List of reference audios. Defaults to [].
seed (Optional[int], optional): Random seed. Defaults to None.
streaming (bool, optional): Stream the audio. Defaults to False.
max_new_tokens (int, optional): Maximum number of tokens. Defaults to 0 (no limit).
chunk_length (int, optional): Chunk length for streaming. Defaults to 200.
top_p (Optional[float], optional): Top-p sampling. Defaults to None.
repetition_penalty (Optional[float], optional): Repetition penalty. Defaults to None.
temperature (Optional[float], optional): Sampling temperature. Defaults to None.
"""

generator = self.generate_streaming(
text=text,
references=references,
seed=seed,
streaming=streaming,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
)

if streaming:
return generator
else:
audio = np.concatenate(list(generator))
return audio
1 change: 0 additions & 1 deletion fish_speech/models/vqgan/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


def load_model(config_name, checkpoint_path, device="cuda"):
hydra.core.global_hydra.GlobalHydra.instance().clear()
with initialize(version_base="1.3", config_path="../../configs"):
cfg = compose(config_name=config_name)

Expand Down
12 changes: 11 additions & 1 deletion fish_speech/utils/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import queue
from dataclasses import dataclass
from typing import Literal
from typing import Any, Literal

import torch
from pydantic import BaseModel, Field, conint, conlist, model_validator
Expand Down Expand Up @@ -158,13 +158,23 @@ def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"


class Reference(BaseModel):
tokens: torch.Tensor
text: str

# Allow arbitrary types for pytorch related types
class Config:
arbitrary_types_allowed = True


class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
preprocessed_references: list[Reference] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
Expand Down
2 changes: 1 addition & 1 deletion tools/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def check_and_download_files(repo_id, file_list, local_dir):
repo_id_1 = "fishaudio/fish-speech-1.5"
local_dir_1 = "./checkpoints/fish-speech-1.5"
files_1 = [
"gitattributes",
".gitattributes",
"model.pth",
"README.md",
"special_tokens.json",
Expand Down