Skip to content

Commit 27bacff

Browse files
authored
Export Mimi model to ExecuTorch
Differential Revision: D71039057 Pull Request resolved: #8753
1 parent b5d8e3b commit 27bacff

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -x
9+
10+
pip install -U moshi
11+
pip install bitsandbytes
12+
# Run llama2/install requirements for torchao deps
13+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
14+
15+
bash "$SCRIPT_DIR"/../llama/install_requirements.sh
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import io
2+
import os
3+
import random
4+
import unittest
5+
6+
import numpy as np
7+
import requests
8+
import torch
9+
import torch.nn as nn
10+
import torchaudio
11+
12+
from huggingface_hub import hf_hub_download
13+
from moshi.models import loaders
14+
from torch.export import export, ExportedProgram
15+
16+
17+
def read_mp3_from_url(url):
18+
response = requests.get(url)
19+
response.raise_for_status() # Ensure request is successful
20+
audio_stream = io.BytesIO(response.content)
21+
waveform, sample_rate = torchaudio.load(audio_stream, format="mp3")
22+
return waveform.numpy(), sample_rate
23+
24+
25+
class TestMimiModel(unittest.TestCase):
26+
@classmethod
27+
def setUpClass(cls):
28+
"""Setup once for all tests: Load model and prepare test data."""
29+
30+
# Get environment variables (if set), otherwise use default values
31+
mimi_weight = os.getenv("MIMI_WEIGHT", None)
32+
hf_repo = os.getenv("HF_REPO", loaders.DEFAULT_REPO)
33+
device = "cuda" if torch.cuda.device_count() else "cpu"
34+
35+
def seed_all(seed):
36+
torch.manual_seed(seed)
37+
if torch.cuda.is_available():
38+
torch.cuda.manual_seed(seed)
39+
torch.cuda.manual_seed_all(seed)
40+
random.seed(seed)
41+
np.random.seed(seed)
42+
torch.backends.cudnn.deterministic = True
43+
torch.backends.cudnn.benchmark = False
44+
45+
seed_all(42424242)
46+
47+
if mimi_weight is None:
48+
mimi_weight = hf_hub_download(hf_repo, loaders.MIMI_NAME)
49+
cls.mimi = loaders.get_mimi(mimi_weight, device)
50+
cls.device = device
51+
cls.sample_pcm, cls.sample_sr = read_mp3_from_url(
52+
"https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
53+
)
54+
55+
def test_mp3_loading(self):
56+
"""Ensure MP3 file loads correctly."""
57+
self.assertIsInstance(self.sample_pcm, np.ndarray)
58+
self.assertGreater(self.sample_sr, 0)
59+
60+
def test_encoding(self):
61+
"""Ensure encoding produces expected tensor shape."""
62+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
63+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)
64+
sample_pcm = sample_pcm[None]
65+
chunk = sample_pcm[..., 0:pcm_chunk_size]
66+
encoded = self.mimi.encode(chunk)
67+
self.assertIsInstance(encoded, torch.Tensor)
68+
self.assertGreater(encoded.shape[-1], 0)
69+
70+
def test_decoding(self):
71+
"""Ensure decoding produces expected output."""
72+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
73+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
74+
chunk = sample_pcm[..., 0:pcm_chunk_size]
75+
encoded = self.mimi.encode(chunk)
76+
decoded = self.mimi.decode(encoded)
77+
self.assertIsInstance(decoded, torch.Tensor)
78+
79+
def test_streaming_encoding_decoding(self):
80+
"""Test streaming encoding and decoding consistency."""
81+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
82+
sample_rate = self.mimi.sample_rate
83+
max_duration_sec = 10.0
84+
max_duration_len = int(sample_rate * max_duration_sec)
85+
86+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)
87+
if sample_pcm.shape[-1] > max_duration_len:
88+
sample_pcm = sample_pcm[..., :max_duration_len]
89+
sample_pcm = sample_pcm[None].to(device=self.device)
90+
91+
all_codes = []
92+
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
93+
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
94+
chunk = sample_pcm[..., start_idx:end_idx]
95+
codes = self.mimi.encode(chunk)
96+
if codes.shape[-1]:
97+
all_codes.append(codes)
98+
99+
all_codes_th = torch.cat(all_codes, dim=-1)
100+
101+
all_pcms = []
102+
with self.mimi.streaming(1):
103+
for i in range(all_codes_th.shape[-1]):
104+
codes = all_codes_th[..., i : i + 1]
105+
pcm = self.mimi.decode(codes)
106+
all_pcms.append(pcm)
107+
all_pcms = torch.cat(all_pcms, dim=-1)
108+
109+
pcm_ref = self.mimi.decode(all_codes_th)
110+
self.assertTrue(torch.allclose(pcm_ref, all_pcms, atol=1e-5))
111+
112+
def test_exported_decoding(self):
113+
"""Ensure exported decoding model is consistent with reference output."""
114+
115+
class MimiDecode(nn.Module):
116+
def __init__(self, mimi: nn.Module):
117+
super().__init__()
118+
self.mimi_model = mimi
119+
120+
def forward(self, x):
121+
return self.mimi_model.decode(x)
122+
123+
sample_pcm = torch.tensor(self.sample_pcm, device=self.device)[None]
124+
pcm_chunk_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
125+
chunk = sample_pcm[..., 0:pcm_chunk_size]
126+
input = self.mimi.encode(chunk)
127+
128+
mimi_decode = MimiDecode(self.mimi)
129+
ref_decode_output = mimi_decode(input)
130+
exported_decode: ExportedProgram = export(mimi_decode, (input,), strict=False)
131+
ep_decode_output = exported_decode.module()(input)
132+
self.assertTrue(torch.allclose(ep_decode_output, ref_decode_output, atol=1e-6))
133+
134+
def test_exported_encoding(self):
135+
"""Ensure exported encoding model is consistent with reference output."""
136+
137+
class MimiEncode(nn.Module):
138+
def __init__(self, mimi: nn.Module):
139+
super().__init__()
140+
self.mimi_model = mimi
141+
142+
def forward(self, x):
143+
return self.mimi_model.encode(x)
144+
145+
mimi_encode = MimiEncode(self.mimi)
146+
chunk = torch.tensor(self.sample_pcm, device=self.device)[None][
147+
..., 0 : int(self.mimi.sample_rate / self.mimi.frame_rate)
148+
]
149+
ref_encode_output = mimi_encode(chunk)
150+
exported_encode = export(mimi_encode, (chunk,), strict=False)
151+
ep_encode_output = exported_encode.module()(chunk)
152+
self.assertTrue(torch.allclose(ep_encode_output, ref_encode_output, atol=1e-6))
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main()

0 commit comments

Comments
 (0)