|
| 1 | +# Copyright (c) Kyutai, all rights reserved. |
| 2 | +# This source code is licensed under the license found in the |
| 3 | +# LICENSE file in the root directory of this source tree. |
| 4 | + |
| 5 | +import argparse |
| 6 | +import random |
| 7 | +import time |
| 8 | + |
| 9 | +from huggingface_hub import hf_hub_download |
| 10 | +import numpy as np |
| 11 | +import sphn |
| 12 | +import torch |
| 13 | +from torch.profiler import profile, ProfilerActivity |
| 14 | + |
| 15 | +from moshi.models import loaders |
| 16 | + |
| 17 | +import torch.nn as nn |
| 18 | + |
| 19 | +from executorch.examples.models.llama.llama_transformer import Transformer |
| 20 | + |
| 21 | +from executorch.examples.models.llama.model_args import ModelArgs |
| 22 | + |
| 23 | +from torch.export import export, export_for_training, ExportedProgram |
| 24 | + |
| 25 | +from executorch.exir import ( |
| 26 | + EdgeCompileConfig, |
| 27 | + ExecutorchBackendConfig, |
| 28 | + to_edge_transform_and_lower, |
| 29 | +) |
| 30 | +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
| 31 | + |
| 32 | +import torchaudio |
| 33 | +import requests |
| 34 | +import io |
| 35 | + |
| 36 | +from pydub import AudioSegment |
| 37 | +import numpy as np |
| 38 | + |
| 39 | +import torchaudio |
| 40 | +import requests |
| 41 | +import io |
| 42 | + |
| 43 | + |
| 44 | +def read_mp3_from_url(url): |
| 45 | + response = requests.get(url) |
| 46 | + response.raise_for_status() # Ensure request is successful |
| 47 | + |
| 48 | + # Convert to a file-like object |
| 49 | + audio_stream = io.BytesIO(response.content) |
| 50 | + |
| 51 | + # Load audio using torchaudio |
| 52 | + waveform, sample_rate = torchaudio.load(audio_stream, format="mp3") |
| 53 | + |
| 54 | + return waveform.numpy(), sample_rate |
| 55 | + |
| 56 | +# Read the MP3 file |
| 57 | + |
| 58 | +parser = argparse.ArgumentParser() |
| 59 | +parser.add_argument("--mimi-weight", type=str) |
| 60 | +parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) |
| 61 | +parser.add_argument( |
| 62 | + "--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu" |
| 63 | +) |
| 64 | +parser.add_argument("--profile", action="store_true") |
| 65 | +args = parser.parse_args() |
| 66 | + |
| 67 | + |
| 68 | +def seed_all(seed): |
| 69 | + torch.manual_seed(seed) |
| 70 | + if torch.cuda.is_available(): |
| 71 | + torch.cuda.manual_seed(seed) |
| 72 | + torch.cuda.manual_seed_all(seed) # for multi-GPU setups |
| 73 | + random.seed(seed) |
| 74 | + np.random.seed(seed) |
| 75 | + torch.backends.cudnn.deterministic = True |
| 76 | + torch.backends.cudnn.benchmark = False |
| 77 | + |
| 78 | + |
| 79 | +seed_all(42424242) |
| 80 | + |
| 81 | + |
| 82 | +print("loading mimi") |
| 83 | +if args.mimi_weight is None: |
| 84 | + args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) |
| 85 | +mimi = loaders.get_mimi(args.mimi_weight, args.device) |
| 86 | +print("mimi loaded") |
| 87 | +# emb = torch.load('emb.pt') |
| 88 | + |
| 89 | +def mimi_test(mimi, max_duration_sec=10.0): |
| 90 | + url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" |
| 91 | + sample_pcm, sample_sr = read_mp3_from_url(url) |
| 92 | + pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) |
| 93 | + sample_rate = mimi.sample_rate |
| 94 | + # Uncomment below to get real audio |
| 95 | + # # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 |
| 96 | + # sample_pcm, sample_sr = sphn.read("/Users/myuan/src/moshi0/src/moshi/data/bria-24khz.mp3") |
| 97 | + # print("loaded pcm", sample_pcm.shape, sample_sr) |
| 98 | + # sample_pcm = sphn.resample( |
| 99 | + # sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate |
| 100 | + # ) |
| 101 | + sample_pcm = torch.tensor(sample_pcm, device=args.device) |
| 102 | + max_duration_len = int(sample_rate * max_duration_sec) |
| 103 | + if sample_pcm.shape[-1] > max_duration_len: |
| 104 | + sample_pcm = sample_pcm[..., :max_duration_len] |
| 105 | + # print("resampled pcm", sample_pcm.shape, sample_sr) |
| 106 | + sample_pcm = sample_pcm[None].to(device=args.device) |
| 107 | + # |
| 108 | + # sample_pcm = torch.ones(1,1,240000) |
| 109 | + |
| 110 | + print("streaming encoding...") |
| 111 | + start_time = time.time() |
| 112 | + all_codes = [] |
| 113 | + def run_loop(): |
| 114 | + for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): |
| 115 | + end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) |
| 116 | + chunk = sample_pcm[..., start_idx:end_idx] |
| 117 | + codes = mimi.encode(chunk) |
| 118 | + if codes.shape[-1]: |
| 119 | + print(start_idx, codes.shape, end="\r") |
| 120 | + all_codes.append(codes) |
| 121 | + |
| 122 | + if args.profile: |
| 123 | + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: |
| 124 | + run_loop() |
| 125 | + prof.export_chrome_trace("trace.json") |
| 126 | + else: |
| 127 | + run_loop() |
| 128 | + all_codes_th = torch.cat(all_codes, dim=-1) |
| 129 | + print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s") |
| 130 | + print("streaming decoding...") |
| 131 | + all_pcms = [] |
| 132 | + with mimi.streaming(1): |
| 133 | + for i in range(all_codes_th.shape[-1]): |
| 134 | + codes = all_codes_th[..., i : i + 1] |
| 135 | + pcm = mimi.decode(codes) |
| 136 | + print(i, pcm.shape, end="\r") |
| 137 | + all_pcms.append(pcm) |
| 138 | + all_pcms = torch.cat(all_pcms, dim=-1) |
| 139 | + # print("pcm", all_pcms.shape, all_pcms.dtype) |
| 140 | + # sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate) |
| 141 | + pcm_ref = mimi.decode(all_codes_th) |
| 142 | + assert(torch.allclose(pcm_ref,all_pcms, atol=1e-5)) |
| 143 | + |
| 144 | + import matplotlib.pyplot as plt |
| 145 | + |
| 146 | + # Create a 1D tensor |
| 147 | + |
| 148 | + # Convert tensor to numpy for plotting |
| 149 | + diff = all_pcms - pcm_ref |
| 150 | + diff = diff.squeeze() |
| 151 | + length = diff.size(-1) |
| 152 | + x = torch.arange(0, length) |
| 153 | + plt.plot(x.numpy(), diff.numpy(), label="diff") |
| 154 | + plt.xlabel("x") |
| 155 | + plt.ylabel("y") |
| 156 | + plt.title("Line Plot of Tensor") |
| 157 | + plt.legend() |
| 158 | + plt.show(block=True) |
| 159 | + class MimiDecode(nn.Module): |
| 160 | + def __init__(self, mimi: nn.Module): |
| 161 | + super().__init__() |
| 162 | + self.mimi_model = mimi |
| 163 | + |
| 164 | + def forward(self, x): |
| 165 | + return self.mimi_model.decode(x) |
| 166 | + |
| 167 | + mimi_decode = MimiDecode(mimi) |
| 168 | + |
| 169 | + ep: ExportedProgram = torch.export.export(mimi_decode, (all_codes_th,), strict=False) |
| 170 | + edge_prog = to_edge_transform_and_lower( |
| 171 | + ep, |
| 172 | + partitioner=[XnnpackPartitioner()], |
| 173 | + ) |
| 174 | + class MimiEncode(nn.Module): |
| 175 | + def __init__(self, mimi: nn.Module): |
| 176 | + super().__init__() |
| 177 | + self.mimi_model = mimi |
| 178 | + |
| 179 | + def forward(self, x): |
| 180 | + return self.mimi_model.encode(x) |
| 181 | + |
| 182 | + mimi_encode = MimiEncode(mimi) |
| 183 | + chunk = sample_pcm[..., 0:pcm_chunk_size] |
| 184 | + out = mimi_encode(chunk) |
| 185 | + exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module() |
| 186 | + |
| 187 | +with torch.no_grad(): |
| 188 | + mimi_test(mimi) |
| 189 | + |
0 commit comments