Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ compile:

WAMR_VERSION = 2.2.0
WAMR_DIR = _build/wamr
WASI_NN_DIR = _build/wasi_nn/wasi_nn_backend

GENESIS_WASM_BRANCH = tillathehun0/cu-experimental
GENESIS_WASM_REPO = https://github.com/permaweb/ao.git
Expand Down Expand Up @@ -104,3 +105,19 @@ setup-genesis-wasm: $(GENESIS_WASM_SERVER_DIR)
fi
@cd $(GENESIS_WASM_SERVER_DIR) && npm install > /dev/null 2>&1 && \
echo "Installed [email protected] server."
# Set up wasi-nn environment
$(WASI_NN_DIR):
@echo "Cloning wasi-nn backend repository..." && \
git clone --depth=1 -b stage/wasi-nn https://github.com/apuslabs/wasi_nn_backend.git $(WASI_NN_DIR) && \
echo "Cloned wasi-nn backend to $(WASI_NN_DIR)"

setup-wasi-nn: $(WASI_NN_DIR)
@mkdir -p $(WASI_NN_DIR)/lib
@echo "Building wasi-nn backend..." && \
cmake \
$(WAMR_FLAGS) \
-S $(WASI_NN_DIR) \
-B $(WASI_NN_DIR)/build && \
make -C $(WASI_NN_DIR)/build -j8 && \
cp $(WASI_NN_DIR)/build/libwasi_nn_backend.so ./native/wasi_nn_llama && \
echo "Successfully built wasi-nn backend"
183 changes: 183 additions & 0 deletions native/wasi_nn_llama/include/wasi_nn_llama.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Copyright (C) 2019 Intel Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#ifndef WASI_NN_TYPES_H
#define WASI_NN_TYPES_H

#include <stdint.h>
#include <stdbool.h>

#ifdef __cplusplus
extern "C" {
#endif

/**
* ERRORS
*
*/

// sync up with
// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L136 Error
// codes returned by functions in this API.
typedef enum {
// No error occurred.
success = 0,
// Caller module passed an invalid argument.
invalid_argument,
// Invalid encoding.
invalid_encoding,
// The operation timed out.
timeout,
// Runtime Error.
runtime_error,
// Unsupported operation.
unsupported_operation,
// Graph is too large.
too_large,
// Graph not found.
not_found,
// The operation is insecure or has insufficient privilege to be performed.
// e.g., cannot access a hardware feature requested
security,
// The operation failed for an unspecified reason.
unknown,
// for WasmEdge-wasi-nn
end_of_sequence = 100, // End of Sequence Found.
context_full = 101, // Context Full.
prompt_tool_long = 102, // Prompt Too Long.
model_not_found = 103, // Model Not Found.
} wasi_nn_error;

/**
* TENSOR
*
*/

// The dimensions of a tensor.
//
// The array length matches the tensor rank and each element in the array
// describes the size of each dimension.
typedef struct {
uint32_t *buf;
uint32_t size;
} tensor_dimensions;

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
// sync up with
// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L27
// The type of the elements in a tensor.
typedef enum { fp16 = 0, fp32, fp64, bf16, u8, i32, i64 } tensor_type;
#else
typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type;
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

// The tensor data.
//
// Initially conceived as a sparse representation, each empty cell would be
// filled with zeros and the array length must match the product of all of the
// dimensions and the number of bytes in the type (e.g., a 2x2 tensor with
// 4-byte f32 elements would have a data array of length 16). Naturally, this
// representation requires some knowledge of how to lay out data in
// memory--e.g., using row-major ordering--and could perhaps be improved.
typedef uint8_t *tensor_data;

// A tensor.
typedef struct {
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
// represent a tensor containing a single value, use `[1]` for the tensor
// dimensions.
tensor_dimensions *dimensions;
// Describe the type of element in the tensor (e.g., f32).
tensor_type type;
// Contains the tensor data.
tensor_data data;
} tensor;

/**
* GRAPH
*
*/

// The graph initialization data.
//
// This consists of an array of buffers because implementing backends may encode
// their graph IR in parts (e.g., OpenVINO stores its IR and weights
// separately).
typedef struct {
uint8_t *buf;
uint32_t size;
} graph_builder;

typedef struct {
graph_builder *buf;
uint32_t size;
} graph_builder_array;

// An execution graph for performing inference (i.e., a model).
typedef uint32_t graph;

// sync up with
// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75
// Describes the encoding of the graph. This allows the API to be implemented by
// various backends that encode (i.e., serialize) their graph IR with different
// formats.
typedef enum {
openvino = 0,
onnx,
tensorflow,
pytorch,
tensorflowlite,
ggml,
autodetect,
unknown_backend,
} graph_encoding;

// Define where the graph should be executed.
typedef enum execution_target { cpu = 0, gpu, tpu } execution_target;

// Bind a `graph` to the input and output tensors for an inference.
typedef uint32_t graph_execution_context;


__attribute__((visibility("default"))) wasi_nn_error
init_backend(void **ctx) ;

__attribute__((visibility("default"))) wasi_nn_error
init_backend_with_config(void **ctx, const char *config, uint32_t config_len);

__attribute__((visibility("default"))) wasi_nn_error
deinit_backend(void *ctx);

__attribute__((visibility("default"))) wasi_nn_error
load_by_name_with_config(void *ctx, const char *filename, uint32_t filename_len,
const char *config, uint32_t config_len, graph *g);

__attribute__((visibility("default"))) wasi_nn_error
init_execution_context(void *ctx, const char *session_id, graph_execution_context *exec_ctx);

__attribute__((visibility("default"))) wasi_nn_error
close_execution_context(void *ctx, graph_execution_context exec_ctx);

__attribute__((visibility("default"))) wasi_nn_error
run_inference(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor *input_tensor,tensor_data output_tensor, uint32_t *output_tensor_size, const char *options);

__attribute__((visibility("default"))) wasi_nn_error
set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor *wasi_nn_tensor);

