diff --git a/source/extensions/common/wasm/wasm.cc b/source/extensions/common/wasm/wasm.cc index cc07486c5c1dc..b1de6cb764acf 100644 --- a/source/extensions/common/wasm/wasm.cc +++ b/source/extensions/common/wasm/wasm.cc @@ -952,7 +952,7 @@ Wasm::Wasm(absl::string_view vm, absl::string_view id, absl::string_view initial } void Wasm::registerFunctions() { -#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), #_fn, &_fn##Handler); +#define _REGISTER(_fn) registerCallback(wasm_vm_.get(), "envoy", #_fn, &_fn##Handler); if (is_emscripten_) { _REGISTER(getTotalMemory); _REGISTER(_emscripten_get_heap_size); @@ -961,7 +961,8 @@ void Wasm::registerFunctions() { #undef _REGISTER // Calls with the "_proxy_" prefix. -#define _REGISTER_PROXY(_fn) registerCallback(wasm_vm_.get(), "_proxy_" #_fn, &_fn##Handler); +#define _REGISTER_PROXY(_fn) \ + registerCallback(wasm_vm_.get(), "envoy", "_proxy_" #_fn, &_fn##Handler); _REGISTER_PROXY(log); _REGISTER_PROXY(getRequestStreamInfoProtocol); diff --git a/source/extensions/common/wasm/wasm.h b/source/extensions/common/wasm/wasm.h index bdafacee71fbb..37074f88136b9 100644 --- a/source/extensions/common/wasm/wasm.h +++ b/source/extensions/common/wasm/wasm.h @@ -403,6 +403,8 @@ class WasmVm : public Logger::Loggable { virtual absl::string_view getMemory(uint32_t pointer, uint32_t size) PURE; // Set a block of memory in the VM, returns true on success, false if the pointer/size is invalid. virtual bool setMemory(uint32_t pointer, uint32_t size, void* data) PURE; + // Make a new intrinsic module (e.g. for Emscripten support). + virtual void makeModule(absl::string_view name) PURE; // Get the contents of the user section with the given name or "" if it does not exist and // optionally a presence indicator. @@ -506,15 +508,17 @@ inline Context::Context(Wasm* wasm) : wasm_(wasm), id_(wasm->allocContextId()) { // Forward declarations for VM implemenations. template -void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, R (*)(Args...)); +void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + R (*)(Args...)); template void getFunctionWavm(WasmVm* vm, absl::string_view functionName, std::function*); template -void registerCallback(WasmVm* vm, absl::string_view functionName, R (*f)(Args...)) { +void registerCallback(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + R (*f)(Args...)) { if (vm->vm() == WasmVmNames::get().Wavm) { - registerCallbackWavm(vm, functionName, f); + registerCallbackWavm(vm, moduleName, functionName, f); } else { throw WasmVmException("unsupoorted wasm vm"); } diff --git a/source/extensions/common/wasm/wavm/BUILD b/source/extensions/common/wasm/wavm/BUILD index e077a80794e66..d302142abe450 100644 --- a/source/extensions/common/wasm/wavm/BUILD +++ b/source/extensions/common/wasm/wavm/BUILD @@ -27,6 +27,7 @@ envoy_cc_library( "wavm_with_llvm", ], deps = [ + "//external:abseil_node_hash_map", "//include/envoy/server:wasm_interface", "//include/envoy/thread_local:thread_local_interface", "//source/common/common:assert_lib", diff --git a/source/extensions/common/wasm/wavm/wavm.cc b/source/extensions/common/wasm/wavm/wavm.cc index 96142cdec1fdf..23bc962d749b8 100644 --- a/source/extensions/common/wasm/wavm/wavm.cc +++ b/source/extensions/common/wasm/wavm/wavm.cc @@ -40,6 +40,7 @@ #include "WAVM/Runtime/RuntimeData.h" #include "WAVM/WASM/WASM.h" #include "WAVM/WASTParse/WASTParse.h" +#include "absl/container/node_hash_map.h" #include "absl/strings/match.h" using namespace WAVM; @@ -189,13 +190,9 @@ struct Wavm : public WasmVm { void* allocMemory(uint32_t size, uint32_t* pointer) override; absl::string_view getMemory(uint32_t pointer, uint32_t size) override; bool setMemory(uint32_t pointer, uint32_t size, void* data) override; + void makeModule(absl::string_view name) override; absl::string_view getUserSection(absl::string_view name, bool* present) override; - WAVM::Runtime::Memory* memory() { return memory_; } - WAVM::Runtime::Context* context() { return context_; } - WAVM::Runtime::ModuleInstance* moduleInstance() { return moduleInstance_; } - WAVM::Runtime::ModuleInstance* envoyModuleInstance() { return moduleInstance_; } - void GetFunctions(); void RegisterCallbacks(); @@ -207,9 +204,10 @@ struct Wavm : public WasmVm { Emscripten::Instance* emscriptenInstance_ = nullptr; WAVM::Runtime::GCPointer compartment_; WAVM::Runtime::GCPointer context_; - Intrinsics::Module envoy_module_; - WAVM::Runtime::GCPointer envoyModuleInstance_ = nullptr; - std::vector> envoy_functions_; + absl::node_hash_map intrinsicModules_; + absl::node_hash_map> + intrinsicModuleInstances_; + std::vector> envoyFunctions_; }; Wavm::~Wavm() { @@ -222,8 +220,9 @@ Wavm::~Wavm() { delete emscriptenInstance_; } context_ = nullptr; - envoyModuleInstance_ = nullptr; - envoy_functions_.clear(); + intrinsicModuleInstances_.clear(); + intrinsicModules_.clear(); + envoyFunctions_.clear(); if (compartment_) { ASSERT(tryCollectCompartment(std::move(compartment_))); } @@ -234,8 +233,10 @@ std::unique_ptr Wavm::clone() { wavm->compartment_ = WAVM::Runtime::cloneCompartment(compartment_); wavm->memory_ = WAVM::Runtime::remapToClonedCompartment(memory_, wavm->compartment_); wavm->context_ = WAVM::Runtime::createContext(wavm->compartment_); - wavm->envoyModuleInstance_ = - WAVM::Runtime::remapToClonedCompartment(envoyModuleInstance_, wavm->compartment_); + for (auto& p : intrinsicModuleInstances_) { + wavm->intrinsicModuleInstances_.emplace( + p.first, WAVM::Runtime::remapToClonedCompartment(p.second, wavm->compartment_)); + } wavm->moduleInstance_ = WAVM::Runtime::remapToClonedCompartment(moduleInstance_, wavm->compartment_); return wavm; @@ -264,13 +265,18 @@ bool Wavm::load(const std::string& code, bool allow_precompiled) { } else { module_ = WAVM::Runtime::loadPrecompiledModule(irModule_, precompiledObjectSection->data); } + makeModule("envoy"); return true; } void Wavm::link(absl::string_view name, bool needs_emscripten) { RootResolver rootResolver(compartment_); - envoyModuleInstance_ = Intrinsics::instantiateModule(compartment_, envoy_module_, "envoy"); - rootResolver.moduleNameToInstanceMap().set("envoy", envoyModuleInstance_); + for (auto& p : intrinsicModules_) { + auto instance = Intrinsics::instantiateModule(compartment_, intrinsicModules_[p.first], + std::string(p.first)); + intrinsicModuleInstances_.emplace(p.first, instance); + rootResolver.moduleNameToInstanceMap().set(p.first, instance); + } if (needs_emscripten) { emscriptenInstance_ = Emscripten::instantiate(compartment_, irModule_); rootResolver.moduleNameToInstanceMap().set("env", emscriptenInstance_->env); @@ -283,6 +289,10 @@ void Wavm::link(absl::string_view name, bool needs_emscripten) { memory_ = getDefaultMemory(moduleInstance_); } +void Wavm::makeModule(absl::string_view name) { + intrinsicModules_.emplace(std::piecewise_construct, std::make_tuple(name), std::make_tuple()); +} + void Wavm::start(Context* context) { auto f = getStartFunction(moduleInstance_); if (f) { @@ -301,7 +311,7 @@ void Wavm::start(Context* context) { } void* Wavm::allocMemory(uint32_t size, uint32_t* address) { - auto f = asFunctionNullable(getInstanceExport(moduleInstance(), "_malloc")); + auto f = asFunctionNullable(getInstanceExport(moduleInstance_, "_malloc")); if (!f) return nullptr; auto values = invokeFunctionChecked(context_, f, {size}); @@ -310,18 +320,18 @@ void* Wavm::allocMemory(uint32_t size, uint32_t* address) { ASSERT(v.type == ValueType::i32); *address = v.u32; return reinterpret_cast( - WAVM::Runtime::memoryArrayPtr(memory(), v.u32, static_cast(size))); + WAVM::Runtime::memoryArrayPtr(memory_, v.u32, static_cast(size))); } absl::string_view Wavm::getMemory(uint32_t pointer, uint32_t size) { return {reinterpret_cast( - WAVM::Runtime::memoryArrayPtr(memory(), pointer, static_cast(size))), + WAVM::Runtime::memoryArrayPtr(memory_, pointer, static_cast(size))), static_cast(size)}; } bool Wavm::setMemory(uint32_t pointer, uint32_t size, void* data) { auto p = reinterpret_cast( - WAVM::Runtime::memoryArrayPtr(memory(), pointer, static_cast(size))); + WAVM::Runtime::memoryArrayPtr(memory_, pointer, static_cast(size))); if (p) { memcpy(p, data, size); return true; @@ -357,69 +367,83 @@ IR::FunctionType inferEnvoyFunctionType(R (*)(void*, Args...)) { using namespace Wavm; template -void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, R (*f)(Args...)) { +void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + R (*f)(Args...)) { auto wavm = static_cast(vm); - wavm->envoy_functions_.emplace_back( - new Intrinsics::Function(wavm->envoy_module_, functionName.data(), reinterpret_cast(f), - inferEnvoyFunctionType(f), IR::CallingConvention::intrinsic)); + wavm->envoyFunctions_.emplace_back(new Intrinsics::Function( + wavm->intrinsicModules_[moduleName], functionName.data(), reinterpret_cast(f), + inferEnvoyFunctionType(f), IR::CallingConvention::intrinsic)); } -template void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, - void (*f)(void*)); -template void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, +template void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, void (*f)(void*)); +template void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, void (*f)(void*, U32)); -template void registerCallbackWavm(WasmVm* vm, +template void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, void (*f)(void*, U32, U32)); template void registerCallbackWavm(WasmVm* vm, + absl::string_view moduleName, absl::string_view functionName, void (*f)(void*, U32, U32, U32)); template void -registerCallbackWavm(WasmVm* vm, absl::string_view functionName, +registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32)); + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + void (*f)(void*, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32, U32)); + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + void (*f)(void*, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, void (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32, U32)); -template void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, - U32 (*f)(void*)); -template void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, +template void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, U32 (*f)(void*)); +template void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, U32 (*f)(void*, U32)); -template void registerCallbackWavm(WasmVm* vm, absl::string_view functionName, +template void registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, U32 (*f)(void*, U32, U32)); template void registerCallbackWavm(WasmVm* vm, + absl::string_view moduleName, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32)); template void -registerCallbackWavm(WasmVm* vm, absl::string_view functionName, +registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32)); -template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32)); +template void +registerCallbackWavm(WasmVm* vm, absl::string_view moduleName, + absl::string_view functionName, + U32 (*f)(void*, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32)); + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + U32 (*f)(void*, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32)); + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, + U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32)); template void registerCallbackWavm( - WasmVm* vm, absl::string_view functionName, + WasmVm* vm, absl::string_view moduleName, absl::string_view functionName, U32 (*f)(void*, U32, U32, U32, U32, U32, U32, U32, U32, U32, U32)); template @@ -436,10 +460,9 @@ template void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName, std::function* function, uint32_t) { auto wavm = static_cast(vm); - auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance(), std::string(functionName))); + auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName))); if (!f) - f = asFunctionNullable( - getInstanceExport(wavm->envoyModuleInstance(), std::string(functionName))); + f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName))); if (!f) { *function = nullptr; return; @@ -449,7 +472,7 @@ void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName, } *function = [wavm, f](Context* context, Args... args) -> R { UntaggedValue values[] = {args...}; - CALL_WITH_CONTEXT_RETURN(invokeFunctionUnchecked(wavm->context(), f, &values[0]), context, + CALL_WITH_CONTEXT_RETURN(invokeFunctionUnchecked(wavm->context_, f, &values[0]), context, uint32_t, i32); }; } @@ -460,10 +483,9 @@ template void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName, std::function* function, Void) { auto wavm = static_cast(vm); - auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance(), std::string(functionName))); + auto f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName))); if (!f) - f = asFunctionNullable( - getInstanceExport(wavm->envoyModuleInstance(), std::string(functionName))); + f = asFunctionNullable(getInstanceExport(wavm->moduleInstance_, std::string(functionName))); if (!f) { *function = nullptr; return; @@ -473,7 +495,7 @@ void getFunctionWavmReturn(WasmVm* vm, absl::string_view functionName, } *function = [wavm, f](Context* context, Args... args) -> R { UntaggedValue values[] = {args...}; - CALL_WITH_CONTEXT(invokeFunctionUnchecked(wavm->context(), f, &values[0]), context); + CALL_WITH_CONTEXT(invokeFunctionUnchecked(wavm->context_, f, &values[0]), context); }; }