Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@
*/

#define WASM_ENABLE_WASI_EPHEMERAL_NN 1
#define WASI_NN_NAME(name) wasi_ephemeral_nn_##name

#include "wasi_nn.h"

#undef WASM_ENABLE_WASI_EPHEMERAL_NN
#undef WASI_NN_NAME
54 changes: 32 additions & 22 deletions core/iwasm/libraries/wasi-nn/include/wasi_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@
* @return wasi_nn_error Execution status.
*/
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding,
execution_target target, graph *g) WASI_NN_IMPORT("load");
WASI_NN_ERROR_TYPE
WASI_NN_NAME(load)
(WASI_NN_NAME(graph_builder) * builder, uint32_t builder_len,
WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target,
WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load");
#else
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
execution_target target, graph *g) WASI_NN_IMPORT("load");
WASI_NN_ERROR_TYPE
WASI_NN_NAME(load)
(WASI_NN_NAME(graph_builder_array) * builder,
WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target,
WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load");
#endif

wasi_nn_error
load_by_name(const char *name, uint32_t name_len, graph *g)
WASI_NN_ERROR_TYPE
WASI_NN_NAME(load_by_name)
(const char *name, uint32_t name_len, WASI_NN_NAME(graph) * g)
WASI_NN_IMPORT("load_by_name");

/**
Expand All @@ -59,8 +64,9 @@ load_by_name(const char *name, uint32_t name_len, graph *g)
* @param ctx Execution context.
* @return wasi_nn_error Execution status.
*/
wasi_nn_error
init_execution_context(graph g, graph_execution_context *ctx)
WASI_NN_ERROR_TYPE
WASI_NN_NAME(init_execution_context)
(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx)
WASI_NN_IMPORT("init_execution_context");

/**
Expand All @@ -71,18 +77,20 @@ init_execution_context(graph g, graph_execution_context *ctx)
* @param tensor Input tensor.
* @return wasi_nn_error Execution status.
*/
wasi_nn_error
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
WASI_NN_IMPORT("set_input");
WASI_NN_ERROR_TYPE
WASI_NN_NAME(set_input)
(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
WASI_NN_NAME(tensor) * tensor) WASI_NN_IMPORT("set_input");

/**
* @brief Compute the inference on the given inputs.
*
* @param ctx Execution context.
* @return wasi_nn_error Execution status.
*/
wasi_nn_error
compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
WASI_NN_ERROR_TYPE
WASI_NN_NAME(compute)
(WASI_NN_NAME(graph_execution_context) ctx) WASI_NN_IMPORT("compute");

/**
* @brief Extract the outputs after inference.
Expand All @@ -97,14 +105,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
* @return wasi_nn_error Execution status.
*/
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t output_tensor_max_size,
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
WASI_NN_ERROR_TYPE
WASI_NN_NAME(get_output)
(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
WASI_NN_NAME(tensor_data) output_tensor, uint32_t output_tensor_max_size,
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
#else
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
WASI_NN_ERROR_TYPE
WASI_NN_NAME(get_output)
(graph_execution_context ctx, uint32_t index,
WASI_NN_NAME(tensor_data) output_tensor, uint32_t *output_tensor_size)
WASI_NN_IMPORT("get_output");
#endif

Expand Down
107 changes: 70 additions & 37 deletions core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
extern "C" {
#endif

/* our host logic doesn't use any prefix. neither legacy wasi_nn.h does. */

#if !defined(__wasm__) || !defined(WASI_NN_NAME)
#define WASI_NN_NAME(name) name
#define WASI_NN_ERROR_NAME(name) name
#define WASI_NN_TYPE_NAME(name) name
#define WASI_NN_ENCODING_NAME(name) name
#define WASI_NN_TARGET_NAME(name) name
#define WASI_NN_ERROR_TYPE wasi_nn_error
#else
#define WASI_NN_ERROR_NAME(name) WASI_NN_NAME(error_##name)
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name)
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name)
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name)
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error);
#endif

