Skip to content

Commit 2657110

Browse files
author
Martin Yuan
committed
Export Mimi model to ExecuTorch
1 parent fae734f commit 2657110

File tree

4 files changed

+333
-1
lines changed

4 files changed

+333
-1
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -x
9+
10+
pip install -U moshi
11+
pip install bitsandbytes
12+
# Run llama2/install requirements for torchao deps
13+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
14+
15+
bash "$SCRIPT_DIR"/../llama/install_requirements.sh
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import argparse
2+
import io
3+
import random
4+
import time
5+
6+
import numpy as np
7+
import requests
8+
import torch
9+
10+
import torch.nn as nn
11+
12+
import torchaudio
13+
14+
from huggingface_hub import hf_hub_download
15+
16+
from moshi.models import loaders
17+
18+
from pydub import AudioSegment
19+
20+
from torch.export import export, export_for_training, ExportedProgram
21+
22+
23+
def read_mp3_from_url(url):
24+
response = requests.get(url)
25+
response.raise_for_status() # Ensure request is successful
26+
27+
# Convert to a file-like object
28+
audio_stream = io.BytesIO(response.content)
29+
30+
# Load audio using torchaudio
31+
waveform, sample_rate = torchaudio.load(audio_stream, format="mp3")
32+
33+
return waveform.numpy(), sample_rate
34+
35+
36+
# Read the MP3 file
37+
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument("--mimi-weight", type=str)
40+
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO)
41+
parser.add_argument(
42+
"--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu"
43+
)
44+
parser.add_argument("--profile", action="store_true")
45+
args = parser.parse_args()
46+
47+
48+
def seed_all(seed):
49+
torch.manual_seed(seed)
50+
if torch.cuda.is_available():
51+
torch.cuda.manual_seed(seed)
52+
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
53+
random.seed(seed)
54+
np.random.seed(seed)
55+
torch.backends.cudnn.deterministic = True
56+
torch.backends.cudnn.benchmark = False
57+
58+
59+
seed_all(42424242)
60+
61+
62+
print("loading mimi")
63+
if args.mimi_weight is None:
64+
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
65+
mimi = loaders.get_mimi(args.mimi_weight, args.device)
66+
print("mimi loaded")
67+
# emb = torch.load('emb.pt')
68+
69+
70+
def mimi_test(mimi, max_duration_sec=10.0):
71+
url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
72+
sample_pcm, sample_sr = read_mp3_from_url(url)
73+
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
74+
sample_rate = mimi.sample_rate
75+
sample_pcm = torch.tensor(sample_pcm, device=args.device)
76+
max_duration_len = int(sample_rate * max_duration_sec)
77+
if sample_pcm.shape[-1] > max_duration_len:
78+
sample_pcm = sample_pcm[..., :max_duration_len]
79+
sample_pcm = sample_pcm[None].to(device=args.device)
80+
# sample_pcm = torch.ones(1,1,240000)
81+
82+
print("streaming encoding...")
83+
start_time = time.time()
84+
all_codes = []
85+
86+
def run_loop():
87+
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
88+
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
89+
chunk = sample_pcm[..., start_idx:end_idx]
90+
codes = mimi.encode(chunk)
91+
if codes.shape[-1]:
92+
print(start_idx, codes.shape, end="\r")
93+
all_codes.append(codes)
94+
95+
run_loop()
96+
97+
all_codes_th = torch.cat(all_codes, dim=-1)
98+
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
99+
print("streaming decoding...")
100+
all_pcms = []
101+
with mimi.streaming(1):
102+
for i in range(all_codes_th.shape[-1]):
103+
codes = all_codes_th[..., i : i + 1]
104+
pcm = mimi.decode(codes)
105+
print(i, pcm.shape, end="\r")
106+
all_pcms.append(pcm)
107+
all_pcms = torch.cat(all_pcms, dim=-1)
108+
# print("pcm", all_pcms.shape, all_pcms.dtype)
109+
# sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
110+
pcm_ref = mimi.decode(all_codes_th)
111+
# sphn.write_wav("ref.wav", pcm_ref[0, 0].cpu().numpy(), sample_rate)
112+
113+
assert torch.allclose(pcm_ref, all_pcms, atol=1e-5)
114+
115+
class MimiDecode(nn.Module):
116+
def __init__(self, mimi: nn.Module):
117+
super().__init__()
118+
self.mimi_model = mimi
119+
120+
def forward(self, x):
121+
return self.mimi_model.decode(x)
122+
123+
mimi_decode = MimiDecode(mimi)
124+
125+
input = all_codes_th[..., 0:1]
126+
ref_decode_output = mimi_decode(input)
127+
128+
exported_decode: ExportedProgram = torch.export.export(
129+
mimi_decode, (input,), strict=False
130+
)
131+
print(f"Exported program: {exported_decode}")
132+
ep_decode_output = exported_decode.module()(input)
133+
print(f"ep output: {ep_decode_output}")
134+
assert torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6)
135+
136+
# edge_prog = to_edge_transform_and_lower(
137+
# ep,
138+
# partitioner=[XnnpackPartitioner()],
139+
# )
140+
# print(f"Edge program: {edge_prog}")
141+
class MimiEncode(nn.Module):
142+
def __init__(self, mimi: nn.Module):
143+
super().__init__()
144+
self.mimi_model = mimi
145+
146+
def forward(self, x):
147+
return self.mimi_model.encode(x)
148+
149+
mimi_encode = MimiEncode(mimi)
150+
chunk = sample_pcm[..., 0:pcm_chunk_size]
151+
ref_encode_out = mimi_encode(chunk)
152+
exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False)
153+
print(f"Exported encode program: {exported_encode}")
154+
ep_encode_output = exported_encode.module()(chunk)
155+
print(f"ep output: {ep_encode_output}")
156+
assert torch.allclose(ep_encode_output, ref_encode_out, atol=1e-6)
157+
158+
159+
with torch.no_grad():
160+
mimi_test(mimi)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import io
2+
import os
3+
import random
4+
import time
5+
import unittest
6+
7+
import numpy as np
8+
import requests
9+
import torch
10+
import torch.nn as nn
11+
import torchaudio
12+
13+
from huggingface_hub import hf_hub_download
14+
from moshi.models import loaders
15+
from torch.export import export, ExportedProgram
16+
17+
18+
def read_mp3_from_url(url):
19+
response = requests.get(url)
20+
response.raise_for_status() # Ensure request is successful
21+
audio_stream = io.BytesIO(response.content)
22+
waveform, sample_rate = torchaudio.load(audio_stream, format="mp3")
23+
return waveform.numpy(), sample_rate
24+
25+
26+
class TestMimiModel(unittest.TestCase):
27+
@classmethod
28+
def setUpClass(cls):
29+
"""Setup once for all tests: Load model and prepare test data."""
30+
31+
# Get environment variables (if set), otherwise use default values
32+
mimi_weight = os.getenv("MIMI_WEIGHT", None)
33+
hf_repo = os.getenv("HF_REPO", loaders.DEFAULT_REPO)
34+
device = "cuda" if torch.cuda.device_count() else "cpu"
35+
36+
def seed_all(seed):
37+
torch.manual_seed(seed)
38+
if torch.cuda.is_available():
39+
torch.cuda.manual_seed(seed)
40+
torch.cuda.manual_seed_all(seed)
41+
random.seed(seed)
42+
np.random.seed(seed)
43+
torch.backends.cudnn.deterministic = True
44+
torch.backends.cudnn.benchmark = False
45+
46+
seed_all(42424242)
47+
48+
if mimi_weight is None:
49+
mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME)
50+
cls.mimi = loaders.get_mimi(mimi_weight, device)
51+
cls.device = device
52+
cls.sample_pcm, cls.sample_sr = read_mp3_from_url(
53+
"https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
54+
)
55+
56+
def test_mp3_loading(self):
57+
"""Ensure MP3 file loads correctly."""
58+
self.assertIsInstance(self.sample_pcm, np.ndarray)
59+
self.assertGreater(self.sample_sr, 0)
60+
61+
def test_encoding(self):
62+
"""Ensure encoding produces expected tensor shape."""
63+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
64+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)
65+
sample_pcm = sample_pcm[None]
66+
chunk = sample_pcm[..., 0:pcm_chunk_size]
67+
encoded = self.mimi.encode(chunk)
68+
self.assertIsInstance(encoded, torch.Tensor)
69+
self.assertGreater(encoded.shape[-1], 0)
70+
71+
def test_decoding(self):
72+
"""Ensure decoding produces expected output."""
73+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
74+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
75+
chunk = sample_pcm[..., 0:pcm_chunk_size]
76+
encoded = self.mimi.encode(chunk)
77+
decoded = self.mimi.decode(encoded)
78+
self.assertIsInstance(decoded, torch.Tensor)
79+
80+
def test_streaming_encoding_decoding(self):
81+
"""Test streaming encoding and decoding consistency."""
82+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
83+
sample_rate = self.mimi.sample_rate
84+
max_duration_sec = 10.0
85+
max_duration_len = int(sample_rate * max_duration_sec)
86+
87+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)
88+
if sample_pcm.shape[-1] > max_duration_len:
89+
sample_pcm = sample_pcm[..., :max_duration_len]
90+
sample_pcm = sample_pcm[None].to(device=self.device)
91+
92+
all_codes = []
93+
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
94+
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
95+
chunk = sample_pcm[..., start_idx:end_idx]
96+
codes = self.mimi.encode(chunk)
97+
if codes.shape[-1]:
98+
all_codes.append(codes)
99+
100+
all_codes_th = torch.cat(all_codes, dim=-1)
101+
102+
all_pcms = []
103+
with self.mimi.streaming(1):
104+
for i in range(all_codes_th.shape[-1]):
105+
codes = all_codes_th[..., i : i + 1]
106+
pcm = self.mimi.decode(codes)
107+
all_pcms.append(pcm)
108+
all_pcms = torch.cat(all_pcms, dim=-1)
109+
110+
pcm_ref = self.mimi.decode(all_codes_th)
111+
self.assertTrue(torch.allclose(pcm_ref, all_pcms, atol=1e-5))
112+
113+
def test_exported_decoding(self):
114+
"""Ensure exported decoding model is consistent with reference output."""
115+
116+
class MimiDecode(nn.Module):
117+
def __init__(self, mimi: nn.Module):
118+
super().__init__()
119+
self.mimi_model = mimi
120+
121+
def forward(self, x):
122+
return self.mimi_model.decode(x)
123+
124+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
125+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
126+
chunk = sample_pcm[..., 0:pcm_chunk_size]
127+
input = self.mimi.encode(chunk)
128+
129+
mimi_decode = MimiDecode(self.mimi)
130+
ref_decode_output = mimi_decode(input)
131+
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
132+
ep_decode_output = exported_decode.module()(input)
133+
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
134+
135+
def test_exported_encoding(self):
136+
"""Ensure exported encoding model is consistent with reference output."""
137+
138+
class MimiEncode(nn.Module):
139+
def __init__(self, mimi: nn.Module):
140+
super().__init__()
141+
self.mimi_model = mimi
142+
143+
def forward(self, x):
144+
return self.mimi_model.encode(x)
145+
146+
mimi_encode = MimiEncode(self.mimi)
147+
chunk = torch.tensor(self.sample_pcm, device=self.device)[None][
148+
..., 0 : int(self.mimi.sample_rate / self.mimi.frame_rate)
149+
]
150+
ref_encode_output = mimi_encode(chunk)
151+
exported_encode = export(mimi_encode, (chunk,), strict=False)
152+
ep_encode_output = exported_encode.module()(chunk)
153+
self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6))
154+
155+
156+
if __name__ == "__main__":
157+
unittest.main()

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def python_is_compatible():
6767
# NOTE: If a newly-fetched version of the executorch repo changes the value of
6868
# NIGHTLY_VERSION, you should re-run this script to install the necessary
6969
# package versions.
70-
NIGHTLY_VERSION = "dev20250301"
70+
NIGHTLY_VERSION = "dev20250311"
7171

7272

7373
def install_requirements(use_pytorch_nightly):

0 commit comments

Comments
 (0)