diff --git a/Makefile b/Makefile index 01fb5f46f..fd9d66437 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ compile: WAMR_VERSION = 2.2.0 WAMR_DIR = _build/wamr +WASI_NN_DIR = _build/wasi_nn/wasi_nn_backend GENESIS_WASM_BRANCH = tillathehun0/cu-experimental GENESIS_WASM_REPO = https://github.com/permaweb/ao.git @@ -104,3 +105,19 @@ setup-genesis-wasm: $(GENESIS_WASM_SERVER_DIR) fi @cd $(GENESIS_WASM_SERVER_DIR) && npm install > /dev/null 2>&1 && \ echo "Installed genesis-wasm@1.0 server." +# Set up wasi-nn environment +$(WASI_NN_DIR): + @echo "Cloning wasi-nn backend repository..." && \ + git clone --depth=1 -b stage/wasi-nn https://github.com/apuslabs/wasi_nn_backend.git $(WASI_NN_DIR) && \ + echo "Cloned wasi-nn backend to $(WASI_NN_DIR)" + +setup-wasi-nn: $(WASI_NN_DIR) + @mkdir -p $(WASI_NN_DIR)/lib + @echo "Building wasi-nn backend..." && \ + cmake \ + $(WAMR_FLAGS) \ + -S $(WASI_NN_DIR) \ + -B $(WASI_NN_DIR)/build && \ + make -C $(WASI_NN_DIR)/build -j8 && \ + cp $(WASI_NN_DIR)/build/libwasi_nn_backend.so ./native/wasi_nn_llama && \ + echo "Successfully built wasi-nn backend" diff --git a/native/wasi_nn_llama/include/wasi_nn_llama.h b/native/wasi_nn_llama/include/wasi_nn_llama.h new file mode 100644 index 000000000..5e81eca69 --- /dev/null +++ b/native/wasi_nn_llama/include/wasi_nn_llama.h @@ -0,0 +1,183 @@ +/* + * Copyright (C) 2019 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + + #ifndef WASI_NN_TYPES_H + #define WASI_NN_TYPES_H + +#include +#include + + #ifdef __cplusplus + extern "C" { + #endif + + /** + * ERRORS + * + */ + + // sync up with + // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L136 Error + // codes returned by functions in this API. + typedef enum { + // No error occurred. + success = 0, + // Caller module passed an invalid argument. + invalid_argument, + // Invalid encoding. + invalid_encoding, + // The operation timed out. + timeout, + // Runtime Error. + runtime_error, + // Unsupported operation. + unsupported_operation, + // Graph is too large. + too_large, + // Graph not found. + not_found, + // The operation is insecure or has insufficient privilege to be performed. + // e.g., cannot access a hardware feature requested + security, + // The operation failed for an unspecified reason. + unknown, + // for WasmEdge-wasi-nn + end_of_sequence = 100, // End of Sequence Found. + context_full = 101, // Context Full. + prompt_tool_long = 102, // Prompt Too Long. + model_not_found = 103, // Model Not Found. + } wasi_nn_error; + + /** + * TENSOR + * + */ + + // The dimensions of a tensor. + // + // The array length matches the tensor rank and each element in the array + // describes the size of each dimension. + typedef struct { + uint32_t *buf; + uint32_t size; + } tensor_dimensions; + + #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + // sync up with + // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L27 + // The type of the elements in a tensor. + typedef enum { fp16 = 0, fp32, fp64, bf16, u8, i32, i64 } tensor_type; + #else + typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; + #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ + + // The tensor data. + // + // Initially conceived as a sparse representation, each empty cell would be + // filled with zeros and the array length must match the product of all of the + // dimensions and the number of bytes in the type (e.g., a 2x2 tensor with + // 4-byte f32 elements would have a data array of length 16). Naturally, this + // representation requires some knowledge of how to lay out data in + // memory--e.g., using row-major ordering--and could perhaps be improved. + typedef uint8_t *tensor_data; + + // A tensor. + typedef struct { + // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To + // represent a tensor containing a single value, use `[1]` for the tensor + // dimensions. + tensor_dimensions *dimensions; + // Describe the type of element in the tensor (e.g., f32). + tensor_type type; + // Contains the tensor data. + tensor_data data; + } tensor; + + /** + * GRAPH + * + */ + + // The graph initialization data. + // + // This consists of an array of buffers because implementing backends may encode + // their graph IR in parts (e.g., OpenVINO stores its IR and weights + // separately). + typedef struct { + uint8_t *buf; + uint32_t size; + } graph_builder; + + typedef struct { + graph_builder *buf; + uint32_t size; + } graph_builder_array; + + // An execution graph for performing inference (i.e., a model). + typedef uint32_t graph; + + // sync up with + // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75 + // Describes the encoding of the graph. This allows the API to be implemented by + // various backends that encode (i.e., serialize) their graph IR with different + // formats. + typedef enum { + openvino = 0, + onnx, + tensorflow, + pytorch, + tensorflowlite, + ggml, + autodetect, + unknown_backend, + } graph_encoding; + + // Define where the graph should be executed. + typedef enum execution_target { cpu = 0, gpu, tpu } execution_target; + + // Bind a `graph` to the input and output tensors for an inference. + typedef uint32_t graph_execution_context; + + + __attribute__((visibility("default"))) wasi_nn_error + init_backend(void **ctx) ; + + __attribute__((visibility("default"))) wasi_nn_error + init_backend_with_config(void **ctx, const char *config, uint32_t config_len); + + __attribute__((visibility("default"))) wasi_nn_error + deinit_backend(void *ctx); + + __attribute__((visibility("default"))) wasi_nn_error + load_by_name_with_config(void *ctx, const char *filename, uint32_t filename_len, + const char *config, uint32_t config_len, graph *g); + + __attribute__((visibility("default"))) wasi_nn_error + init_execution_context(void *ctx, const char *session_id, graph_execution_context *exec_ctx); + + __attribute__((visibility("default"))) wasi_nn_error + close_execution_context(void *ctx, graph_execution_context exec_ctx); + + __attribute__((visibility("default"))) wasi_nn_error + run_inference(void *ctx, graph_execution_context exec_ctx, uint32_t index, + tensor *input_tensor,tensor_data output_tensor, uint32_t *output_tensor_size, const char *options); + + __attribute__((visibility("default"))) wasi_nn_error + set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index, + tensor *wasi_nn_tensor); + + __attribute__((visibility("default"))) wasi_nn_error + compute(void *ctx, graph_execution_context exec_ctx); + + __attribute__((visibility("default"))) wasi_nn_error + get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index, + tensor_data output_tensor, uint32_t *output_tensor_size); + + + #ifdef __cplusplus + } + #endif + #endif + \ No newline at end of file diff --git a/native/wasi_nn_llama/include/wasi_nn_logging.h b/native/wasi_nn_llama/include/wasi_nn_logging.h new file mode 100644 index 000000000..5fdb5feee --- /dev/null +++ b/native/wasi_nn_llama/include/wasi_nn_logging.h @@ -0,0 +1,34 @@ +#include +#include +#include +#include +#ifndef HB_LOGGING_H +#define HB_LOGGING_H + + +// Enable debug logging by default if not defined +#define HB_DEBUG 0 +#ifndef HB_DEBUG +#endif + + +#define DRV_DEBUG(format, ...) beamr_print(HB_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) +#define DRV_PRINT(format, ...) beamr_print(1, __FILE__, __LINE__, format, ##__VA_ARGS__) + +/* + * Function: beamr_print + * -------------------- + * This function prints a formatted message to the standard output, prefixed with the thread + * ID, file name, and line number where the log was generated. + * + * print: A flag that controls whether the message is printed (1 to print, 0 to skip). + * file: The source file name where the log was generated. + * line: The line number where the log was generated. + * format: The format string for the message. + * ...: The variables to be printed in the format. + */ +void beamr_print(int print, const char* file, int line, const char* format, ...); + + + +#endif // HB_LOGGING_H \ No newline at end of file diff --git a/native/wasi_nn_llama/include/wasi_nn_nif.h b/native/wasi_nn_llama/include/wasi_nn_nif.h new file mode 100644 index 000000000..c99ad1bee --- /dev/null +++ b/native/wasi_nn_llama/include/wasi_nn_nif.h @@ -0,0 +1,40 @@ +#ifndef WASI_NN_NIF_H +#define WASI_NN_NIF_H + +#include "wasi_nn_llama.h" +#include +#include +#include +#include + + +// Function pointer types +typedef wasi_nn_error (*init_backend_fn)(void **ctx); +typedef wasi_nn_error (*init_backend_with_config_fn)(void **ctx, const char *config, uint32_t config_len); +typedef wasi_nn_error (*deinit_backend_fn)(void *ctx); +typedef wasi_nn_error (*init_execution_context_fn)(void *ctx, const char *session_id, graph_execution_context *exec_ctx); +typedef wasi_nn_error (*close_execution_context_fn)(void *ctx, graph_execution_context exec_ctx); +typedef wasi_nn_error (*load_by_name_with_config_fn)(void *ctx, const char *filename, uint32_t filename_len, + const char *config, uint32_t config_len, graph *g); +typedef wasi_nn_error (*run_inference_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, + tensor *input_tensor,tensor_data output_tensor, uint32_t *output_tensor_size, const char *options); +typedef wasi_nn_error (*set_input_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor *tensor); +typedef wasi_nn_error (*compute_fn)(void *ctx, graph_execution_context exec_ctx); +typedef wasi_nn_error (*get_output_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor_data output, uint32_t *output_size); + +// Structure to hold all function pointers +typedef struct { + void* handle; + init_backend_fn init_backend; + init_backend_with_config_fn init_backend_with_config; + deinit_backend_fn deinit_backend; + load_by_name_with_config_fn load_by_name_with_config; + init_execution_context_fn init_execution_context; + close_execution_context_fn close_execution_context; + run_inference_fn run_inference; + set_input_fn set_input; + compute_fn compute; + get_output_fn get_output; + +} wasi_nn_backend_api; +#endif // WASI_NN_NIF_H \ No newline at end of file diff --git a/native/wasi_nn_llama/src/wasi_nn_logging.c b/native/wasi_nn_llama/src/wasi_nn_logging.c new file mode 100644 index 000000000..b2cf33c78 --- /dev/null +++ b/native/wasi_nn_llama/src/wasi_nn_logging.c @@ -0,0 +1,36 @@ +#include "../include/wasi_nn_logging.h" + + + +void beamr_print(int print, const char* file, int line, const char* format, ...) { + va_list args; + va_start(args, format); + if(print) { + pthread_t thread_id = pthread_self(); + printf("[DBG#%p @ %s:%d] ", thread_id, file, line); + vprintf(format, args); + printf("\r\n"); + } + va_end(args); +} + +// void send_error(Proc* proc, const char* message_fmt, ...) { +// va_list args; +// va_start(args, message_fmt); +// char* message = driver_alloc(256); +// vsnprintf(message, 256, message_fmt, args); +// DRV_DEBUG("Sending error message: %s", message); +// ErlDrvTermData* msg = driver_alloc(sizeof(ErlDrvTermData) * 7); +// int msg_index = 0; +// msg[msg_index++] = ERL_DRV_ATOM; +// msg[msg_index++] = atom_error; +// msg[msg_index++] = ERL_DRV_STRING; +// msg[msg_index++] = (ErlDrvTermData)message; +// msg[msg_index++] = strlen(message); +// msg[msg_index++] = ERL_DRV_TUPLE; +// msg[msg_index++] = 2; + +// int msg_res = erl_drv_output_term(proc->port_term, msg, msg_index); +// DRV_DEBUG("Sent error message. Res: %d", msg_res); +// va_end(args); +// } \ No newline at end of file diff --git a/native/wasi_nn_llama/src/wasi_nn_nif.c b/native/wasi_nn_llama/src/wasi_nn_nif.c new file mode 100644 index 000000000..235fd2477 --- /dev/null +++ b/native/wasi_nn_llama/src/wasi_nn_nif.c @@ -0,0 +1,292 @@ +#include "../include/wasi_nn_nif.h" +#include "../include/wasi_nn_logging.h" +#define LIB_PATH "./native/wasi_nn_llama/libwasi_nn_backend.so" +#define MAX_MODEL_PATH 256 +#define MAX_INPUT_SIZE 4096 +#define MAX_CONFIG_SIZE 1024 +#define MAX_OUTPUT_SIZE 8192 + +typedef struct { + void* ctx; + graph g; + graph_execution_context exec_ctx; +} LlamaContext; + +static wasi_nn_backend_api g_wasi_nn_functions = {0}; +static ErlNifResourceType* llama_context_resource; + +static void llama_context_destructor(ErlNifEnv* env, void* obj) +{ + + LlamaContext* ctx = (LlamaContext*)obj; + if (ctx) { + // Cleanup backend context + if (ctx->ctx && g_wasi_nn_functions.deinit_backend) { + g_wasi_nn_functions.deinit_backend(ctx->ctx); + ctx->ctx = NULL; + } + // No need to cleanup shared library here since it's managed globally + // Clear the context structure + memset(ctx, 0, sizeof(LlamaContext)); + } + +} + +static int load(ErlNifEnv* env, void** priv_data, ERL_NIF_TERM load_info) +{ + DRV_DEBUG("Load nif start\n"); + g_wasi_nn_functions.handle = dlopen(LIB_PATH, RTLD_LAZY); + if (!g_wasi_nn_functions.handle) { + DRV_DEBUG("Failed to load wasi library: %s\n", dlerror()); + return 1; + } + // Load all required functions once + g_wasi_nn_functions.init_backend = (init_backend_fn)dlsym(g_wasi_nn_functions.handle, "init_backend"); + g_wasi_nn_functions.init_backend_with_config = (init_backend_with_config_fn)dlsym(g_wasi_nn_functions.handle, "init_backend_with_config"); + g_wasi_nn_functions.deinit_backend = (deinit_backend_fn)dlsym(g_wasi_nn_functions.handle, "deinit_backend"); + g_wasi_nn_functions.init_execution_context = (init_execution_context_fn)dlsym(g_wasi_nn_functions.handle, "init_execution_context"); + g_wasi_nn_functions.close_execution_context = (close_execution_context_fn)dlsym(g_wasi_nn_functions.handle, "close_execution_context"); + g_wasi_nn_functions.set_input = (set_input_fn)dlsym(g_wasi_nn_functions.handle, "set_input"); + g_wasi_nn_functions.compute = (compute_fn)dlsym(g_wasi_nn_functions.handle, "compute"); + g_wasi_nn_functions.get_output = (get_output_fn)dlsym(g_wasi_nn_functions.handle, "get_output"); + g_wasi_nn_functions.load_by_name_with_config = (load_by_name_with_config_fn)dlsym(g_wasi_nn_functions.handle, "load_by_name_with_config"); + g_wasi_nn_functions.run_inference = (run_inference_fn)dlsym(g_wasi_nn_functions.handle, "run_inference"); + if (!g_wasi_nn_functions.init_backend ||!g_wasi_nn_functions.deinit_backend || + !g_wasi_nn_functions.init_execution_context ||!g_wasi_nn_functions.close_execution_context || + !g_wasi_nn_functions.set_input ||!g_wasi_nn_functions.compute || + !g_wasi_nn_functions.get_output ||!g_wasi_nn_functions.load_by_name_with_config ||!g_wasi_nn_functions.run_inference) { + dlclose(g_wasi_nn_functions.handle); + return 1; + } + DRV_DEBUG("Load nif Finished\n"); + llama_context_resource = enif_open_resource_type(env, NULL, "llama_context", + llama_context_destructor, ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER, NULL); + return llama_context_resource ? 0 : 1; +} + + + + +static ERL_NIF_TERM nif_init_backend(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + LlamaContext* ctx = enif_alloc_resource(llama_context_resource, sizeof(LlamaContext)); + if (!ctx) { + DRV_DEBUG("Failed to allocate LlamaContext resource\n"); + return enif_make_tuple2(env, enif_make_atom(env, "error"), + enif_make_atom(env, "allocation_failed")); + } + DRV_DEBUG("Initializing backend...\n"); + wasi_nn_error err = g_wasi_nn_functions.init_backend(&ctx->ctx); + if (err != success) { + DRV_DEBUG("Backend initialization failed with error: %d\n", err); + enif_release_resource(ctx); + return enif_make_tuple2(env, enif_make_atom(env, "error"), + enif_make_atom(env, "init_failed")); + } + DRV_DEBUG("nif_init_backend finished \n"); + ERL_NIF_TERM ctx_term = enif_make_resource(env, ctx); + return enif_make_tuple2(env, enif_make_atom(env, "ok"), ctx_term); +} +static ERL_NIF_TERM nif_load_by_name_with_config(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + LlamaContext* ctx; + char *model_path = (char *)malloc(MAX_MODEL_PATH * sizeof(char)); + char *config = (char *)malloc(MAX_CONFIG_SIZE * sizeof(char)); + ERL_NIF_TERM ret_term; // Variable to hold the return term + // if allocate failed + if (!model_path || !config) { + DRV_DEBUG("Memory allocation failed for model_path or config\n"); + free(model_path); // free(NULL) is safe + free(config); + return enif_make_tuple2(env, enif_make_atom(env, "error"), + enif_make_atom(env, "allocation_failed")); + } + // Get the context from the first argument + if(!enif_get_resource(env, argv[0], llama_context_resource, (void**)&ctx)) + { + DRV_DEBUG("Invalid context\n"); + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_context")); + goto cleanup; // Use goto for centralized cleanup + } + // Get the model path from the second argument + if (!enif_get_string(env, argv[1], model_path, MAX_MODEL_PATH, ERL_NIF_LATIN1)) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"),enif_make_atom(env, "invalid_model_path")); + goto cleanup; + } + // Get the config from the third argument + if (!enif_get_string(env, argv[2], config, MAX_CONFIG_SIZE, ERL_NIF_LATIN1)) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), + enif_make_atom(env, "invalid_config")); + goto cleanup; + } + DRV_DEBUG("Loading model: %s config : %s\n", model_path, config); + + if (g_wasi_nn_functions.load_by_name_with_config(ctx->ctx, model_path, strlen(model_path), + config, strlen(config), &ctx->g) != success) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "load_failed")); + goto cleanup; + } + + ret_term = enif_make_atom(env, "ok"); + + +cleanup: + free(model_path); + free(config); + return ret_term; +} + +static ERL_NIF_TERM nif_init_execution_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + DRV_DEBUG("Init context Start \n" ); + LlamaContext* ctx; + char session_id[256]; + + if (!enif_get_resource(env, argv[0], llama_context_resource, (void**)&ctx)) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_args_init_execution")); + } + + if (!enif_get_string(env, argv[1], session_id, sizeof(session_id), ERL_NIF_LATIN1)) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_session_id")); + } + + if (g_wasi_nn_functions.init_execution_context(ctx->ctx, session_id, &ctx->exec_ctx)!= success) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "init_execution_failed")); + } + DRV_DEBUG("Init context finished for session: %s\n", session_id); + return enif_make_tuple2(env, enif_make_atom(env, "ok"), enif_make_ulong(env, ctx->exec_ctx)); +} + +static ERL_NIF_TERM nif_close_execution_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + LlamaContext* ctx; + unsigned long exec_ctx_id; + + if (!enif_get_resource(env, argv[0], llama_context_resource, (void**)&ctx)) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_args")); + } + + if (!enif_get_ulong(env, argv[1], &exec_ctx_id)) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_exec_ctx")); + } + + if (g_wasi_nn_functions.close_execution_context(ctx->ctx, (graph_execution_context)exec_ctx_id) != success) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "close_execution_failed")); + } + + return enif_make_atom(env, "ok"); +} +static ERL_NIF_TERM nif_run_inference(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + DRV_DEBUG("Start to run_inference \n" ); + LlamaContext* ctx; + unsigned long exec_ctx_id; + char *input = NULL; + tensor_data output; + ERL_NIF_TERM ret_term; // Variable for the return term + ERL_NIF_TERM result_bin; + uint32_t output_size = 0; // Initialize output_size + + // Allocate memory + input = (char *)malloc(MAX_INPUT_SIZE * sizeof(char)); + output = (uint8_t *)malloc(MAX_OUTPUT_SIZE * sizeof(uint8_t)); + // Check allocations + if (!input || !output ) { + fprintf(stderr, "Initial memory allocation failed\n"); + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "allocation_failed")); + goto cleanup; // Jump to cleanup section + } + // Get the context from the first argument + if (!enif_get_resource(env, argv[0], llama_context_resource, (void**)&ctx)) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_args")); + goto cleanup; + } + + // Get the execution context ID from the second argument + if (!enif_get_ulong(env, argv[1], &exec_ctx_id)) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_exec_ctx")); + goto cleanup; + } + + //Get input from the third argument + if (!enif_get_string(env, argv[2], input, MAX_INPUT_SIZE, ERL_NIF_LATIN1)) { + DRV_DEBUG("Invalid input\n"); + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_input")); + goto cleanup; + } + + tensor input_tensor = { + .dimensions = NULL, + .type = fp32, + .data = (tensor_data)input, + }; + // Run inference with session-specific execution context + if (g_wasi_nn_functions.run_inference(ctx->ctx, (graph_execution_context)exec_ctx_id, 0, &input_tensor, output, &output_size, NULL) != success) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "run_inference_failed")); + goto cleanup; + } + DRV_DEBUG("Output size: %d\n", output_size); + DRV_DEBUG("Output %s\n", output); + // TODO limit output size + unsigned char* bin_data = enif_make_new_binary(env, output_size, &result_bin); + if (!bin_data) { + ret_term = enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "binary_creation_failed")); + // Output buffer still needs freeing in cleanup + goto cleanup; + } + + // Copy the output_buffer into the Erlang binary + memcpy(bin_data, output, output_size); + ret_term = enif_make_tuple2(env, enif_make_atom(env, "ok"), result_bin); +cleanup: + // Free all allocated memory. free(NULL) is safe. + free(input); + free(output); + DRV_DEBUG("Clean all"); + return ret_term; +} +static ERL_NIF_TERM nif_set_input(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + // TBD + return enif_make_atom(env, "ok"); +} +static ERL_NIF_TERM nif_compute(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + // TBD + return enif_make_atom(env, "ok"); +} +static ERL_NIF_TERM nif_get_output(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + // TBD + return enif_make_atom(env, "ok"); +} +static ERL_NIF_TERM nif_deinit_backend(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) +{ + LlamaContext* ctx; + if (!enif_get_resource(env, argv[0], llama_context_resource, (void**)&ctx)) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "invalid_args")); + } + if (g_wasi_nn_functions.deinit_backend(ctx->ctx)!= success) { + return enif_make_tuple2(env, enif_make_atom(env, "error"), enif_make_atom(env, "deinit_failed")); + } + return enif_make_atom(env, "ok"); +} +static ErlNifFunc nif_funcs[] = { + {"init_backend", 0, nif_init_backend}, + {"load_by_name_with_config", 3, nif_load_by_name_with_config}, + {"init_execution_context", 2, nif_init_execution_context}, + {"close_execution_context", 2, nif_close_execution_context}, + {"deinit_backend", 1, nif_deinit_backend}, + {"run_inference", 3, nif_run_inference}, +}; + + +static void unload(ErlNifEnv* env, void* priv_data) +{ + // The resource destructor will be called automatically for any remaining resources + if (g_wasi_nn_functions.handle) { + dlclose(g_wasi_nn_functions.handle); + g_wasi_nn_functions.handle = NULL; + } +} +ERL_NIF_INIT(dev_wasi_nn_nif, nif_funcs, load, NULL, NULL, unload) \ No newline at end of file diff --git a/rebar.config b/rebar.config index 8bfb84323..6750586b7 100644 --- a/rebar.config +++ b/rebar.config @@ -52,6 +52,21 @@ {add, cowboy, [{erl_opts, [{d, 'COWBOY_QUICER', 1}]}]}, {add, gun, [{erl_opts, [{d, 'GUN_QUICER', 1}]}]} ]} + ]}, + {wasi_nn, [ + {erl_opts, [{d, 'ENABLE_WASI_NN', true}]}, + {pre_hooks, [ + {compile, "make -C \"${REBAR_ROOT_DIR}\" setup-wasi-nn"} + ]}, + {port_specs, [ + {"./priv/wasi_nn.so", [ + "./native/wasi_nn_llama/src/wasi_nn_nif.c", + "./native/wasi_nn_llama/src/wasi_nn_logging.c" + ]} + ]}, + {post_hooks, [ + { compile, "rm -f native/wasi_nn_llama/src/*.o native/wasi_nn_llama/src/*.d"} + ]} ]} ]}. diff --git a/src/dev_wasi_nn.erl b/src/dev_wasi_nn.erl new file mode 100644 index 000000000..a63007e42 --- /dev/null +++ b/src/dev_wasi_nn.erl @@ -0,0 +1,296 @@ +%%% @doc A WASI-NN device implementation for HyperBEAM that provides AI inference +%%% capabilities. This device supports loading models from Arweave transactions +%%% and performing inference with session management for optimal performance. +%%% Models are cached locally to avoid repeated downloads. +-module(dev_wasi_nn). +-export([info/1, info/3, infer/3]). +-export([read_model_by_ID/2]). +-include("include/hb.hrl"). +-include_lib("eunit/include/eunit.hrl"). +%% @doc Get device information and exported functions. +%% Returns the list of functions that are exposed via the device API. +%% This is used by the HyperBEAM runtime to determine which endpoints +%% are available for this device. +%% +%% @param _ Ignored parameter. +%% @returns A map containing the list of exported functions. +info(_) -> + #{ exports => [info, infer] }. + +%% @doc Provide HTTP info response about this device. +%% Returns comprehensive information about the WASI-NN device including +%% its capabilities, version, and API documentation. This endpoint helps +%% users understand how to interact with the AI inference functionality. +%% +%% @param _Msg1 Ignored parameter. +%% @param _Msg2 Ignored parameter. +%% @param _Opts Ignored parameter. +%% @returns {ok, InfoBody} containing device information and API documentation. +info(_Msg1, _Msg2, _Opts) -> + InfoBody = #{ + <<"description">> => <<"GPU device for handling LLM Inference">>, + <<"version">> => <<"1.0">>, + <<"api">> => #{ + <<"infer">> => #{ + <<"description">> => <<"LLM Inference">>, + <<"method">> => <<"GET or POST">>, + <<"required_params">> => #{ + <<"prompt">> => <<"Prompt for Infer">>, + <<"model-id">> => <<"Arweave TXID of the model file">> + } + } + } + }, + {ok, InfoBody}. + +%% @doc Perform AI inference using a specified model and prompt. +%% This function handles the complete inference workflow including model +%% retrieval (either from local cache or Arweave), session management, +%% and inference execution. Models are automatically downloaded and cached +%% locally for improved performance on subsequent requests. +%% +%% @param _M1 Ignored parameter. +%% @param M2 The request message containing inference parameters: +%% - <<"model-id">>: Arweave transaction ID of the model file +%% - <<"config">>: JSON configuration for the model (optional) +%% - <<"prompt">>: The input prompt for inference +%% - <<"session-id">>: Session identifier for context reuse (optional) +%% @param Opts A map of configuration options. +%% @returns {ok, #{<<"result">> := Result, <<"session-id">> := SessionId}} on success, +%% {error, Reason} on failure. +infer(_M1, M2, Opts) -> + TxID = hb_ao:get(<<"model-id">>, M2, undefined, Opts), + ModelConfig = hb_ao:get(<<"config">>, M2, + "{\"n_gpu_layers\":96,\"ctx_size\":64000,\"batch_size\":64000}", Opts), + Prompt = hb_ao:get(<<"prompt">>, M2, Opts), + SessionId = hb_ao:get(<<"session-id">>, M2, undefined, Opts), + ?event(dev_wasi_nn, {infer, {tx_id, TxID}, {session_id, SessionId}}), + case TxID of + undefined -> + ?event(dev_wasi_nn, {infer, {fallback_to_default_model}}), + DefaultTxID = <<"ISrbGzQot05rs_HKC08O_SmkipYQnqgB1yC3mjZZeEo">>, + case read_model_by_ID(DefaultTxID, Opts) of + {ok, LocalModelPath} -> + ?event(dev_wasi_nn, {infer, {model_ready, LocalModelPath}}), + load_and_infer(LocalModelPath, ModelConfig, Prompt, SessionId, Opts); + {error, Reason} -> + ?event(dev_wasi_nn, {infer, {model_download_failed, Reason}}), + {error, {model_download_failed, Reason}} + end; + _ -> + ?event(dev_wasi_nn, {infer, {downloading_model, TxID}}), + case read_model_by_ID(TxID, Opts) of + {ok, LocalModelPath} -> + ?event(dev_wasi_nn, {infer, {model_ready, LocalModelPath}}), + load_and_infer(LocalModelPath, ModelConfig, Prompt, SessionId, Opts); + {error, Reason} -> + ?event(dev_wasi_nn, {infer, {model_download_failed, Reason}}), + {error, {model_download_failed, Reason}} + end + end. +%%%-------------------------------------------------------------------- +%%% Helper Functions +%%%-------------------------------------------------------------------- +%% @doc Load model and perform inference using persistent context management. +%% Handles the complete inference workflow including model loading, session +%% management, and inference execution. Uses session IDs to maintain context +%% across multiple requests for improved performance. If no session ID is +%% provided, a new unique session ID will be generated. +%% +%% @param ModelPath The local file path to the model. +%% @param ModelConfig JSON configuration string for the model. +%% @param Prompt The input prompt for inference. +%% @param ProvidedSessionId Optional session ID for context reuse. If undefined, +%% a new session ID will be generated. +%% @param Opts A map of configuration options. +%% @returns {ok, #{<<"result">> := Result, <<"session-id">> := SessionId}} on success, +%% {error, Reason} on failure. +load_and_infer(ModelPath, ModelConfig, Prompt, ProvidedSessionId, Opts) -> + SessionId = case ProvidedSessionId of + undefined -> hb_util:human_id(crypto:strong_rand_bytes(32)); + _ -> ProvidedSessionId + end, + ?event(dev_wasi_nn, {load_and_infer, {model_path, ModelPath}, {session_id, SessionId}}), + % Use persistent context management (fast if model already loaded) + case dev_wasi_nn_nif:switch_model(ModelPath, ModelConfig) of + {ok, Context} -> + % Create or reuse session-specific execution context + case dev_wasi_nn_nif:init_execution_context_once(Context, binary_to_list(SessionId)) of + {ok, ExecContextId} -> + % Run inference with session-specific context + case dev_wasi_nn_nif:run_inference(Context, ExecContextId, binary_to_list(Prompt)) of + {ok, Output} -> + ?event(output, Output), + {ok, #{ + <<"result">> => Output, + <<"session-id">> => list_to_binary(SessionId) + }}; + {error, Reason} -> + ?event(dev_wasi_nn, {inference_failed, SessionId, Reason}), + {error, Reason} + end; + {error, Reason2} -> + ?event(dev_wasi_nn, {session_init_failed, SessionId, Reason2}), + {error, Reason2} + end; + {error, Reason3} -> + ?event(dev_wasi_nn, {model_load_failed, SessionId, ModelPath, Reason3}), + {error, Reason3} + end. +%% @doc Configure options with model storage settings. +%% This helper function extends base options with appropriate model storage +%% configuration. It allows users to customize the model store if desired, +%% otherwise uses sensible defaults for local filesystem caching. +%% +%% @param BaseOpts The base options to extend with model storage configuration. +%% @returns Extended options map with model store configuration. +opts(BaseOpts) -> + %% Allow user to configure model store, or use default + DefaultModelStore = #{ + <<"store-module">> => hb_store_fs, + <<"name">> => <<"model-cache">> + }, + ModelStore = hb_opts:get(model_store, DefaultModelStore, BaseOpts), + %% Extend base options with model store configuration + BaseOpts#{ + store => [ModelStore | hb_opts:get(store, [], BaseOpts)] + }. + +%% @doc Download and retrieve a model by Arweave transaction ID. +%% This function handles the complete model retrieval workflow including: +%% - Starting the HTTP server for Arweave gateway access +%% - Configuring local filesystem caching to avoid repeated downloads +%% - Downloading the model from Arweave if not already cached +%% - Resolving the local file path where the model is stored +%% +%% The function uses a two-tier storage strategy: +%% 1. First checks local cache (hb_store_fs) for existing model +%% 2. Falls back to Arweave gateway (hb_store_gateway) if not cached +%% 3. Automatically caches downloaded models locally for future use +%% +%% @param TxID The Arweave transaction ID containing the model file as a binary. +%% @param Opts The base options to extend with model storage configuration. +%% @returns {ok, LocalFilePath} where LocalFilePath is a string path to the +%% cached model file, or {error, Reason} on failure. +read_model_by_ID(TxID, Opts) -> + %% Start the HTTP server (required for gateway access) + hb_http_server:start_node(#{}), + %% Configure options with model storage settings + ConfiguredOpts = opts(Opts), + %% Attempt to read the model from cache or download from Arweave + case hb_cache:read(TxID, ConfiguredOpts) of + {ok, Message} -> + ?event(cache, {successfully_read_message_from_arweave}), + %% Extract the data reference from the message + %% This could be either a link to existing cached data or binary data + DataLink = hb_maps:get(<<"data">>, Message, undefined, ConfiguredOpts), + ?event(cache, {data_link, DataLink}), + %% Handle two different data storage formats + case DataLink of + %% Case 1: Data is stored as a link reference to existing cached file + {link, DataPath, _LinkOpts} -> + ?event(cache, {extracted_data_path, DataPath}), + %% Resolve the relative path to absolute filesystem path + %% The store resolves internal paths to actual file locations + ResolvedPath = hb_store:resolve(ModelStore, DataPath), + StoreName = hb_maps:get(<<"name">>, ModelStore, undefined, ConfiguredOpts), + %% Construct full path: "model-cache/resolved/path/to/file" + ActualFilePath = <>, + ?event(cache, {actual_file_path, ActualFilePath}), + %% Convert binary path to string for external API compatibility + StringPath = case is_binary(ActualFilePath) of + true -> binary_to_list(ActualFilePath); + false -> ActualFilePath + end, + {ok, StringPath}; + %% Case 2: Data is stored as direct binary content (needs hash-based path) + _ -> + %% Load the binary data into memory if not already loaded + LoadedData = hb_cache:ensure_loaded(DataLink, ConfiguredOpts), + ?event(cache, {loaded_data_size, byte_size(LoadedData)}), + %% Generate content-based hash path for storage location + %% This ensures identical files share the same storage location + Hashpath = hb_path:hashpath(LoadedData, ConfiguredOpts), + ?event(cache, {calculated_hashpath, Hashpath}), + %% Construct the standardized data path using content hash + DataPath = <<"data/", Hashpath/binary>>, + ?event(cache, {data_path, DataPath}), + %% Resolve to actual filesystem path and construct full path + ResolvedPath = hb_store:resolve(ModelStore, DataPath), + StoreName = hb_maps:get(<<"name">>, ModelStore, undefined, ConfiguredOpts), + ActualFilePath = <>, + ?event(cache, {actual_file_path, ActualFilePath}), + %% Convert binary path to string for external API compatibility + StringPath = case is_binary(ActualFilePath) of + true -> binary_to_list(ActualFilePath); + false -> ActualFilePath + end, + {ok, StringPath} + end; + not_found -> + %% Model transaction ID not found on Arweave network + ?event({string, <<"Message not found on Arweave">>}), + {error, not_found} + end. + +%% @doc Unit test for the complete inference API. +%% This test validates the end-to-end inference functionality by testing +%% the complete pipeline from model retrieval to inference execution. +%% The test uses the infer/3 function directly to simulate real API usage. +%% +%% IMPORTANT: This test requires the model to be available locally. +%% Run model_download_test() first to ensure the model is downloaded. +%% +%% The test performs the following steps: +%% 1. Creates a test message with model ID and prompt +%% 2. Calls the infer/3 function with the test parameters +%% 3. Validates the response format and content +%% 4. Ensures the inference result is meaningful +%% +%% This test simulates real-world API usage and validates the complete +%% inference workflow including model loading, session management, +%% and inference execution. +%% +%% @returns ok on success, throws an error on failure. +infer_test() -> + % Create test message with inference parameters + % - model-id: Arweave transaction ID of the model to use + % - prompt: Input text for inference + M2 = #{ + <<"model-id">> => <<"ISrbGzQot05rs_HKC08O_SmkipYQnqgB1yC3mjZZeEo">>, + <<"prompt">> => <<"Hello who are you?">> + }, + % Empty options map for this test + Opts = #{}, + % Execute the inference API call + case infer(#{}, M2, Opts) of + {ok, #{<<"result">> := Result, <<"session-id">> := SessionId}} -> + % Inference completed successfully + ?event(dev_wasi_nn, {infer_test, {result, Result}, {session_id, SessionId}}), + % Validate the inference result + % Ensure result is a binary and has content + ?assert(is_binary(Result)), + ?assert(byte_size(Result) > 0), + % Validate session ID is present + ?assert(is_binary(SessionId)), + ?assert(byte_size(SessionId) > 0); + {error, Reason} -> + % Inference failed + ?event(dev_wasi_nn, {infer_test, {inference_failed, Reason}}), + ?assert(false, Reason) + end. + +%%% Tests + +%% read model ID test +read_model_by_ID_test() -> + ID = <<"ISrbGzQot05rs_HKC08O_SmkipYQnqgB1yC3mjZZeEo">>, + case read_model_by_ID(ID, #{}) of + {ok, LocalModelPath} -> + ?event(dev_wasi_nn, {read_model_by_ID_test, {model_ready, LocalModelPath}}), + ?assert(is_list(LocalModelPath)), + ?assert(length(LocalModelPath) > 0); + {error, Reason} -> + ?event(dev_wasi_nn, {read_model_by_ID_test, {model_read_failed, Reason}}), + ?assert(false, {model_read_failed, Reason}) + end. \ No newline at end of file diff --git a/src/dev_wasi_nn_nif.erl b/src/dev_wasi_nn_nif.erl new file mode 100644 index 000000000..bcfe6d5a0 --- /dev/null +++ b/src/dev_wasi_nn_nif.erl @@ -0,0 +1,357 @@ +%%% @doc WASI-NN NIF module for HyperBEAM. +%%% Implements native functions for AI model loading and inference. +%%% This module provides the NIF interface for the dev_wasi_nn module. +-module(dev_wasi_nn_nif). +-include("include/hb.hrl"). +-include_lib("eunit/include/eunit.hrl"). +-export([ + init_backend/0, + load_by_name_with_config/3, + init_execution_context/2, + close_execution_context/2, + deinit_backend/1, + run_inference/3 +]). +-export([init_execution_context_once/2, switch_model/2]). +-export([cleanup_model_contexts/1, cleanup_all_contexts/0, get_current_model_info/0]). + +-on_load(init/0). +-define(NOT_LOADED, not_loaded(?LINE)). +%% Module-level cache +-define(CACHE_TAB, wasi_nn_cache). +-define(SINGLETON_KEY, global_cache). +-define(CACHE_OWNER_NAME, wasi_nn_cache_owner). % Registered name for cache owner process +%% @doc Start the dedicated ETS table owner process. +%% Creates a persistent process that owns the ETS table to ensure it remains +%% available even if the calling process terminates. +%% +%% @returns {ok, Pid} where Pid is the process ID of the cache owner. +start_cache_owner() -> + case whereis(?CACHE_OWNER_NAME) of + undefined -> + % No owner process exists, create one + Pid = spawn(fun() -> + % Create the table if it doesn't exist + case ets:info(?CACHE_TAB) of + undefined -> + ?event(dev_wasi_nn_nif, {cache_owner_creating_table, ?CACHE_TAB}), + ets:new(?CACHE_TAB, [set, named_table, public]); + _ -> + ?event(dev_wasi_nn_nif, {cache_table_already_exists, ?CACHE_TAB}) + end, + % Register the process with a name for easy lookup + register(?CACHE_OWNER_NAME, self()), + cache_owner_loop() + end), + {ok, Pid}; + Pid -> + % Owner process already exists + {ok, Pid} + end. + +%% @doc Loop function for the cache owner process. +%% Keeps the process alive to maintain ownership of the ETS table. +%% Handles stop messages and ping requests. +%% +%% @returns ok when the process is stopped. +cache_owner_loop() -> + receive + stop -> + ?event(dev_wasi_nn_nif, {cache_owner_stopping}), + ok; + {From, ping} -> + From ! {self(), pong}, + cache_owner_loop(); + _ -> + cache_owner_loop() + after + 3600000 -> % Stay alive for a long time (1 hour), then check again + cache_owner_loop() + end. + +%% @doc Initialize the NIF library and cache management. +%% This function is automatically called when the module is loaded. +%% It starts the cache owner process and loads the NIF library. +%% +%% @returns ok on success, exits with error on failure. +init() -> + PrivDir = code:priv_dir(hb), + Path = filename:join(PrivDir, "wasi_nn"), + ?event(dev_wasi_nn_nif, {loading_nif_from, Path}), + % Start the dedicated cache owner process + start_cache_owner(), + % Load the NIF library + case erlang:load_nif(Path, 0) of + ok -> + ?event(dev_wasi_nn_nif, {nif_loaded_successfully}), + ok; + {error, {load_failed, Reason}} -> + ?event(dev_wasi_nn_nif, {failed_to_load_nif, Reason}), + exit({load_failed, {load_failed, Reason}}); + {error, Reason} -> + ?event(dev_wasi_nn_nif, {failed_to_load_nif_with_error, Reason}), + exit({load_failed, Reason}) + end. + +%% Error handler for NIF functions that are not loaded +not_loaded(Line) -> + erlang:nif_error({not_loaded, [{module, ?MODULE}, {line, Line}]}). + +init_backend() -> + ?NOT_LOADED. + +load_by_name_with_config(_Context, _Path, _Config) -> + ?NOT_LOADED. + +init_execution_context(_Context, _SessionId) -> + ?NOT_LOADED. + +close_execution_context(_Context, _ExecContextId) -> + ?NOT_LOADED. + +deinit_backend(_Context) -> + ?NOT_LOADED. +run_inference(_Context, _ExecContextId, _Prompt) -> + ?NOT_LOADED. + +%% ============================================================================ +%% GLOBAL PERSISTENT CONTEXT MANAGEMENT +%% ============================================================================ + +%% @doc Switch to a different model, creating a new context for each model. +%% Checks if the model is already loaded with the same configuration, +%% and reuses the existing context if possible. Otherwise, creates a new context. +%% +%% @param ModelPath Path to the model file. +%% @param Config Configuration for the model. +%% @returns {ok, Context} on success, {error, Reason} on failure. +switch_model(ModelPath, Config) -> + ensure_cache_table(), + ModelKey = {?SINGLETON_KEY, model_context, ModelPath}, + case ets:lookup(?CACHE_TAB, ModelKey) of + [{_, {ok, Context, CachedConfig}}] when CachedConfig =:= Config -> + ?event(dev_wasi_nn_nif, {model_already_loaded, ModelPath, reusing_context}), + % Update current model reference + ets:insert(?CACHE_TAB, {{?SINGLETON_KEY, current_model}, {ModelPath, Config, Context}}), + {ok, Context}; + [{_, {ok, OldContext, _OldConfig}}] -> + ?event(dev_wasi_nn_nif, {model_different_config, ModelPath, reinitializing}), + % Cleanup old context for this model + deinit_backend(OldContext), + % Create new context for this model + create_model_context(ModelPath, Config); + [] -> + ?event(dev_wasi_nn_nif, {model_not_loaded, ModelPath, creating_new_context}), + create_model_context(ModelPath, Config) + end. + +%% @doc Create a new model context. +%% Initializes a new backend context and loads the model. +%% +%% @param ModelPath Path to the model file. +%% @param Config Configuration for the model. +%% @returns {ok, Context} on success, {error, Reason} on failure. +create_model_context(ModelPath, Config) -> + ensure_cache_table(), + ModelKey = {?SINGLETON_KEY, global_backend, ModelPath}, + % Get or create the global backend context for this model + case ets:lookup(?CACHE_TAB, ModelKey) of + [{_, {ok, Context}}] -> + ?event(dev_wasi_nn_nif, {using_existing_global_backend, ModelPath}), + load_model_with_context(Context, ModelPath, Config); + [] -> + ?event(dev_wasi_nn_nif, {creating_new_global_backend, ModelPath}), + case init_backend() of + {ok, Context} -> + ets:insert(?CACHE_TAB, {ModelKey, {ok, Context}}), + ?event(dev_wasi_nn_nif, {global_backend_created, ModelPath}), + load_model_with_context(Context, ModelPath, Config); + Error -> + ?event(dev_wasi_nn_nif, {failed_to_create_global_backend, ModelPath, Error}), + Error + end + end. + +%% @doc Load a model with an existing context. +%% Uses an existing backend context to load a model. +%% +%% @param Context The backend context. +%% @param ModelPath Path to the model file. +%% @param Config Configuration for the model. +%% @returns {ok, Context} on success, {error, Reason} on failure. +load_model_with_context(Context, ModelPath, Config) -> + case load_by_name_with_config(Context, ModelPath, Config) of + ok -> + ModelKey = {?SINGLETON_KEY, model_context, ModelPath}, + ets:insert(?CACHE_TAB, {ModelKey, {ok, Context, Config}}), + ets:insert(?CACHE_TAB, {{?SINGLETON_KEY, current_model}, {ModelPath, Config, Context}}), + ?event(dev_wasi_nn_nif, {model_context_created, ModelPath}), + {ok, Context}; + Error -> + ?event(dev_wasi_nn_nif, {failed_to_load_model, ModelPath, Error}), + % Cleanup the backend context since model loading failed + deinit_backend(Context), + ets:delete(?CACHE_TAB, {?SINGLETON_KEY, global_backend, ModelPath}), + {error, {model_load_failed, Error}} + end. + +%% @doc Get information about the currently loaded model. +%% Retrieves the model path, configuration, and context for the currently loaded model. +%% +%% @returns {ok, {ModelPath, Config, Context}} on success, +%% {error, no_model_loaded} if no model is loaded. +get_current_model_info() -> + ensure_cache_table(), + case ets:lookup(?CACHE_TAB, {?SINGLETON_KEY, current_model}) of + [{_, {ModelPath, Config, Context}}] -> {ok, {ModelPath, Config, Context}}; + [] -> {error, no_model_loaded} + end. + +%% @doc Clean up all contexts for a specific model. +%% Removes all execution contexts and the model context for a specific model. +%% +%% @param ModelPath Path to the model file. +%% @returns ok. +cleanup_model_contexts(ModelPath) -> + ensure_cache_table(), + % Clean up all execution contexts for this model + ets:match_delete(?CACHE_TAB, {{?SINGLETON_KEY, context_initialized, ModelPath, '_'}, '_'}), + % Clean up the model context + case ets:lookup(?CACHE_TAB, {?SINGLETON_KEY, model_context, ModelPath}) of + [{_, {ok, Context, _Config}}] -> + deinit_backend(Context), + ets:delete(?CACHE_TAB, {?SINGLETON_KEY, model_context, ModelPath}), + ets:delete(?CACHE_TAB, {?SINGLETON_KEY, global_backend, ModelPath}), + ?event(dev_wasi_nn_nif, {cleaned_up_contexts, ModelPath}), + ok; + [] -> + ?event(dev_wasi_nn_nif, {no_context_to_cleanup, ModelPath}), + ok + end. + +%% @doc Clean up all cached contexts. +%% Removes all model contexts and execution contexts from the cache. +%% Useful for testing or memory management. +%% +%% @returns ok. +cleanup_all_contexts() -> + ensure_cache_table(), + % Get all model contexts and clean them up + ModelContexts = ets:match(?CACHE_TAB, {{?SINGLETON_KEY, model_context, '$1'}, {ok, '$2', '$3'}}), + lists:foreach(fun([ModelPath, Context, _Config]) -> + deinit_backend(Context), + ?event(dev_wasi_nn_nif, {cleaned_up_context, ModelPath}) + end, ModelContexts), + % Clear the entire cache + ets:delete_all_objects(?CACHE_TAB), + ?event(dev_wasi_nn_nif, {all_contexts_cleaned_up}), + ok. + +%% @doc Helper function to safely access the ETS table. +%% Ensures that the ETS table exists and has an owner process. +%% If the table doesn't exist, it starts the cache owner process. +%% If the table exists but has no owner, it restarts the owner process. +%% +%% @returns ok if the table exists and has an owner, +%% {ok, Pid} if a new owner process was started. +ensure_cache_table() -> + case ets:info(?CACHE_TAB) of + undefined -> + % Start the cache owner which will create the table + ?event(dev_wasi_nn_nif, {table_doesnt_exist, starting_cache_owner}), + start_cache_owner(); + _ -> + % Table exists, ensure owner process is running + case whereis(?CACHE_OWNER_NAME) of + undefined -> + % Strange case: table exists but no owner - restart owner + ?event(dev_wasi_nn_nif, {table_exists_no_owner, restarting_owner}), + start_cache_owner(); + _ -> + % All good, table exists and owner is running + ok + end + end. + +%% @doc Function to ensure execution context is only initialized once per session and model. +%% Checks if an execution context already exists for the given session and model, +%% and reuses it if possible. Otherwise, creates a new execution context. +%% +%% @param Context The model context. +%% @param SessionId The session identifier. +%% @returns {ok, ExecContextId} on success, {error, Reason} on failure. +init_execution_context_once(Context, SessionId) -> + ensure_cache_table(), + % Get current model info to create a unique session key per model + case get_current_model_info() of + {ok, {ModelPath, _Config, _Context}} -> + SessionKey = {?SINGLETON_KEY, context_initialized, ModelPath, SessionId}, + case ets:lookup(?CACHE_TAB, SessionKey) of + [{_, {ok, ExecContextId}}] -> + ?event(dev_wasi_nn_nif, {execution_context_already_initialized, SessionId, ModelPath}), + {ok, ExecContextId}; + [] -> + Result = init_execution_context(Context, SessionId), + case Result of + {ok, ExecContextId} -> + ets:insert(?CACHE_TAB, {SessionKey, {ok, ExecContextId}}), + ?event(dev_wasi_nn_nif, {execution_context_initialized, SessionId, ModelPath}), + {ok, ExecContextId}; + Error -> + ?event(dev_wasi_nn_nif, {failed_to_initialize_execution_context, SessionId, ModelPath, Error}), + Error + end + end; + {error, no_model_loaded} -> + ?event(dev_wasi_nn_nif, {no_model_loaded, cannot_initialize_execution_context}), + {error, no_model_loaded} + end. + + +%% @doc Test WASI-NN inference with a single model. +%% This test validates the complete inference pipeline including model loading, +%% session management, and inference execution. The test uses a model from Arweave +%% to avoid network dependencies during inference testing. +%% +%% The test performs the following steps: +%% 1. Loads a model from Arweave +%% 2. Creates a model context using the NIF +%% 3. Initializes an execution context for the session +%% 4. Runs inference with a test prompt +%% 5. Validates the output is not empty +%% 6. Cleans up all contexts +%% +%% @returns ok on success, throws an error on failure. +run_inference_test() -> + % Path to the pre-downloaded model file + {ok, Path} = dev_wasi_nn:read_model_by_ID("ISrbGzQot05rs_HKC08O_SmkipYQnqgB1yC3mjZZeEo"), + ?event(dev_wasi_nn_nif, {model_path, Path}), + % Model configuration for GPU inference + % - n_gpu_layers: Number of layers to offload to GPU + % - ctx_size: Context window size for the model + % - stream-stdout: Enable streaming output + % - enable_debug_log: Enable debug logging + Config = + "{\"n_gpu_layers\":98,\"ctx_size\":2048,\"stream-stdout\":true,\"enable_debug_log\":true}", + % Session identifier for context management + SessionId = "test_session_1", + % Test prompt for inference + Prompt = "Who are you ?", + % Step 1: Load the model and create a context + % This will either create a new context or reuse an existing one + {ok, Context} = switch_model(Path, Config), + ?event(dev_wasi_nn_nif, {model_loaded, Context, Path, Config}), + % Step 2: Create or reuse execution context for the session + % This ensures session-specific state management + {ok, ExecContextId} = init_execution_context_once(Context, SessionId), + % Step 3: Run inference with the provided prompt + % This is the core inference operation + {ok, Output} = run_inference(Context, ExecContextId, Prompt), + ?event(output, Output), + % Step 4: Validate the inference output + % Ensure we got a meaningful response + ?assertNotEqual(Output, ""), + % Step 5: Clean up all contexts to free resources + % This is important for memory management + cleanup_all_contexts(), + ok.