Skip to content

Commit 1fe92ee

Browse files
author
Martin Yuan
committed
Export Mimi model to ExecuTorch
1 parent afcec1d commit 1fe92ee

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -x
9+
10+
pip install -U moshi
11+
pip install bitsandbytes
12+
# Run llama2/install requirements for torchao deps
13+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
14+
15+
bash "$SCRIPT_DIR"/../llama/install_requirements.sh
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Kyutai, all rights reserved.
2+
# This source code is licensed under the license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
import argparse
6+
import random
7+
import time
8+
9+
from huggingface_hub import hf_hub_download
10+
import numpy as np
11+
import sphn
12+
import torch
13+
from torch.profiler import profile, ProfilerActivity
14+
15+
from moshi.models import loaders
16+
17+
import torch.nn as nn
18+
19+
from executorch.examples.models.llama.llama_transformer import Transformer
20+
21+
from executorch.examples.models.llama.model_args import ModelArgs
22+
23+
from torch.export import export, export_for_training, ExportedProgram
24+
25+
from executorch.exir import (
26+
EdgeCompileConfig,
27+
ExecutorchBackendConfig,
28+
to_edge_transform_and_lower,
29+
)
30+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
31+
32+
import torchaudio
33+
import requests
34+
import io
35+
36+
from pydub import AudioSegment
37+
import numpy as np
38+
39+
import torchaudio
40+
import requests
41+
import io
42+
43+
44+
def read_mp3_from_url(url):
45+
response = requests.get(url)
46+
response.raise_for_status() # Ensure request is successful
47+
48+
# Convert to a file-like object
49+
audio_stream = io.BytesIO(response.content)
50+
51+
# Load audio using torchaudio
52+
waveform, sample_rate = torchaudio.load(audio_stream, format="mp3")
53+
54+
return waveform.numpy(), sample_rate
55+
56+
# Read the MP3 file
57+
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument("--mimi-weight", type=str)
60+
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO)
61+
parser.add_argument(
62+
"--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu"
63+
)
64+
parser.add_argument("--profile", action="store_true")
65+
args = parser.parse_args()
66+
67+
68+
def seed_all(seed):
69+
torch.manual_seed(seed)
70+
if torch.cuda.is_available():
71+
torch.cuda.manual_seed(seed)
72+
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
73+
random.seed(seed)
74+
np.random.seed(seed)
75+
torch.backends.cudnn.deterministic = True
76+
torch.backends.cudnn.benchmark = False
77+
78+
79+
seed_all(42424242)
80+
81+
82+
print("loading mimi")
83+
if args.mimi_weight is None:
84+
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
85+
mimi = loaders.get_mimi(args.mimi_weight, args.device)
86+
print("mimi loaded")
87+
# emb = torch.load('emb.pt')
88+
89+
def mimi_test(mimi, max_duration_sec=10.0):
90+
url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
91+
sample_pcm, sample_sr = read_mp3_from_url(url)
92+
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
93+
sample_rate = mimi.sample_rate
94+
# Uncomment below to get real audio
95+
# # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
96+
# sample_pcm, sample_sr = sphn.read("/Users/myuan/src/moshi0/src/moshi/data/bria-24khz.mp3")
97+
# print("loaded pcm", sample_pcm.shape, sample_sr)
98+
# sample_pcm = sphn.resample(
99+
# sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
100+
# )
101+
sample_pcm = torch.tensor(sample_pcm, device=args.device)
102+
max_duration_len = int(sample_rate * max_duration_sec)
103+
if sample_pcm.shape[-1] > max_duration_len:
104+
sample_pcm = sample_pcm[..., :max_duration_len]
105+
# print("resampled pcm", sample_pcm.shape, sample_sr)
106+
sample_pcm = sample_pcm[None].to(device=args.device)
107+
#
108+
# sample_pcm = torch.ones(1,1,240000)
109+
110+
print("streaming encoding...")
111+
start_time = time.time()
112+
all_codes = []
113+
def run_loop():
114+
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
115+
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
116+
chunk = sample_pcm[..., start_idx:end_idx]
117+
codes = mimi.encode(chunk)
118+
if codes.shape[-1]:
119+
print(start_idx, codes.shape, end="\r")
120+
all_codes.append(codes)
121+
122+
if args.profile:
123+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
124+
run_loop()
125+
prof.export_chrome_trace("trace.json")
126+
else:
127+
run_loop()
128+
all_codes_th = torch.cat(all_codes, dim=-1)
129+
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
130+
print("streaming decoding...")
131+
all_pcms = []
132+
with mimi.streaming(1):
133+
for i in range(all_codes_th.shape[-1]):
134+
codes = all_codes_th[..., i : i + 1]
135+
pcm = mimi.decode(codes)
136+
print(i, pcm.shape, end="\r")
137+
all_pcms.append(pcm)
138+
all_pcms = torch.cat(all_pcms, dim=-1)
139+
# print("pcm", all_pcms.shape, all_pcms.dtype)
140+
# sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
141+
pcm_ref = mimi.decode(all_codes_th)
142+
assert(torch.allclose(pcm_ref,all_pcms, atol=1e-5))
143+
144+
import matplotlib.pyplot as plt
145+
146+
# Create a 1D tensor
147+
148+
# Convert tensor to numpy for plotting
149+
diff = all_pcms - pcm_ref
150+
diff = diff.squeeze()
151+
length = diff.size(-1)
152+
x = torch.arange(0, length)
153+
plt.plot(x.numpy(), diff.numpy(), label="diff")
154+
plt.xlabel("x")
155+
plt.ylabel("y")
156+
plt.title("Line Plot of Tensor")
157+
plt.legend()
158+
plt.show(block=True)
159+
class MimiDecode(nn.Module):
160+
def __init__(self, mimi: nn.Module):
161+
super().__init__()
162+
self.mimi_model = mimi
163+
164+
def forward(self, x):
165+
return self.mimi_model.decode(x)
166+
167+
mimi_decode = MimiDecode(mimi)
168+
169+
ep: ExportedProgram = torch.export.export(mimi_decode, (all_codes_th,), strict=False)
170+
edge_prog = to_edge_transform_and_lower(
171+
ep,
172+
partitioner=[XnnpackPartitioner()],
173+
)
174+
class MimiEncode(nn.Module):
175+
def __init__(self, mimi: nn.Module):
176+
super().__init__()
177+
self.mimi_model = mimi
178+
179+
def forward(self, x):
180+
return self.mimi_model.encode(x)
181+
182+
mimi_encode = MimiEncode(mimi)
183+
chunk = sample_pcm[..., 0:pcm_chunk_size]
184+
out = mimi_encode(chunk)
185+
exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module()
186+
187+
with torch.no_grad():
188+
mimi_test(mimi)
189+

0 commit comments

Comments
 (0)