Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametrising fast inference so that finetuned models can be used #113

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<a target="_blank" style="display: inline-block; vertical-align: middle" href="https://colab.research.google.com/github/metavoiceio/metavoice-src/blob/main/colab_demo.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
[![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM)
[![](https://dcbadge.vercel.app/api/server/Cpy6U3na8Z?style=flat&compact=True)](https://discord.gg/tbTbkGEgJM)
[![Twitter](https://img.shields.io/twitter/url/https/twitter.com/OnusFM.svg?style=social&label=@metavoiceio)](https://twitter.com/metavoiceio)


Expand Down Expand Up @@ -69,7 +69,7 @@ poetry install && poetry run pip install torch==2.2.1 torchaudio==2.2.1
## Usage
1. Download it and use it anywhere (including locally) with our [reference implementation](/fam/llm/fast_inference.py)
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
poetry run python -i fam/llm/fast_inference.py

Expand All @@ -82,7 +82,7 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s

2. Deploy it on any cloud (AWS/GCP/Azure), using our [inference server](serving.py) or [web UI](app.py)
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.
poetry run python serving.py
poetry run python app.py
Expand All @@ -108,6 +108,11 @@ Try it out using our sample datasets via:
poetry run finetune --train ./datasets/sample_dataset.csv --val ./datasets/sample_val_dataset.csv
```

Once you've trained your model, you can use it for inference via:
```bash
poetry run python -i fam/llm/fast_inference.py --first_stage_path ./my-finetuned_model.pt
```

### Configuration

In order to set hyperparameters such as learning rate, what to freeze, etc, you
Expand Down
7 changes: 6 additions & 1 deletion fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
seed: int = 1337,
output_dir: str = "outputs",
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
first_stage_path: Optional[str] = None,
):
"""
Initialise the TTS model.
Expand All @@ -54,6 +55,7 @@ def __init__(
- None for no quantisation (bf16 or fp16 based on device),
- int4 for int4 weight-only quantisation,
- int8 for int8 weight-only quantisation.
first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`.
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
Expand All @@ -64,6 +66,9 @@ def __init__(
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
if first_stage_path:
print(f"Overriding first stage checkpoint via provided model: {first_stage_path}")
first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
Expand All @@ -85,7 +90,7 @@ def __init__(
self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
checkpoint_path=Path(first_stage_ckpt),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
Expand Down
Loading