Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build-scripts/runtime_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ endif ()

if (WAMR_BUILD_WASI_NN EQUAL 1)
include (${IWASM_DIR}/libraries/wasi-nn/cmake/wasi_nn.cmake)
set (WAMR_BUILD_MODULE_INST_CONTEXT 1)
endif ()

if (WAMR_BUILD_LIB_PTHREAD EQUAL 1)
Expand Down
78 changes: 21 additions & 57 deletions core/iwasm/libraries/wasi-nn/src/wasi_nn.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,49 +55,15 @@ struct backends_api_functions {
NN_ERR_PRINTF("Error %s() -> %d", #func, wasi_error); \
} while (0)

/* HashMap utils */
static HashMap *hashmap;

static uint32
hash_func(const void *key)
{
// fnv1a_hash
const uint32 FNV_PRIME = 16777619;
const uint32 FNV_OFFSET_BASIS = 2166136261U;

uint32 hash = FNV_OFFSET_BASIS;
const unsigned char *bytes = (const unsigned char *)key;

for (size_t i = 0; i < sizeof(uintptr_t); ++i) {
hash ^= bytes[i];
hash *= FNV_PRIME;
}

return hash;
}

static bool
key_equal_func(void *key1, void *key2)
{
return key1 == key2;
}

static void
key_destroy_func(void *key1)
{
/* key type is wasm_module_inst_t*. do nothing */
}
static void *wasi_nn_key;

static void
wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
{
NN_DBG_PRINTF("[WASI NN] DEINIT...");

if (wasi_nn_ctx == NULL) {
NN_ERR_PRINTF(
"Error when deallocating memory. WASI-NN context is NULL");
return;
}
NN_DBG_PRINTF("[WASI NN] DEINIT...");
NN_DBG_PRINTF("Freeing wasi-nn");
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
Expand All @@ -116,9 +82,9 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
}

static void
value_destroy_func(void *value)
dtor(wasm_module_inst_t inst, void *ctx)
{
wasi_nn_ctx_destroy((WASINNContext *)value);
wasi_nn_ctx_destroy(ctx);
}

bool
Expand All @@ -131,12 +97,9 @@ wasi_nn_initialize()
return false;
}

// hashmap { instance: wasi_nn_ctx }
hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func,
key_equal_func, key_destroy_func,
value_destroy_func);
if (hashmap == NULL) {
NN_ERR_PRINTF("Error while initializing hashmap");
wasi_nn_key = wasm_runtime_create_context_key(dtor);
if (wasi_nn_key == NULL) {
NN_ERR_PRINTF("Failed to create context key");
os_mutex_destroy(&wasi_nn_lock);
return false;
}
Expand Down Expand Up @@ -170,21 +133,23 @@ static WASINNContext *
wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
{
WASINNContext *wasi_nn_ctx =
(WASINNContext *)bh_hash_map_find(hashmap, (void *)instance);
wasm_runtime_get_context(instance, wasi_nn_key);
if (wasi_nn_ctx == NULL) {
wasi_nn_ctx = wasi_nn_initialize_context();
if (wasi_nn_ctx == NULL)
return NULL;

bool ok =
bh_hash_map_insert(hashmap, (void *)instance, (void *)wasi_nn_ctx);
if (!ok) {
NN_ERR_PRINTF("Error while storing context");
wasi_nn_ctx_destroy(wasi_nn_ctx);
WASINNContext *newctx = wasi_nn_initialize_context();
if (newctx == NULL)
return NULL;
os_mutex_lock(&wasi_nn_lock);
wasi_nn_ctx = wasm_runtime_get_context(instance, wasi_nn_key);
if (wasi_nn_ctx == NULL) {
wasm_runtime_set_context_spread(instance, wasi_nn_key, newctx);
wasi_nn_ctx = newctx;
newctx = NULL;
}
os_mutex_unlock(&wasi_nn_lock);
if (newctx != NULL) {
wasi_nn_ctx_destroy(newctx);
}
}

return wasi_nn_ctx;
}

Expand Down Expand Up @@ -220,8 +185,7 @@ unlock_ctx(WASINNContext *wasi_nn_ctx)
void
wasi_nn_destroy()
{
// destroy hashmap will destroy keys and values
bh_hash_map_destroy(hashmap);
wasm_runtime_destroy_context_key(wasi_nn_key);

// close backends' libraries and registered functions
for (unsigned i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) {
Expand Down
Loading