@@ -5,7 +5,8 @@ use ndarray::{s, Axis};
5
5
use ndarray_stats:: QuantileExt ;
6
6
use ratchet:: { shape, Device , DeviceRequest , Tensor } ;
7
7
use ratchet_hub:: { ApiBuilder , RepoType } ;
8
- use ratchet_models:: whisper:: { Whisper , WhisperDecoder , WhisperEncoder } ;
8
+ use ratchet_loader:: gguf:: gguf;
9
+ use ratchet_models:: whisper:: { Config , Whisper , WhisperDecoder , WhisperEncoder } ;
9
10
use ratchet_nn:: Module ;
10
11
use std:: path:: PathBuf ;
11
12
use wasm_bindgen:: prelude:: * ;
@@ -21,22 +22,24 @@ fn log_init() {
21
22
async fn tiny_encoder ( ) -> Result < ( ) , JsValue > {
22
23
log_init ( ) ;
23
24
let model_repo = ApiBuilder :: from_hf ( "FL33TW00D-HF/whisper-tiny" , RepoType :: Model ) . build ( ) ;
24
- let model_data = model_repo. get ( "tiny_f32.bin" ) . await ?;
25
+ let model_data = model_repo. get ( "tiny_f32.gguf" ) . await ?;
26
+ let config_data = model_repo. get ( "config.json" ) . await ?;
25
27
26
28
let ground_repo = ApiBuilder :: from_hf ( "FL33TW00D-HF/ratchet-util" , RepoType :: Dataset ) . build ( ) ;
27
29
let input_npy = ground_repo. get ( "jfk_tiny_encoder_input.npy" ) . await ?;
28
30
let ground_npy = ground_repo. get ( "jfk_tiny_encoder_hs.npy" ) . await ?;
29
31
30
32
let mut reader = std:: io:: BufReader :: new ( std:: io:: Cursor :: new ( model_data. to_vec ( ) ) ) ;
31
- let gg = Whisper :: load_ggml ( & mut reader) . unwrap ( ) ;
33
+ let header = gguf:: Header :: read ( & mut reader) . unwrap ( ) ;
34
+ let config: Config = serde_json:: from_slice ( & config_data. to_vec ( ) ) . unwrap ( ) ;
32
35
33
36
let device = Device :: request_device ( DeviceRequest :: GPU ) . await . unwrap ( ) ;
34
37
35
38
let input_data = & input_npy. to_vec ( ) ;
36
39
let input = Tensor :: from_npy_bytes :: < f32 > ( input_data, & device) . unwrap ( ) ;
37
40
let ground = Tensor :: from_npy_bytes :: < f32 > ( & ground_npy. to_vec ( ) , & Device :: CPU ) . unwrap ( ) ;
38
41
39
- let encoder = WhisperEncoder :: load ( & gg , & mut reader, & device) . unwrap ( ) ;
42
+ let encoder = WhisperEncoder :: load ( & header , & config , & mut reader, & device) . unwrap ( ) ;
40
43
let result = encoder. schedule ( input) . unwrap ( ) . resolve ( ) . unwrap ( ) ;
41
44
let ours = result. to ( & Device :: CPU ) . await . unwrap ( ) ;
42
45
ground. all_close ( & ours, 1e-3 , 1e-3 ) . unwrap ( ) ;
@@ -46,18 +49,19 @@ async fn tiny_encoder() -> Result<(), JsValue> {
46
49
#[ wasm_bindgen_test]
47
50
async fn tiny_decoder ( ) -> Result < ( ) , JsValue > {
48
51
let model_repo = ApiBuilder :: from_hf ( "FL33TW00D-HF/whisper-tiny" , RepoType :: Model ) . build ( ) ;
49
- let model_data = model_repo. get ( "tiny_f32.bin" ) . await ?;
52
+ let model_data = model_repo. get ( "tiny_f32.gguf" ) . await ?;
53
+ let config_data = model_repo. get ( "config.json" ) . await ?;
50
54
51
55
let ground_repo = ApiBuilder :: from_hf ( "FL33TW00D-HF/ratchet-util" , RepoType :: Dataset ) . build ( ) ;
52
56
let hs_data = ground_repo. get ( "jfk_tiny_encoder_hs.npy" ) . await ?;
53
57
54
58
let mut reader = std:: io:: BufReader :: new ( std:: io:: Cursor :: new ( model_data. to_vec ( ) ) ) ;
55
- let gg_disk = Whisper :: load_ggml ( & mut reader) . unwrap ( ) ;
56
- assert_eq ! ( gg_disk . tensors . len ( ) , 167 ) ;
59
+ let header = gguf :: Header :: read ( & mut reader) . unwrap ( ) ;
60
+ let config : Config = serde_json :: from_slice ( & config_data . to_vec ( ) ) . unwrap ( ) ;
57
61
58
62
let device = Device :: request_device ( DeviceRequest :: GPU ) . await . unwrap ( ) ;
59
63
let audio_ctx = Tensor :: from_npy_bytes :: < f32 > ( & hs_data. to_vec ( ) , & device) . unwrap ( ) ;
60
- let mut decoder = WhisperDecoder :: load ( & gg_disk , & mut reader, & device) . unwrap ( ) ;
64
+ let mut decoder = WhisperDecoder :: load ( & header , & config , & mut reader, & device) . unwrap ( ) ;
61
65
62
66
let mut tokens = vec ! [ 50258 , 50259 , 50359 ] ;
63
67
let mut all_tokens = tokens. clone ( ) ;
0 commit comments