/**
* ERRORS
*
Expand All @@ -22,22 +39,22 @@ extern "C" {
// https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L5-L17
// Error codes returned by functions in this API.
typedef enum {
success = 0,
invalid_argument,
invalid_encoding,
missing_memory,
busy,
runtime_error,
unsupported_operation,
too_large,
not_found,
WASI_NN_ERROR_NAME(success) = 0,
WASI_NN_ERROR_NAME(invalid_argument),
WASI_NN_ERROR_NAME(invalid_encoding),
WASI_NN_ERROR_NAME(missing_memory),
WASI_NN_ERROR_NAME(busy),
WASI_NN_ERROR_NAME(runtime_error),
WASI_NN_ERROR_NAME(unsupported_operation),
WASI_NN_ERROR_NAME(too_large),
WASI_NN_ERROR_NAME(not_found),

// for WasmEdge-wasi-nn
end_of_sequence = 100, // End of Sequence Found.
context_full = 101, // Context Full.
prompt_tool_long = 102, // Prompt Too Long.
model_not_found = 103, // Model Not Found.
} wasi_nn_error;
WASI_NN_ERROR_NAME(end_of_sequence) = 100, // End of Sequence Found.
WASI_NN_ERROR_NAME(context_full) = 101, // Context Full.
WASI_NN_ERROR_NAME(prompt_tool_long) = 102, // Prompt Too Long.
WASI_NN_ERROR_NAME(model_not_found) = 103, // Model Not Found.
} WASI_NN_ERROR_TYPE;

/**
* TENSOR
Expand All @@ -51,15 +68,27 @@ typedef enum {
typedef struct {
uint32_t *buf;
uint32_t size;
} tensor_dimensions;
} WASI_NN_NAME(tensor_dimensions);

#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
// sync up with
// https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L19-L28
// The type of the elements in a tensor.
typedef enum { fp16 = 0, fp32, fp64, u8, i32, i64 } tensor_type;
typedef enum {
WASI_NN_TYPE_NAME(fp16) = 0,
WASI_NN_TYPE_NAME(fp32),
WASI_NN_TYPE_NAME(fp64),
WASI_NN_TYPE_NAME(u8),
WASI_NN_TYPE_NAME(i32),
WASI_NN_TYPE_NAME(i64),
} WASI_NN_NAME(tensor_type);
#else
typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type;
typedef enum {
WASI_NN_TYPE_NAME(fp16) = 0,
WASI_NN_TYPE_NAME(fp32),
WASI_NN_TYPE_NAME(up8),
WASI_NN_TYPE_NAME(ip32),
} WASI_NN_NAME(tensor_type);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */

// The tensor data.
Expand All @@ -70,24 +99,24 @@ typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type;
// 4-byte f32 elements would have a data array of length 16). Naturally, this
// representation requires some knowledge of how to lay out data in
// memory--e.g., using row-major ordering--and could perhaps be improved.
typedef uint8_t *tensor_data;
typedef uint8_t *WASI_NN_NAME(tensor_data);

// A tensor.
typedef struct {
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
// represent a tensor containing a single value, use `[1]` for the tensor
// dimensions.
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
tensor_dimensions dimensions;
WASI_NN_NAME(tensor_dimensions) dimensions;
#else
tensor_dimensions *dimensions;
WASI_NN_NAME(tensor_dimensions) * dimensions;
#endif
// Describe the type of element in the tensor (e.g., f32).
uint8_t type;
uint8_t _pad[3];
// Contains the tensor data.
tensor_data data;
} tensor;
WASI_NN_NAME(tensor_data) data;
} WASI_NN_NAME(tensor);

