Skip to content

Commit 78e8390

Browse files
committed
chore: done
1 parent abdc6d2 commit 78e8390

File tree

5 files changed

+47
-11
lines changed

5 files changed

+47
-11
lines changed

.cargo/config.toml

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
[build]
33
rustflags = [ "--cfg=web_sys_unstable_apis" ]
44
rustdocflags = [ "--cfg=web_sys_unstable_apis" ]
5+
#target = "wasm32-unknown-unknown"

crates/ratchet-models/src/whisper/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub mod tokenizer;
1515
pub mod transcribe;
1616
pub mod transcript;
1717

18+
pub use config::Config;
1819
pub use decoder::WhisperDecoder;
1920
pub use encoder::WhisperEncoder;
2021
pub use model::Whisper;

crates/ratchet-models/src/whisper/model.rs

+32-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use ratchet::NDArrayExt;
1111
use hf_hub::api::sync::Api;
1212

1313
#[cfg(target_arch = "wasm32")]
14-
use {ratchet_hub::ApiBuilder, ratchet_hub::RepoType, wasm_bindgen::JsError};
14+
use {js_sys::Uint8Array, ratchet_hub::ApiBuilder, ratchet_hub::RepoType, wasm_bindgen::JsError};
1515

1616
use crate::registry::WhisperVariants;
1717
use crate::whisper::{options::Language, task::DecodingTask, tokenizer::WhisperTokenizer};
@@ -30,6 +30,7 @@ pub struct Whisper {
3030
}
3131

3232
impl Whisper {
33+
#[cfg(not(target_arch = "wasm32"))]
3334
pub fn load<R: BufRead + Seek>(
3435
header: Header,
3536
reader: &mut R,
@@ -57,14 +58,43 @@ impl Whisper {
5758
})
5859
}
5960

61+
#[cfg(target_arch = "wasm32")]
62+
pub async fn load<R: BufRead + Seek>(
63+
header: Header,
64+
reader: &mut R,
65+
device: Device,
66+
) -> Result<Self, JsError> {
67+
let mel_bytes = Self::fetch_resource(WhisperVariants::Tiny, "melfilters.bytes").await?;
68+
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
69+
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(
70+
&mel_bytes,
71+
&mut mel_filters,
72+
);
73+
let specgen = SpectrogramGenerator::new(mel_filters);
74+
75+
let config: Config = serde_json::from_slice(
76+
&Self::fetch_resource(WhisperVariants::Tiny, "config.json").await?,
77+
)?;
78+
let encoder = WhisperEncoder::load(&header, &config, reader, &device).unwrap();
79+
let decoder = WhisperDecoder::load(&header, &config, reader, &device).unwrap();
80+
81+
Ok(Self {
82+
specgen,
83+
encoder,
84+
decoder,
85+
config,
86+
device,
87+
})
88+
}
89+
6090
#[cfg(target_arch = "wasm32")]
6191
pub async fn fetch_resource(
6292
variant: WhisperVariants,
6393
resource: &str,
6494
) -> Result<Vec<u8>, JsError> {
6595
let repo_id = variant.repo_id();
6696
let model_repo = ApiBuilder::from_hf(repo_id, RepoType::Model).build();
67-
model_repo.get(resource).await?
97+
Ok(model_repo.get(resource).await?.to_vec())
6898
}
6999

70100
#[cfg(not(target_arch = "wasm32"))]

