From 3f67f8034d3b3280034f89ae0635e4fae839812d Mon Sep 17 00:00:00 2001 From: Hubert Siuzdak <35269911+hubertsiuzdak@users.noreply.github.com> Date: Mon, 12 Jun 2023 18:50:07 +0200 Subject: [PATCH] Bark integration (#4) * Update Vocos API to handle discrete codes input * Inference mode * Add example notebook for Bark integration --- notebooks/Bark+Vocos.ipynb | 265 +++++++++++++++++++++++++++++++++++++ vocos/pretrained.py | 46 +++++++ 2 files changed, 311 insertions(+) create mode 100644 notebooks/Bark+Vocos.ipynb diff --git a/notebooks/Bark+Vocos.ipynb b/notebooks/Bark+Vocos.ipynb new file mode 100644 index 0000000..6af7365 --- /dev/null +++ b/notebooks/Bark+Vocos.ipynb @@ -0,0 +1,265 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "private_outputs": true, + "provenance": [], + "gpuType": "T4", + "authorship_tag": "ABX9TyNuxsqp/FTsmltYeYfMZ6sw", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Text-to-Audio Synthesis using Bark and Vocos" + ], + "metadata": { + "id": "NuRzVtHDZ_Gl" + } + }, + { + "cell_type": "markdown", + "source": [ + "In this notebook, we use Bark generative model to turn a text prompt into EnCodec audio tokens. These tokens then go through two decoders, EnCodec and Vocos, to reconstruct the audio waveform. Compare the results to discover the differences in audio quality and characteristics." + ], + "metadata": { + "id": "zJFDte0daDAz" + } + }, + { + "cell_type": "markdown", + "source": [ + "Make sure you have Bark and Vocos installed:" + ], + "metadata": { + "id": "c9omqGDYnajY" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install git+https://github.com/suno-ai/bark.git\n", + "!pip install vocos" + ], + "metadata": { + "id": "voH44g90NvtV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Download and load Bark models" + ], + "metadata": { + "id": "s3cEjOIuj6tq" + } + }, + { + "cell_type": "code", + "source": [ + "from bark import preload_models\n", + "\n", + "preload_models()" + ], + "metadata": { + "id": "1H7XtXRMjxUM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Download and load Vocos." + ], + "metadata": { + "id": "YO1m0dJ1j-F5" + } + }, + { + "cell_type": "code", + "source": [ + "from vocos import Vocos\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "vocos = Vocos.from_pretrained(\"charactr/vocos-encodec-24khz\").to(device)" + ], + "metadata": { + "id": "COQYTDDFkBCq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We are going to reuse `text_to_semantic` from Bark API, but to reconstruct audio waveform with a custom vododer, we need to slightly redefine the API to return `fine_tokens`." + ], + "metadata": { + "id": "--RjqW0rk5JQ" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OiUsuN2DNl5S" + }, + "outputs": [], + "source": [ + "from typing import Optional, Union, Dict\n", + "\n", + "import numpy as np\n", + "from bark.generation import generate_coarse, generate_fine\n", + "\n", + "\n", + "def semantic_to_audio_tokens(\n", + " semantic_tokens: np.ndarray,\n", + " history_prompt: Optional[Union[Dict, str]] = None,\n", + " temp: float = 0.7,\n", + " silent: bool = False,\n", + " output_full: bool = False,\n", + "):\n", + " coarse_tokens = generate_coarse(\n", + " semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, use_kv_caching=True\n", + " )\n", + " fine_tokens = generate_fine(coarse_tokens, history_prompt=history_prompt, temp=0.5)\n", + "\n", + " if output_full:\n", + " full_generation = {\n", + " \"semantic_prompt\": semantic_tokens,\n", + " \"coarse_prompt\": coarse_tokens,\n", + " \"fine_prompt\": fine_tokens,\n", + " }\n", + " return full_generation\n", + " return fine_tokens" + ] + }, + { + "cell_type": "markdown", + "source": [ + "Let's create a text prompt and generate audio tokens:" + ], + "metadata": { + "id": "Cv8KCzXlmoF9" + } + }, + { + "cell_type": "code", + "source": [ + "from bark import text_to_semantic\n", + "\n", + "history_prompt = None\n", + "text_prompt = \"So, you've heard about neural vocoding? [laughs] We've been messing around with this new model called Vocos.\"\n", + "semantic_tokens = text_to_semantic(text_prompt, history_prompt=history_prompt, temp=0.7, silent=False,)\n", + "audio_tokens = semantic_to_audio_tokens(\n", + " semantic_tokens, history_prompt=history_prompt, temp=0.7, silent=False, output_full=False,\n", + ")" + ], + "metadata": { + "id": "pDmSTutoOH_G" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Reconstruct audio waveform with EnCodec:" + ], + "metadata": { + "id": "UYMzI8svTNqI" + } + }, + { + "cell_type": "code", + "source": [ + "from bark.generation import codec_decode\n", + "from IPython.display import Audio\n", + "\n", + "encodec_output = codec_decode(audio_tokens)\n", + "\n", + "import torchaudio\n", + "# Upsample to 44100 Hz for better reproduction on audio hardware\n", + "encodec_output = torchaudio.functional.resample(torch.from_numpy(encodec_output), orig_freq=24000, new_freq=44100)\n", + "Audio(encodec_output, rate=44100)" + ], + "metadata": { + "id": "PzdytlXFTNQ2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Reconstruct with Vocos:" + ], + "metadata": { + "id": "BhUxBuP9TTTw" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "\n", + "audio_tokens_torch = torch.from_numpy(audio_tokens).to(device)\n", + "features = vocos.codes_to_features(audio_tokens_torch)\n", + "vocos_output = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device)) # 6 kbps\n", + "# Upsample to 44100 Hz for better reproduction on audio hardware\n", + "vocos_output = torchaudio.functional.resample(vocos_output, orig_freq=24000, new_freq=44100).cpu()\n", + "Audio(vocos_output.numpy(), rate=44100)" + ], + "metadata": { + "id": "8hzSWQ5-nBlV" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Optionally save to mp3 files:" + ], + "metadata": { + "id": "RjVXQIZRb1Re" + } + }, + { + "cell_type": "code", + "source": [ + "torchaudio.save(\"encodec.mp3\", encodec_output[None, :], 44100, compression=128)\n", + "torchaudio.save(\"vocos.mp3\", vocos_output, 44100, compression=128)" + ], + "metadata": { + "id": "PLFXpjUKb3WX" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/vocos/pretrained.py b/vocos/pretrained.py index 00a5883..3d6b5c8 100644 --- a/vocos/pretrained.py +++ b/vocos/pretrained.py @@ -76,20 +76,66 @@ def from_pretrained(self, repo_id: str) -> "Vocos": model.eval() return model + @torch.inference_mode() def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: """ Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, which is then passed through the backbone and the head to reconstruct the audio output. + + Args: + audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), + where B is the batch size and L is the waveform length. + + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). """ features = self.feature_extractor(audio_input, **kwargs) audio_output = self.decode(features, **kwargs) return audio_output + @torch.inference_mode() def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: """ Method to decode audio waveform from already calculated features. The features input is passed through the backbone and the head to reconstruct the audio output. + + Args: + features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, + C denotes the feature dimension, and L is the sequence length. + + Returns: + Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). """ x = self.backbone(features_input, **kwargs) audio_output = self.head(x) return audio_output + + @torch.inference_mode() + def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: + """ + Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's + codebook weights. + + Args: + codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), + where K is the number of codebooks, B is the batch size and L is the sequence length. + + Returns: + Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, + and L is the sequence length. + """ + assert isinstance( + self.feature_extractor, EncodecFeatures + ), "Feature extractor should be an instance of EncodecFeatures" + + if codes.dim() == 2: + codes = codes.unsqueeze(1) + + n_bins = self.feature_extractor.encodec.quantizer.bins + offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) + embeddings_idxs = codes + offsets.view(-1, 1, 1) + features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) + features = features.transpose(1, 2) + + return features