/**
* GRAPH
Expand All @@ -102,37 +131,41 @@ typedef struct {
typedef struct {
uint8_t *buf;
uint32_t size;
} graph_builder;
} WASI_NN_NAME(graph_builder);

typedef struct {
graph_builder *buf;
WASI_NN_NAME(graph_builder) * buf;
uint32_t size;
} graph_builder_array;
} WASI_NN_NAME(graph_builder_array);

// An execution graph for performing inference (i.e., a model).
typedef uint32_t graph;
typedef uint32_t WASI_NN_NAME(graph);

// sync up with
// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75
// Describes the encoding of the graph. This allows the API to be implemented by
// various backends that encode (i.e., serialize) their graph IR with different
// formats.
typedef enum {
openvino = 0,
onnx,
tensorflow,
pytorch,
tensorflowlite,
ggml,
autodetect,
unknown_backend,
} graph_encoding;
WASI_NN_ENCODING_NAME(openvino) = 0,
WASI_NN_ENCODING_NAME(onnx),
WASI_NN_ENCODING_NAME(tensorflow),
WASI_NN_ENCODING_NAME(pytorch),
WASI_NN_ENCODING_NAME(tensorflowlite),
WASI_NN_ENCODING_NAME(ggml),
WASI_NN_ENCODING_NAME(autodetect),
WASI_NN_ENCODING_NAME(unknown_backend),
} WASI_NN_NAME(graph_encoding);

// Define where the graph should be executed.
typedef enum execution_target { cpu = 0, gpu, tpu } execution_target;
typedef enum WASI_NN_NAME(execution_target) {
WASI_NN_TARGET_NAME(cpu) = 0,
WASI_NN_TARGET_NAME(gpu),
WASI_NN_TARGET_NAME(tpu),
} WASI_NN_NAME(execution_target);

// Bind a `graph` to the input and output tensors for an inference.
typedef uint32_t graph_execution_context;
typedef uint32_t WASI_NN_NAME(graph_execution_context);

#ifdef __cplusplus
}
Expand Down
49 changes: 26 additions & 23 deletions wamr-wasi-extensions/samples/nn/app.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ print_result(const float *result, size_t sz)
int
main(int argc, char **argv)
{
wasi_nn_error nnret;
wasi_ephemeral_nn_error nnret;
int ret;
void *xml;
size_t xmlsz;
Expand All @@ -112,25 +112,27 @@ main(int argc, char **argv)
exit(1);
}
/* note: openvino takes two buffers, namely IR and weights */
graph_builder builders[2] = { {
.buf = xml,
.size = xmlsz,
},
{
.buf = weights,
.size = weightssz,
} };
graph g;
nnret = load(builders, 2, openvino, cpu, &g);
wasi_ephemeral_nn_graph_builder builders[2] = { {
.buf = xml,
.size = xmlsz,
},
{
.buf = weights,
.size = weightssz,
} };
wasi_ephemeral_nn_graph g;
nnret =
wasi_ephemeral_nn_load(builders, 2, wasi_ephemeral_nn_encoding_openvino,
wasi_ephemeral_nn_target_cpu, &g);
unmap_file(xml, xmlsz);
unmap_file(weights, weightssz);
if (nnret != success) {
if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "load failed with %d\n", (int)nnret);
exit(1);
}
graph_execution_context ctx;
nnret = init_execution_context(g, &ctx);
if (nnret != success) {
wasi_ephemeral_nn_graph_execution_context ctx;
nnret = wasi_ephemeral_nn_init_execution_context(g, &ctx);
if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "init_execution_context failed with %d\n", (int)nnret);
exit(1);
}
Expand All @@ -142,26 +144,27 @@ main(int argc, char **argv)
strerror(ret));
exit(1);
}
tensor tensor = {
wasi_ephemeral_nn_tensor tensor = {
.dimensions = { .buf = (uint32_t[]){1, 3, 224, 224,}, .size = 4, },
.type = fp32,
.type = wasi_ephemeral_nn_type_fp32,
.data = tensordata,
};
nnret = set_input(ctx, 0, &tensor);
nnret = wasi_ephemeral_nn_set_input(ctx, 0, &tensor);
unmap_file(tensordata, tensordatasz);
if (nnret != success) {
if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "set_input failed with %d\n", (int)nnret);
exit(1);
}
nnret = compute(ctx);
if (nnret != success) {
nnret = wasi_ephemeral_nn_compute(ctx);
if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "compute failed with %d\n", (int)nnret);
exit(1);
}
float result[1001];
uint32_t resultsz;
nnret = get_output(ctx, 0, (void *)result, sizeof(result), &resultsz);
if (nnret != success) {
nnret = wasi_ephemeral_nn_get_output(ctx, 0, (void *)result, sizeof(result),
&resultsz);
if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "get_output failed with %d\n", (int)nnret);
exit(1);
}
Expand Down
Loading