@@ -2917,59 +2917,63 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
29172917 }
29182918
29192919#ifdef GGML_USE_METAL
2920+ // TODO: Param for enable GPU
29202921 state->ctx_metal = ggml_metal_init (1 );
29212922 if (!state->ctx_metal ) {
29222923 log (" %s: ggml_metal_init() failed\n " , __func__);
29232924 delete state;
29242925 return nullptr ;
29252926 }
29262927
2927- log (" %s: Metal context initialized\n " , __func__);
2928+ if (state->ctx_metal ) {
2929+ log (" %s: Metal context initialized\n " , __func__);
29282930
2929- // this allocates all Metal resources and memory buffers
2931+ // this allocates all Metal resources and memory buffers
29302932
2931- void * data_ptr = NULL ;
2932- size_t data_size = 0 ;
2933+ void * data_ptr = NULL ;
2934+ size_t data_size = 0 ;
29332935
2934- // TODO: add mmap support
2935- // if (params.use_mmap) {
2936- // data_ptr = ctx->model.mapping->addr;
2937- // data_size = ctx->model.mapping->size;
2938- // } else {
2939- // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2940- // data_size = ggml_get_mem_size (ctx->model.ctx);
2941- // }
2936+ // TODO: add mmap support
2937+ // if (params.use_mmap) {
2938+ // data_ptr = ctx->model.mapping->addr;
2939+ // data_size = ctx->model.mapping->size;
2940+ // } else {
2941+ // data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
2942+ // data_size = ggml_get_mem_size (ctx->model.ctx);
2943+ // }
29422944
2943- data_ptr = ggml_get_mem_buffer (ctx->model .ctx );
2944- data_size = ggml_get_mem_size (ctx->model .ctx );
2945+ data_ptr = ggml_get_mem_buffer (ctx->model .ctx );
2946+ data_size = ggml_get_mem_size (ctx->model .ctx );
29452947
2946- const size_t max_size = ggml_get_max_tensor_size (ctx->model .ctx );
2948+ const size_t max_size = ggml_get_max_tensor_size (ctx->model .ctx );
29472949
2948- log (" %s: max tensor size = %8.2f MB\n " , __func__, max_size/1024.0 /1024.0 );
2950+ log (" %s: max tensor size = %8.2f MB\n " , __func__, max_size/1024.0 /1024.0 );
29492951
29502952#define WHISPER_METAL_CHECK_BUF (result ) \
2951- if (!(result)) { \
2952- log (" %s: failed to add metal buffer\n " , __func__); \
2953- delete state; \
2954- return nullptr ; \
2955- }
2953+ if (!(result)) { \
2954+ log (" %s: failed to add metal buffer\n " , __func__); \
2955+ delete state; \
2956+ return nullptr ; \
2957+ }
29562958
2957- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data" , data_ptr, data_size, max_size));
2959+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data" , data_ptr, data_size, max_size));
29582960
2959- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_conv" , state->alloc_conv .meta .data (), state->alloc_conv .meta .size (), 0 ));
2960- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_encode" , state->alloc_encode .meta .data (), state->alloc_encode .meta .size (), 0 ));
2961- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_cross" , state->alloc_cross .meta .data (), state->alloc_cross .meta .size (), 0 ));
2962- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_decode" , state->alloc_decode .meta .data (), state->alloc_decode .meta .size (), 0 ));
2961+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_conv" , state->alloc_conv .meta .data (), state->alloc_conv .meta .size (), 0 ));
2962+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_encode" , state->alloc_encode .meta .data (), state->alloc_encode .meta .size (), 0 ));
2963+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_cross" , state->alloc_cross .meta .data (), state->alloc_cross .meta .size (), 0 ));
2964+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " meta_decode" , state->alloc_decode .meta .data (), state->alloc_decode .meta .size (), 0 ));
29632965
2964- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_conv" , state->alloc_conv .data .data (), state->alloc_conv .data .size (), 0 ));
2965- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_encode" , state->alloc_encode .data .data (), state->alloc_encode .data .size (), 0 ));
2966- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_cross" , state->alloc_cross .data .data (), state->alloc_cross .data .size (), 0 ));
2967- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_decode" , state->alloc_decode .data .data (), state->alloc_decode .data .size (), 0 ));
2966+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_conv" , state->alloc_conv .data .data (), state->alloc_conv .data .size (), 0 ));
2967+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_encode" , state->alloc_encode .data .data (), state->alloc_encode .data .size (), 0 ));
2968+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_cross" , state->alloc_cross .data .data (), state->alloc_cross .data .size (), 0 ));
2969+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " data_decode" , state->alloc_decode .data .data (), state->alloc_decode .data .size (), 0 ));
29682970
2969- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_cross" , state->kv_cross .buf .data (), state->kv_cross .buf .size (), 0 ));
2971+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_cross" , state->kv_cross .buf .data (), state->kv_cross .buf .size (), 0 ));
29702972
2971- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_self_0" , state->decoders [0 ].kv_self .buf .data (), state->decoders [0 ].kv_self .buf .size (), 0 ));
2973+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , " kv_self_0" , state->decoders [0 ].kv_self .buf .data (), state->decoders [0 ].kv_self .buf .size (), 0 ));
29722974#undef WHISPER_METAL_CHECK_BUF
2975+
2976+ }
29732977#endif
29742978
29752979 state->rng = std::mt19937 (0 );
@@ -4492,17 +4496,19 @@ int whisper_full_with_state(
44924496
44934497 // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
44944498#ifdef GGML_USE_METAL
4499+ if (state->ctx_metal ) {
44954500#define WHISPER_METAL_CHECK_BUF (result ) \
4496- if (!(result)) { \
4497- log (" %s: failed to add metal buffer\n " , __func__); \
4498- return 0 ; \
4499- }
4501+ if (!(result)) { \
4502+ log (" %s: failed to add metal buffer\n " , __func__); \
4503+ return 0 ; \
4504+ }
45004505
4501- const std::string kv_name = " kv_self_" + std::to_string (j);
4502- auto & kv_self = decoder.kv_self ;
4506+ const std::string kv_name = " kv_self_" + std::to_string (j);
4507+ auto & kv_self = decoder.kv_self ;
45034508
4504- WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , kv_name.c_str (), kv_self.buf .data (), kv_self.buf .size (), 0 ));
4509+ WHISPER_METAL_CHECK_BUF (ggml_metal_add_buffer (state->ctx_metal , kv_name.c_str (), kv_self.buf .data (), kv_self.buf .size (), 0 ));
45054510#undef WHISPER_METAL_CHECK_BUF
4511+ }
45064512#endif
45074513 }
45084514 }
0 commit comments