Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 49 additions & 40 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
return ret;
}

static wasi_nn_error
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
WASINNContext **wasi_nn_ctx_ptr)
{
wasi_nn_error res;

graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(encoding, &loaded_backend)) {
res = invalid_encoding;
NN_ERR_PRINTF("load backend failed");
goto fail;
}

WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
if (wasi_nn_ctx->is_backend_ctx_initialized) {
if (wasi_nn_ctx->backend != loaded_backend) {
res = unsupported_operation;
goto fail;
}
}
else {
wasi_nn_ctx->backend = loaded_backend;

/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
if (res != success)
goto fail;

wasi_nn_ctx->is_backend_ctx_initialized = true;
}
*wasi_nn_ctx_ptr = wasi_nn_ctx;
return success;
fail:
return res;
}

/* WASI-NN implementation */

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
Expand All @@ -410,14 +447,15 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
{
wasi_nn_error res;

NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
target);

wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
if (!instance)
return runtime_error;

wasi_nn_error res;
graph_builder_array builder_native = { 0 };
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (success
Expand All @@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
goto fail;
}

graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(encoding, &loaded_backend)) {
res = invalid_encoding;
NN_ERR_PRINTF("load backend failed");
goto fail;
}

WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_ctx->backend = loaded_backend;

/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
if (res != success)
goto fail;

Expand All @@ -473,6 +500,8 @@ wasi_nn_error
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
graph *g)
{
wasi_nn_error res;

wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
if (!instance) {
return runtime_error;
Expand All @@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,

NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);

graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
NN_ERR_PRINTF("load backend failed");
return invalid_encoding;
}

WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_ctx->backend = loaded_backend;

wasi_nn_error res;
/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
if (res != success)
return res;

Expand All @@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
int32_t name_len, char *config,
int32_t config_len, graph *g)
{
wasi_nn_error res;

wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
if (!instance) {
return runtime_error;
Expand Down Expand Up @@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,

NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);

graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
NN_ERR_PRINTF("load backend failed");
return invalid_encoding;
}

WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_ctx->backend = loaded_backend;

wasi_nn_error res;
/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
if (res != success)
return res;

Expand Down
1 change: 1 addition & 0 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "wasm_export.h"

typedef struct {
bool is_backend_ctx_initialized;
bool is_model_loaded;
graph_encoding backend;
void *backend_ctx;
Expand Down
Loading