crates/ratchet-models/src/whisper/transcribe.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub async fn transcribe(
9090
callback: Option<impl Fn(StreamedSegment)>,
9191
) -> anyhow::Result<TranscriptionResult> {
9292
let runtime = Instant::now();
93-
let n_mels = model.hparams.n_mels as usize;
93+
let n_mels = model.config.n_mels as usize;
9494
let mel = model.specgen.generate(audio)?.to(&model.device).await?;
9595
let content_frames = mel.shape()[mel.rank() - 1] - N_FRAMES;
9696

crates/ratchet-models/tests/whisper.rs

+12-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use ndarray::{s, Axis};
55
use ndarray_stats::QuantileExt;
66
use ratchet::{shape, Device, DeviceRequest, Tensor};
77
use ratchet_hub::{ApiBuilder, RepoType};
8-
use ratchet_models::whisper::{Whisper, WhisperDecoder, WhisperEncoder};
8+
use ratchet_loader::gguf::gguf;
9+
use ratchet_models::whisper::{Config, Whisper, WhisperDecoder, WhisperEncoder};
910
use ratchet_nn::Module;
1011
use std::path::PathBuf;
1112
use wasm_bindgen::prelude::*;
@@ -21,22 +22,24 @@ fn log_init() {
2122
async fn tiny_encoder() -> Result<(), JsValue> {
2223
log_init();
2324
let model_repo = ApiBuilder::from_hf("FL33TW00D-HF/whisper-tiny", RepoType::Model).build();
24-
let model_data = model_repo.get("tiny_f32.bin").await?;
25+
let model_data = model_repo.get("tiny_f32.gguf").await?;
26+
let config_data = model_repo.get("config.json").await?;
2527

2628
let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build();
2729
let input_npy = ground_repo.get("jfk_tiny_encoder_input.npy").await?;
2830
let ground_npy = ground_repo.get("jfk_tiny_encoder_hs.npy").await?;
2931

3032
let mut reader = std::io::BufReader::new(std::io::Cursor::new(model_data.to_vec()));
31-
let gg = Whisper::load_ggml(&mut reader).unwrap();
33+
let header = gguf::Header::read(&mut reader).unwrap();
34+
let config: Config = serde_json::from_slice(&config_data.to_vec()).unwrap();
3235

3336
let device = Device::request_device(DeviceRequest::GPU).await.unwrap();
3437

3538
let input_data = &input_npy.to_vec();
3639
let input = Tensor::from_npy_bytes::<f32>(input_data, &device).unwrap();
3740
let ground = Tensor::from_npy_bytes::<f32>(&ground_npy.to_vec(), &Device::CPU).unwrap();
3841

39-
let encoder = WhisperEncoder::load(&gg, &mut reader, &device).unwrap();
42+
let encoder = WhisperEncoder::load(&header, &config, &mut reader, &device).unwrap();
4043
let result = encoder.schedule(input).unwrap().resolve().unwrap();
4144
let ours = result.to(&Device::CPU).await.unwrap();
4245
ground.all_close(&ours, 1e-3, 1e-3).unwrap();
@@ -46,18 +49,19 @@ async fn tiny_encoder() -> Result<(), JsValue> {
4649
#[wasm_bindgen_test]
4750
async fn tiny_decoder() -> Result<(), JsValue> {
4851
let model_repo = ApiBuilder::from_hf("FL33TW00D-HF/whisper-tiny", RepoType::Model).build();
49-
let model_data = model_repo.get("tiny_f32.bin").await?;
52+
let model_data = model_repo.get("tiny_f32.gguf").await?;
53+
let config_data = model_repo.get("config.json").await?;
5054

5155
let ground_repo = ApiBuilder::from_hf("FL33TW00D-HF/ratchet-util", RepoType::Dataset).build();
5256
let hs_data = ground_repo.get("jfk_tiny_encoder_hs.npy").await?;
5357

5458
let mut reader = std::io::BufReader::new(std::io::Cursor::new(model_data.to_vec()));
55-
let gg_disk = Whisper::load_ggml(&mut reader).unwrap();
56-
assert_eq!(gg_disk.tensors.len(), 167);
59+
let header = gguf::Header::read(&mut reader).unwrap();
60+
let config: Config = serde_json::from_slice(&config_data.to_vec()).unwrap();
5761

5862
let device = Device::request_device(DeviceRequest::GPU).await.unwrap();
5963
let audio_ctx = Tensor::from_npy_bytes::<f32>(&hs_data.to_vec(), &device).unwrap();
60-
let mut decoder = WhisperDecoder::load(&gg_disk, &mut reader, &device).unwrap();
64+
let mut decoder = WhisperDecoder::load(&header, &config, &mut reader, &device).unwrap();
6165

6266
let mut tokens = vec![50258, 50259, 50359];
6367
let mut all_tokens = tokens.clone();

0 commit comments

Comments
 (0)