__attribute__((visibility("default"))) wasi_nn_error
compute(void *ctx, graph_execution_context exec_ctx);

__attribute__((visibility("default"))) wasi_nn_error
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size);


#ifdef __cplusplus
}
#endif
#endif

34 changes: 34 additions & 0 deletions native/wasi_nn_llama/include/wasi_nn_logging.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include <pthread.h>
#include <stdarg.h>
#include <string.h>
#include <time.h>
#ifndef HB_LOGGING_H
#define HB_LOGGING_H


// Enable debug logging by default if not defined
#define HB_DEBUG 0
#ifndef HB_DEBUG
#endif


#define DRV_DEBUG(format, ...) beamr_print(HB_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define DRV_PRINT(format, ...) beamr_print(1, __FILE__, __LINE__, format, ##__VA_ARGS__)

/*
* Function: beamr_print
* --------------------
* This function prints a formatted message to the standard output, prefixed with the thread
* ID, file name, and line number where the log was generated.
*
* print: A flag that controls whether the message is printed (1 to print, 0 to skip).
* file: The source file name where the log was generated.
* line: The line number where the log was generated.
* format: The format string for the message.
* ...: The variables to be printed in the format.
*/
void beamr_print(int print, const char* file, int line, const char* format, ...);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker on this PR, but perhaps if we are using these debugging prints more widely now, they should be abstracted into utility HB_PRINT (etc) functions?




#endif // HB_LOGGING_H
40 changes: 40 additions & 0 deletions native/wasi_nn_llama/include/wasi_nn_nif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef WASI_NN_NIF_H
#define WASI_NN_NIF_H

#include "wasi_nn_llama.h"
#include <erl_nif.h>
#include <stdio.h>
#include <string.h>
#include <dlfcn.h>


// Function pointer types
typedef wasi_nn_error (*init_backend_fn)(void **ctx);
typedef wasi_nn_error (*init_backend_with_config_fn)(void **ctx, const char *config, uint32_t config_len);
typedef wasi_nn_error (*deinit_backend_fn)(void *ctx);
typedef wasi_nn_error (*init_execution_context_fn)(void *ctx, const char *session_id, graph_execution_context *exec_ctx);
typedef wasi_nn_error (*close_execution_context_fn)(void *ctx, graph_execution_context exec_ctx);
typedef wasi_nn_error (*load_by_name_with_config_fn)(void *ctx, const char *filename, uint32_t filename_len,
const char *config, uint32_t config_len, graph *g);
typedef wasi_nn_error (*run_inference_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor *input_tensor,tensor_data output_tensor, uint32_t *output_tensor_size, const char *options);
typedef wasi_nn_error (*set_input_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor *tensor);
typedef wasi_nn_error (*compute_fn)(void *ctx, graph_execution_context exec_ctx);
typedef wasi_nn_error (*get_output_fn)(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor_data output, uint32_t *output_size);

// Structure to hold all function pointers
typedef struct {
void* handle;
init_backend_fn init_backend;
init_backend_with_config_fn init_backend_with_config;
deinit_backend_fn deinit_backend;
load_by_name_with_config_fn load_by_name_with_config;
init_execution_context_fn init_execution_context;
close_execution_context_fn close_execution_context;
run_inference_fn run_inference;
set_input_fn set_input;
compute_fn compute;
get_output_fn get_output;

} wasi_nn_backend_api;
#endif // WASI_NN_NIF_H
36 changes: 36 additions & 0 deletions native/wasi_nn_llama/src/wasi_nn_logging.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "../include/wasi_nn_logging.h"



void beamr_print(int print, const char* file, int line, const char* format, ...) {
va_list args;
va_start(args, format);
if(print) {
pthread_t thread_id = pthread_self();
printf("[DBG#%p @ %s:%d] ", thread_id, file, line);
vprintf(format, args);
printf("\r\n");
}
va_end(args);
}

// void send_error(Proc* proc, const char* message_fmt, ...) {
// va_list args;
// va_start(args, message_fmt);
// char* message = driver_alloc(256);
// vsnprintf(message, 256, message_fmt, args);
// DRV_DEBUG("Sending error message: %s", message);
// ErlDrvTermData* msg = driver_alloc(sizeof(ErlDrvTermData) * 7);
// int msg_index = 0;
// msg[msg_index++] = ERL_DRV_ATOM;
// msg[msg_index++] = atom_error;
// msg[msg_index++] = ERL_DRV_STRING;
// msg[msg_index++] = (ErlDrvTermData)message;
// msg[msg_index++] = strlen(message);
// msg[msg_index++] = ERL_DRV_TUPLE;
// msg[msg_index++] = 2;

// int msg_res = erl_drv_output_term(proc->port_term, msg, msg_index);
// DRV_DEBUG("Sent error message. Res: %d", msg_res);
// va_end(args);
// }
Loading