Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions source/extensions/common/wasm/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ Wasm::Wasm(absl::string_view vm, absl::string_view id, absl::string_view initial
}

void Wasm::registerCallbacks() {
#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), "envoy", #_fn, &_fn##Handler);
#define _REGISTER(_fn) wasm_vm_->registerCallback("envoy", #_fn, &_fn##Handler);
if (is_emscripten_) {
_REGISTER(getTotalMemory);
_REGISTER(_emscripten_get_heap_size);
Expand All @@ -962,8 +962,7 @@ void Wasm::registerCallbacks() {
#undef _REGISTER

// Calls with the "_proxy_" prefix.
#define _REGISTER_PROXY(_fn) \
registerCallback(wasm_vm_.get(), "envoy", "_proxy_" #_fn, &_fn##Handler);
#define _REGISTER_PROXY(_fn) wasm_vm_->registerCallback("envoy", "_proxy_" #_fn, &_fn##Handler);
_REGISTER_PROXY(log);

_REGISTER_PROXY(getRequestStreamInfoProtocol);
Expand Down Expand Up @@ -1018,19 +1017,19 @@ void Wasm::registerCallbacks() {
void Wasm::establishEnvironment() {
if (is_emscripten_) {
wasm_vm_->makeModule("global");
emscripten_NaN_ = makeGlobal(wasm_vm_.get(), "global", "NaN", std::nan("0"));
emscripten_NaN_ = wasm_vm_->makeGlobal("global", "NaN", std::nan("0"));
emscripten_Infinity_ =
makeGlobal(wasm_vm_.get(), "global", "Infinity", std::numeric_limits<double>::infinity());
wasm_vm_->makeGlobal("global", "Infinity", std::numeric_limits<double>::infinity());
}
}

void Wasm::getFunctions() {
#define _GET(_fn) getFunction(wasm_vm_.get(), "_" #_fn, &_fn##_);
#define _GET(_fn) wasm_vm_->getFunction("_" #_fn, &_fn##_);
_GET(malloc);
_GET(free);
#undef _GET

#define _GET_PROXY(_fn) getFunction(wasm_vm_.get(), "_proxy_" #_fn, &_fn##_);
#define _GET_PROXY(_fn) wasm_vm_->getFunction("_proxy_" #_fn, &_fn##_);
_GET_PROXY(onStart);
_GET_PROXY(onConfigure);
_GET_PROXY(onTick);
Expand Down
107 changes: 58 additions & 49 deletions source/extensions/common/wasm/wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,33 @@ class WasmVm;
using Pairs = std::vector<std::pair<absl::string_view, absl::string_view>>;
using PairsWithStringValues = std::vector<std::pair<absl::string_view, std::string>>;

// 1st arg is always a pointer to Context (Context*).
using WasmCall0Void = std::function<void(Context*)>;
using WasmCall1Void = std::function<void(Context*, uint32_t)>;
using WasmCall1Int = std::function<uint32_t(Context*, uint32_t)>;
using WasmCall2Void = std::function<void(Context*, uint32_t, uint32_t)>;

using WasmContextCall0Void = std::function<void(Context*, uint32_t context_id)>;
using WasmContextCall7Void = std::function<void(Context*, uint32_t context_id, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t, uint32_t, uint32_t)>;

using WasmContextCall0Int = std::function<uint32_t(Context*, uint32_t context_id)>;
using WasmContextCall2Int =
std::function<uint32_t(Context*, uint32_t context_id, uint32_t, uint32_t)>;
using WasmCall8Void = std::function<void(Context*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t)>;
using WasmCall1Int = std::function<uint32_t(Context*, uint32_t)>;
using WasmCall3Int = std::function<uint32_t(Context*, uint32_t, uint32_t, uint32_t)>;

// 1st arg is always a context_id (uint32_t).
using WasmContextCall0Void = WasmCall1Void;
using WasmContextCall7Void = WasmCall8Void;
using WasmContextCall0Int = WasmCall1Int;
using WasmContextCall2Int = WasmCall3Int;

// 1st arg is always a pointer to raw_context (void*).
using WasmCallback0Void = void (*)(void*);
using WasmCallback1Void = void (*)(void*, uint32_t);
using WasmCallback2Void = void (*)(void*, uint32_t, uint32_t);
using WasmCallback3Void = void (*)(void*, uint32_t, uint32_t, uint32_t);
using WasmCallback4Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback5Void = void (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback0Int = uint32_t (*)(void*);
using WasmCallback3Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t);
using WasmCallback5Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t);
using WasmCallback9Int = uint32_t (*)(void*, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t,
uint32_t, uint32_t, uint32_t, uint32_t);

// A context which will be the target of callbacks for a particular session
// e.g. a handler of a stream.
Expand Down Expand Up @@ -435,6 +450,40 @@ class WasmVm : public Logger::Loggable<Logger::Id::wasm> {
// Get the contents of the user section with the given name or "" if it does not exist and
// optionally a presence indicator.
virtual absl::string_view getUserSection(absl::string_view name, bool* present = nullptr) PURE;

// Get typed function exported by the WASM module.
virtual void getFunction(absl::string_view functionName, WasmCall0Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall1Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall2Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall8Void* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall1Int* f) PURE;
virtual void getFunction(absl::string_view functionName, WasmCall3Int* f) PURE;

// Register typed callbacks exported by the host environment.
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback1Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback2Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback4Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Void f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback0Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback3Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback5Int f) PURE;
virtual void registerCallback(absl::string_view moduleName, absl::string_view functionName,
WasmCallback9Int f) PURE;

// Register typed value exported by the host environment.
virtual std::unique_ptr<Global<double>>
makeGlobal(absl::string_view moduleName, absl::string_view name, double initialValue) PURE;
};

// Create a new low-level WASM VM of the give type (e.g. "envoy.wasm.vm.wavm").
Expand Down Expand Up @@ -466,46 +515,6 @@ class WasmVmException : public EnvoyException {

inline Context::Context(Wasm* wasm) : wasm_(wasm), id_(wasm->allocContextId()) {}

// Forward declarations for VM implemenations.
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>*);

template <typename T>
std::unique_ptr<Global<T>> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue);

template <typename R, typename... Args>
void registerCallback(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*f)(Args...)) {
if (vm->vm() == WasmVmNames::get().Wavm) {
registerCallbackWavm(vm, moduleName, functionName, f);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

template <typename F> void getFunction(WasmVm* vm, absl::string_view functionName, F* function) {
if (vm->vm() == WasmVmNames::get().Wavm) {
getFunctionWavm(vm, functionName, function);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

template <typename T>
std::unique_ptr<Global<T>> makeGlobal(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue) {
if (vm->vm() == WasmVmNames::get().Wavm) {
return makeGlobalWavm(vm, moduleName, name, initialValue);
} else {
throw WasmVmException("unsupported wasm vm");
}
}

inline void* Wasm::allocMemory(uint32_t size, uint32_t* address) {
uint32_t a = malloc_(generalContext(), size);
*address = a;
Expand Down
45 changes: 45 additions & 0 deletions source/extensions/common/wasm/wavm/wavm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ namespace Wasm {

extern thread_local Envoy::Extensions::Common::Wasm::Context* current_context_;

// Forward declarations.
template <typename R, typename... Args>
void getFunctionWavm(WasmVm* vm, absl::string_view functionName,
std::function<R(Context*, Args...)>* function);
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName,
R (*)(Args...));
template <typename T>
std::unique_ptr<Global<T>> makeGlobalWavm(WasmVm* vm, absl::string_view moduleName,
absl::string_view name, T initialValue);

namespace Wavm {

struct Wavm;
Expand Down Expand Up @@ -221,6 +232,40 @@ struct Wavm : public WasmVm {

void getInstantiatedGlobals();

#define _GET_FUNCTION(_type) \
void getFunction(absl::string_view functionName, _type* f) override { \
getFunctionWavm(this, functionName, f); \
};
_GET_FUNCTION(WasmCall0Void);
_GET_FUNCTION(WasmCall1Void);
_GET_FUNCTION(WasmCall2Void);
_GET_FUNCTION(WasmCall8Void);
_GET_FUNCTION(WasmCall1Int);
_GET_FUNCTION(WasmCall3Int);
#undef _GET_FUNCTION

#define _REGISTER_CALLBACK(_type) \
void registerCallback(absl::string_view moduleName, absl::string_view functionName, \
_type f) override { \
registerCallbackWavm(this, moduleName, functionName, f); \
};
_REGISTER_CALLBACK(WasmCallback0Void);
_REGISTER_CALLBACK(WasmCallback1Void);
_REGISTER_CALLBACK(WasmCallback2Void);
_REGISTER_CALLBACK(WasmCallback3Void);
_REGISTER_CALLBACK(WasmCallback4Void);
_REGISTER_CALLBACK(WasmCallback5Void);
_REGISTER_CALLBACK(WasmCallback0Int);
_REGISTER_CALLBACK(WasmCallback3Int);
_REGISTER_CALLBACK(WasmCallback5Int);
_REGISTER_CALLBACK(WasmCallback9Int);
#undef _REGISTER_CALLBACK

std::unique_ptr<Global<double>> makeGlobal(absl::string_view moduleName, absl::string_view name,
double initialValue) override {
return makeGlobalWavm(this, moduleName, name, initialValue);
};

bool hasInstantiatedModule_ = false;
IR::Module irModule_;
WAVM::Runtime::ModuleRef module_ = nullptr;
Expand Down