forked from samcamwilliams/HyperBEAM
-
Notifications
You must be signed in to change notification settings - Fork 71
Device wasi_nn for AI inference #393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Alex-wuhu
wants to merge
7
commits into
permaweb:edge
Choose a base branch
from
apuslabs:edge/PR
base: edge
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
86dec73
add device wasi nn and compile set up
Alex-wuhu 224b581
Merge remote-tracking branch 'origin/permaweb/edge' into edge/PR
beaf12c
refactor: update model paths and clean up inference logic in dev_wasi…
b3c2eb0
clean up code add UT
f2959f6
remove debug
90fe7f1
using hb_cache for model files
d4442a8
refactor: enhance model retrieval and inference logic by adding optio…
Alex-wuhu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ compile: | |
|
|
||
| WAMR_VERSION = 2.2.0 | ||
| WAMR_DIR = _build/wamr | ||
| WASI_NN_DIR = _build/wasi_nn/wasi_nn_backend | ||
|
|
||
| GENESIS_WASM_BRANCH = tillathehun0/cu-experimental | ||
| GENESIS_WASM_REPO = https://github.com/permaweb/ao.git | ||
|
|
@@ -104,3 +105,19 @@ setup-genesis-wasm: $(GENESIS_WASM_SERVER_DIR) | |
| fi | ||
| @cd $(GENESIS_WASM_SERVER_DIR) && npm install > /dev/null 2>&1 && \ | ||
| echo "Installed [email protected] server." | ||
| # Set up wasi-nn environment | ||
| $(WASI_NN_DIR): | ||
| @echo "Cloning wasi-nn backend repository..." && \ | ||
| git clone --depth=1 -b stage/wasi-nn https://github.com/apuslabs/wasi_nn_backend.git $(WASI_NN_DIR) && \ | ||
| echo "Cloned wasi-nn backend to $(WASI_NN_DIR)" | ||
|
|
||
| setup-wasi-nn: $(WASI_NN_DIR) | ||
| @mkdir -p $(WASI_NN_DIR)/lib | ||
| @echo "Building wasi-nn backend..." && \ | ||
| cmake \ | ||
| $(WAMR_FLAGS) \ | ||
| -S $(WASI_NN_DIR) \ | ||
| -B $(WASI_NN_DIR)/build && \ | ||
| make -C $(WASI_NN_DIR)/build -j8 && \ | ||
| cp $(WASI_NN_DIR)/build/libwasi_nn_backend.so ./native/wasi_nn_llama && \ | ||
| echo "Successfully built wasi-nn backend" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| /* | ||
| * Copyright (C) 2019 Intel Corporation. All rights reserved. | ||
| * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| */ | ||
|
|
||
| #ifndef WASI_NN_TYPES_H | ||
| #define WASI_NN_TYPES_H | ||
|
|
||
| #include <stdint.h> | ||
| #include <stdbool.h> | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /** | ||
| * ERRORS | ||
| * | ||
| */ | ||
|
|
||
| // sync up with | ||
| // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L136 Error | ||
| // codes returned by functions in this API. | ||
| typedef enum { | ||
| // No error occurred. | ||
| success = 0, | ||
| // Caller module passed an invalid argument. | ||
| invalid_argument, | ||
| // Invalid encoding. | ||
| invalid_encoding, | ||
| // The operation timed out. | ||
| timeout, | ||
| // Runtime Error. | ||
| runtime_error, | ||
| // Unsupported operation. | ||
| unsupported_operation, | ||
| // Graph is too large. | ||
| too_large, | ||
| // Graph not found. | ||
| not_found, | ||
| // The operation is insecure or has insufficient privilege to be performed. | ||
| // e.g., cannot access a hardware feature requested | ||
| security, | ||
| // The operation failed for an unspecified reason. | ||
| unknown, | ||
| // for WasmEdge-wasi-nn | ||
| end_of_sequence = 100, // End of Sequence Found. | ||
| context_full = 101, // Context Full. | ||
| prompt_tool_long = 102, // Prompt Too Long. | ||
| model_not_found = 103, // Model Not Found. | ||
| } wasi_nn_error; | ||
|
|
||
| /** | ||
| * TENSOR | ||
| * | ||
| */ | ||
|
|
||
| // The dimensions of a tensor. | ||
| // | ||
| // The array length matches the tensor rank and each element in the array | ||
| // describes the size of each dimension. | ||
| typedef struct { | ||
| uint32_t *buf; | ||
| uint32_t size; | ||
| } tensor_dimensions; | ||
|
|
||
| #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 | ||
| // sync up with | ||
| // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L27 | ||
| // The type of the elements in a tensor. | ||
| typedef enum { fp16 = 0, fp32, fp64, bf16, u8, i32, i64 } tensor_type; | ||
| #else | ||
| typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; | ||
| #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ | ||
|
|
||
| // The tensor data. | ||
| // | ||
| // Initially conceived as a sparse representation, each empty cell would be | ||
| // filled with zeros and the array length must match the product of all of the | ||
| // dimensions and the number of bytes in the type (e.g., a 2x2 tensor with | ||
| // 4-byte f32 elements would have a data array of length 16). Naturally, this | ||
| // representation requires some knowledge of how to lay out data in | ||
| // memory--e.g., using row-major ordering--and could perhaps be improved. | ||
| typedef uint8_t *tensor_data; | ||
|
|
||
| // A tensor. | ||
| typedef struct { | ||
| // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To | ||
| // represent a tensor containing a single value, use `[1]` for the tensor | ||
| // dimensions. | ||
| tensor_dimensions *dimensions; | ||
| // Describe the type of element in the tensor (e.g., f32). | ||
| tensor_type type; | ||
| // Contains the tensor data. | ||
| tensor_data data; | ||
| } tensor; | ||
|
|
||
| /** | ||
| * GRAPH | ||
| * | ||
| */ | ||
|
|
||
| // The graph initialization data. | ||
| // | ||
| // This consists of an array of buffers because implementing backends may encode | ||
| // their graph IR in parts (e.g., OpenVINO stores its IR and weights | ||
| // separately). | ||
| typedef struct { | ||
| uint8_t *buf; | ||
| uint32_t size; | ||
| } graph_builder; | ||
|
|
||
| typedef struct { | ||
| graph_builder *buf; | ||
| uint32_t size; | ||
| } graph_builder_array; | ||
|
|
||
| // An execution graph for performing inference (i.e., a model). | ||
| typedef uint32_t graph; | ||
|
|
||
| // sync up with | ||
| // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75 | ||
| // Describes the encoding of the graph. This allows the API to be implemented by | ||
| // various backends that encode (i.e., serialize) their graph IR with different | ||
| // formats. | ||
| typedef enum { | ||
| openvino = 0, | ||
| onnx, | ||
| tensorflow, | ||
| pytorch, | ||
| tensorflowlite, | ||
| ggml, | ||
| autodetect, | ||
| unknown_backend, | ||
| } graph_encoding; | ||
|
|
||
| // Define where the graph should be executed. | ||
| typedef enum execution_target { cpu = 0, gpu, tpu } execution_target; | ||
|
|
||
| // Bind a `graph` to the input and output tensors for an inference. | ||
| typedef uint32_t graph_execution_context; | ||
|
|
||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| init_backend(void **ctx) ; | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| init_backend_with_config(void **ctx, const char *config, uint32_t config_len); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| deinit_backend(void *ctx); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| load_by_name_with_config(void *ctx, const char *filename, uint32_t filename_len, | ||
| const char *config, uint32_t config_len, graph *g); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| init_execution_context(void *ctx, const char *session_id, graph_execution_context *exec_ctx); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| close_execution_context(void *ctx, graph_execution_context exec_ctx); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| run_inference(void *ctx, graph_execution_context exec_ctx, uint32_t index, | ||
| tensor *input_tensor,tensor_data output_tensor, uint32_t *output_tensor_size, const char *options); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index, | ||
| tensor *wasi_nn_tensor); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| compute(void *ctx, graph_execution_context exec_ctx); | ||
|
|
||
| __attribute__((visibility("default"))) wasi_nn_error | ||
| get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index, | ||
| tensor_data output_tensor, uint32_t *output_tensor_size); | ||
|
|
||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif | ||
| #endif | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| #include <pthread.h> | ||
| #include <stdarg.h> | ||
| #include <string.h> | ||
| #include <time.h> | ||
| #ifndef HB_LOGGING_H | ||
| #define HB_LOGGING_H | ||
|
|
||
|
|
||
| // Enable debug logging by default if not defined | ||
| #define HB_DEBUG 0 | ||
| #ifndef HB_DEBUG | ||
| #endif | ||
|
|
||
|
|
||
| #define DRV_DEBUG(format, ...) beamr_print(HB_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) | ||
| #define DRV_PRINT(format, ...) beamr_print(1, __FILE__, __LINE__, format, ##__VA_ARGS__) | ||
|
|
||
| /* | ||
| * Function: beamr_print | ||
| * -------------------- | ||
| * This function prints a formatted message to the standard output, prefixed with the thread | ||
| * ID, file name, and line number where the log was generated. | ||
| * | ||
| * print: A flag that controls whether the message is printed (1 to print, 0 to skip). | ||
| * file: The source file name where the log was generated. | ||
| * line: The line number where the log was generated. | ||
| * format: The format string for the message. | ||
| * ...: The variables to be printed in the format. | ||
| */ | ||
| void beamr_print(int print, const char* file, int line, const char* format, ...); | ||
|
|
||
|
|
||
|
|
||
| #endif // HB_LOGGING_H | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| #ifndef WASI_NN_NIF_H | ||
| #define WASI_NN_NIF_H | ||
|
|
||
| #include "wasi_nn_llama.h" | ||
| #include <erl_nif.h> | ||
| #include <stdio.h> | ||
| #include <string.h> | ||
| #include <dlfcn.h> | ||
|
|
||
|
|
||
| // Function pointer types | ||
| typedef wasi_nn_error (*init_backend_fn)(void **ctx); | ||
| typedef wasi_nn_error (*init_backend_with_config_fn)(void **ctx, const char *config, uint32_t config_len); | ||
| typedef wasi_nn_error (*deinit_backend_fn)(void *ctx); | ||
| typedef wasi_nn_error (*init_execution_context_fn)(void *ctx, const char *session_id, graph_execution_context *exec_ctx); | ||
| typedef wasi_nn_error (*close_execution_context_fn)(void *ctx, graph_execution_context exec_ctx); | ||
| typedef wasi_nn_error (*load_by_name_with_config_fn)(void *ctx, const char *filename, uint32_t filename_len, | ||
| const char *config, uint32_t config_len, graph *g); | ||
| typedef wasi_nn_error (*run_inference_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, | ||
| tensor *input_tensor,tensor_data output_tensor, uint32_t *output_tensor_size, const char *options); | ||
| typedef wasi_nn_error (*set_input_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor *tensor); | ||
| typedef wasi_nn_error (*compute_fn)(void *ctx, graph_execution_context exec_ctx); | ||
| typedef wasi_nn_error (*get_output_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor_data output, uint32_t *output_size); | ||
|
|
||
| // Structure to hold all function pointers | ||
| typedef struct { | ||
| void* handle; | ||
| init_backend_fn init_backend; | ||
| init_backend_with_config_fn init_backend_with_config; | ||
| deinit_backend_fn deinit_backend; | ||
| load_by_name_with_config_fn load_by_name_with_config; | ||
| init_execution_context_fn init_execution_context; | ||
| close_execution_context_fn close_execution_context; | ||
| run_inference_fn run_inference; | ||
| set_input_fn set_input; | ||
| compute_fn compute; | ||
| get_output_fn get_output; | ||
|
|
||
| } wasi_nn_backend_api; | ||
| #endif // WASI_NN_NIF_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #include "../include/wasi_nn_logging.h" | ||
|
|
||
|
|
||
|
|
||
| void beamr_print(int print, const char* file, int line, const char* format, ...) { | ||
| va_list args; | ||
| va_start(args, format); | ||
| if(print) { | ||
| pthread_t thread_id = pthread_self(); | ||
| printf("[DBG#%p @ %s:%d] ", thread_id, file, line); | ||
| vprintf(format, args); | ||
| printf("\r\n"); | ||
| } | ||
| va_end(args); | ||
| } | ||
|
|
||
| // void send_error(Proc* proc, const char* message_fmt, ...) { | ||
| // va_list args; | ||
| // va_start(args, message_fmt); | ||
| // char* message = driver_alloc(256); | ||
| // vsnprintf(message, 256, message_fmt, args); | ||
| // DRV_DEBUG("Sending error message: %s", message); | ||
| // ErlDrvTermData* msg = driver_alloc(sizeof(ErlDrvTermData) * 7); | ||
| // int msg_index = 0; | ||
| // msg[msg_index++] = ERL_DRV_ATOM; | ||
| // msg[msg_index++] = atom_error; | ||
| // msg[msg_index++] = ERL_DRV_STRING; | ||
| // msg[msg_index++] = (ErlDrvTermData)message; | ||
| // msg[msg_index++] = strlen(message); | ||
| // msg[msg_index++] = ERL_DRV_TUPLE; | ||
| // msg[msg_index++] = 2; | ||
|
|
||
| // int msg_res = erl_drv_output_term(proc->port_term, msg, msg_index); | ||
| // DRV_DEBUG("Sent error message. Res: %d", msg_res); | ||
| // va_end(args); | ||
| // } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a blocker on this PR, but perhaps if we are using these debugging prints more widely now, they should be abstracted into utility
HB_PRINT(etc) functions?