diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index a79cbfad64f0a..973a7017e8b9a 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -1130,6 +1130,7 @@ bool Context::onStart(absl::string_view vm_configuration, PluginSharedPtr plugin wasm_->on_context_create_(this, id_, 0); plugin_.reset(); } + in_vm_context_created_ = true; if (wasm_->on_vm_start_) { configuration_ = vm_configuration; plugin_ = plugin; @@ -1192,6 +1193,7 @@ void Context::onCreate(uint32_t parent_context_id) { Network::FilterStatus Context::onNetworkNewConnection() { DeferAfterCallActions actions(this); onCreate(root_context_id_); + in_vm_context_created_ = true; if (!wasm_->on_new_connection_) { return Network::FilterStatus::Continue; } @@ -1202,7 +1204,7 @@ Network::FilterStatus Context::onNetworkNewConnection() { } Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_stream) { - if (!wasm_->on_downstream_data_) { + if (!in_vm_context_created_ || !wasm_->on_downstream_data_) { return Network::FilterStatus::Continue; } DeferAfterCallActions actions(this); @@ -1214,7 +1216,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str } Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_stream) { - if (!wasm_->on_upstream_data_) { + if (!in_vm_context_created_ || !wasm_->on_upstream_data_) { return Network::FilterStatus::Continue; } DeferAfterCallActions actions(this); @@ -1226,7 +1228,7 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea } void Context::onDownstreamConnectionClose(PeerType peer_type) { - if (wasm_->on_downstream_connection_close_) { + if (in_vm_context_created_ && wasm_->on_downstream_connection_close_) { DeferAfterCallActions actions(this); wasm_->on_downstream_connection_close_(this, id_, static_cast(peer_type)); } @@ -1239,7 +1241,7 @@ void Context::onDownstreamConnectionClose(PeerType peer_type) { } void Context::onUpstreamConnectionClose(PeerType peer_type) { - if (wasm_->on_upstream_connection_close_) { + if (in_vm_context_created_ && wasm_->on_upstream_connection_close_) { DeferAfterCallActions actions(this); wasm_->on_upstream_connection_close_(this, id_, static_cast(peer_type)); } @@ -1266,7 +1268,7 @@ Http::FilterHeadersStatus Context::onRequestHeaders() { } Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) { - if (!wasm_->on_request_body_) { + if (!in_vm_context_created_ || !wasm_->on_request_body_) { return Http::FilterDataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1292,7 +1294,7 @@ Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) { } Http::FilterTrailersStatus Context::onRequestTrailers() { - if (!wasm_->on_request_trailers_) { + if (!in_vm_context_created_ || !wasm_->on_request_trailers_) { return Http::FilterTrailersStatus::Continue; } DeferAfterCallActions actions(this); @@ -1303,7 +1305,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() { } Http::FilterMetadataStatus Context::onRequestMetadata() { - if (!wasm_->on_request_metadata_) { + if (!in_vm_context_created_ || !wasm_->on_request_metadata_) { return Http::FilterMetadataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1333,7 +1335,7 @@ Http::FilterHeadersStatus Context::onResponseHeaders() { } Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) { - if (!wasm_->on_response_body_) { + if (!in_vm_context_created_ || !wasm_->on_response_body_) { return Http::FilterDataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1359,7 +1361,7 @@ Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) { } Http::FilterTrailersStatus Context::onResponseTrailers() { - if (!wasm_->on_response_trailers_) { + if (!in_vm_context_created_ || !wasm_->on_response_trailers_) { return Http::FilterTrailersStatus::Continue; } DeferAfterCallActions actions(this); @@ -1370,7 +1372,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() { } Http::FilterMetadataStatus Context::onResponseMetadata() { - if (!wasm_->on_response_metadata_) { + if (!in_vm_context_created_ || !wasm_->on_response_metadata_) { return Http::FilterMetadataStatus::Continue; } DeferAfterCallActions actions(this); @@ -1637,7 +1639,7 @@ void Context::onDestroy() { bool Context::onDone() { DeferAfterCallActions actions(this); - if (wasm_->on_done_) { + if (in_vm_context_created_ && wasm_->on_done_) { return wasm_->on_done_(this, id_).u64_ != 0; } return true; @@ -1645,14 +1647,14 @@ bool Context::onDone() { void Context::onLog() { DeferAfterCallActions actions(this); - if (wasm_->on_log_) { + if (in_vm_context_created_ && wasm_->on_log_) { wasm_->on_log_(this, id_); } } void Context::onDelete() { DeferAfterCallActions actions(this); - if (wasm_->on_delete_) { + if (in_vm_context_created_ && wasm_->on_delete_) { wasm_->on_delete_(this, id_); } } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 6c30dc1ebbbc5..84f57038c347d 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -346,6 +346,8 @@ class Context : public Logger::Loggable, void addAfterVmCallAction(std::function f); + void setInVmContextCreatedForTesting() { in_vm_context_created_ = true; } + protected: friend class Wasm; diff --git a/test/extensions/wasm/wasm_test.cc b/test/extensions/wasm/wasm_test.cc index 9deb40e86c499..aaa9adbab56ac 100644 --- a/test/extensions/wasm/wasm_test.cc +++ b/test/extensions/wasm/wasm_test.cc @@ -191,7 +191,7 @@ TEST_P(WasmTest, DivByZero) { auto context = std::make_unique(wasm.get()); EXPECT_CALL(*context, scriptLog_(spdlog::level::err, Eq("before div by zero"))); EXPECT_TRUE(wasm->initialize(code, false)); - wasm->setContext(context.get()); + context->setInVmContextCreatedForTesting(); if (GetParam() == "v8") { EXPECT_THROW_WITH_MESSAGE( @@ -401,6 +401,7 @@ TEST_P(WasmTest, StatsHighLevel) { "{{ test_rundir }}/test/extensions/wasm/test_data/stats_cpp.wasm")); EXPECT_FALSE(code.empty()); auto context = std::make_unique(wasm.get()); + context->setInVmContextCreatedForTesting(); EXPECT_CALL(*context, scriptLog_(spdlog::level::trace, Eq("get counter = 1"))); EXPECT_CALL(*context, scriptLog_(spdlog::level::debug, Eq("get counter = 2")));