diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 1c53314b5b259..3e459c770122a 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -1116,6 +1116,7 @@ bool Context::onStart(absl::string_view vm_configuration, PluginSharedPtr plugin wasm_->on_context_create_(this, id_, 0); plugin_.reset(); } + in_vm_context_created_ = true; if (wasm_->on_vm_start_) { configuration_ = vm_configuration; plugin_ = plugin; @@ -1178,6 +1179,7 @@ void Context::onCreate(uint32_t parent_context_id) { Network::FilterStatus Context::onNetworkNewConnection() { DeferAfterCallActions actions(this); onCreate(root_context_id_); + in_vm_context_created_ = true; if (!wasm_->on_new_connection_) { return Network::FilterStatus::Continue; } @@ -1188,7 +1190,7 @@ Network::FilterStatus Context::onNetworkNewConnection() { } Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_stream) { - if (!wasm_->on_downstream_data_) { + if (!in_vm_context_created_ || !wasm_->on_downstream_data_) { return Network::FilterStatus::Continue; } DeferAfterCallActions actions(this); @@ -1200,7 +1202,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str } Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_stream) { - if (!wasm_->on_upstream_data_) { + if (!in_vm_context_created_ || !wasm_->on_upstream_data_) { return Network::FilterStatus::Continue; } DeferAfterCallActions actions(this); @@ -1212,7 +1214,7 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea } void Context::onDownstreamConnectionClose(PeerType peer_type) { - if (wasm_->on_downstream_connection_close_) { + if (in_vm_context_created_ && wasm_->on_downstream_connection_close_) { DeferAfterCallActions actions(this); wasm_->on_downstream_connection_close_(this, id_, static_cast(peer_type)); } @@ -1225,7 +1227,7 @@ void Context::onDownstreamConnectionClose(PeerType peer_type) { } void Context::onUpstreamConnectionClose(PeerType peer_type) { - if (wasm_->on_upstream_connection_close_) { + if (in_vm_context_created_ && wasm_->on_upstream_connection_close_) { DeferAfterCallActions actions(this); wasm_->on_upstream_connection_close_(this, id_, static_cast(peer_type)); } @@ -1252,7 +1254,7 @@ Http::FilterHeadersStatus Context::onRequestHeaders() { } Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) { - if (!wasm_->on_request_body_) { + if (!in_vm_context_created_ || !wasm_->on_request_body_) { return Http::FilterDataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1278,7 +1280,7 @@ Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) { } Http::FilterTrailersStatus Context::onRequestTrailers() { - if (!wasm_->on_request_trailers_) { + if (!in_vm_context_created_ || !wasm_->on_request_trailers_) { return Http::FilterTrailersStatus::Continue; } DeferAfterCallActions actions(this); @@ -1289,7 +1291,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() { } Http::FilterMetadataStatus Context::onRequestMetadata() { - if (!wasm_->on_request_metadata_) { + if (!in_vm_context_created_ || !wasm_->on_request_metadata_) { return Http::FilterMetadataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1319,7 +1321,7 @@ Http::FilterHeadersStatus Context::onResponseHeaders() { } Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) { - if (!wasm_->on_response_body_) { + if (!in_vm_context_created_ || !wasm_->on_response_body_) { return Http::FilterDataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1345,7 +1347,7 @@ Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) { } Http::FilterTrailersStatus Context::onResponseTrailers() { - if (!wasm_->on_response_trailers_) { + if (!in_vm_context_created_ || !wasm_->on_response_trailers_) { return Http::FilterTrailersStatus::Continue; } DeferAfterCallActions actions(this); @@ -1356,7 +1358,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() { } Http::FilterMetadataStatus Context::onResponseMetadata() { - if (!wasm_->on_response_metadata_) { + if (!in_vm_context_created_ || !wasm_->on_response_metadata_) { return Http::FilterMetadataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1626,7 +1628,7 @@ void Context::onDestroy() { bool Context::onDone() { DeferAfterCallActions actions(this); - if (wasm_->on_done_) { + if (in_vm_context_created_ && wasm_->on_done_) { return wasm_->on_done_(this, id_).u64_ != 0; } return true; @@ -1634,14 +1636,14 @@ bool Context::onDone() { void Context::onLog() { DeferAfterCallActions actions(this); - if (wasm_->on_log_) { + if (in_vm_context_created_ && wasm_->on_log_) { wasm_->on_log_(this, id_); } } void Context::onDelete() { DeferAfterCallActions actions(this); - if (wasm_->on_delete_) { + if (in_vm_context_created_ && wasm_->on_delete_) { wasm_->on_delete_(this, id_); } } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 0a1847ed0ae60..65917a8e9078a 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -331,6 +331,8 @@ class Context : public Logger::Loggable, void addAfterVmCallAction(std::function f); + void setInVmContextCreatedForTesting() { in_vm_context_created_ = true; } + protected: friend class Wasm; diff --git a/test/extensions/wasm/wasm_test.cc b/test/extensions/wasm/wasm_test.cc index 9deb40e86c499..aaa9adbab56ac 100644 --- a/test/extensions/wasm/wasm_test.cc +++ b/test/extensions/wasm/wasm_test.cc @@ -191,7 +191,7 @@ TEST_P(WasmTest, DivByZero) { auto context = std::make_unique(wasm.get()); EXPECT_CALL(*context, scriptLog_(spdlog::level::err, Eq("before div by zero"))); EXPECT_TRUE(wasm->initialize(code, false)); - wasm->setContext(context.get()); + context->setInVmContextCreatedForTesting(); if (GetParam() == "v8") { EXPECT_THROW_WITH_MESSAGE( @@ -401,6 +401,7 @@ TEST_P(WasmTest, StatsHighLevel) { "{{ test_rundir }}/test/extensions/wasm/test_data/stats_cpp.wasm")); EXPECT_FALSE(code.empty()); auto context = std::make_unique(wasm.get()); + context->setInVmContextCreatedForTesting(); EXPECT_CALL(*context, scriptLog_(spdlog::level::trace, Eq("get counter = 1"))); EXPECT_CALL(*context, scriptLog_(spdlog::level::debug, Eq("get counter = 2")));