Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions source/extensions/common/wasm/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,7 @@ bool Context::onStart(absl::string_view vm_configuration, PluginSharedPtr plugin
wasm_->on_context_create_(this, id_, 0);
plugin_.reset();
}
in_vm_context_created_ = true;
if (wasm_->on_vm_start_) {
configuration_ = vm_configuration;
plugin_ = plugin;
Expand Down Expand Up @@ -1178,6 +1179,7 @@ void Context::onCreate(uint32_t parent_context_id) {
Network::FilterStatus Context::onNetworkNewConnection() {
DeferAfterCallActions actions(this);
onCreate(root_context_id_);
in_vm_context_created_ = true;
if (!wasm_->on_new_connection_) {
return Network::FilterStatus::Continue;
}
Expand All @@ -1188,7 +1190,7 @@ Network::FilterStatus Context::onNetworkNewConnection() {
}

Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_stream) {
if (!wasm_->on_downstream_data_) {
if (!in_vm_context_created_ || !wasm_->on_downstream_data_) {
return Network::FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1200,7 +1202,7 @@ Network::FilterStatus Context::onDownstreamData(int data_length, bool end_of_str
}

Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_stream) {
if (!wasm_->on_upstream_data_) {
if (!in_vm_context_created_ || !wasm_->on_upstream_data_) {
return Network::FilterStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1212,7 +1214,7 @@ Network::FilterStatus Context::onUpstreamData(int data_length, bool end_of_strea
}

void Context::onDownstreamConnectionClose(PeerType peer_type) {
if (wasm_->on_downstream_connection_close_) {
if (in_vm_context_created_ && wasm_->on_downstream_connection_close_) {
DeferAfterCallActions actions(this);
wasm_->on_downstream_connection_close_(this, id_, static_cast<uint32_t>(peer_type));
}
Expand All @@ -1225,7 +1227,7 @@ void Context::onDownstreamConnectionClose(PeerType peer_type) {
}

void Context::onUpstreamConnectionClose(PeerType peer_type) {
if (wasm_->on_upstream_connection_close_) {
if (in_vm_context_created_ && wasm_->on_upstream_connection_close_) {
DeferAfterCallActions actions(this);
wasm_->on_upstream_connection_close_(this, id_, static_cast<uint32_t>(peer_type));
}
Expand All @@ -1252,7 +1254,7 @@ Http::FilterHeadersStatus Context::onRequestHeaders() {
}

Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) {
if (!wasm_->on_request_body_) {
if (!in_vm_context_created_ || !wasm_->on_request_body_) {
return Http::FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1278,7 +1280,7 @@ Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) {
}

Http::FilterTrailersStatus Context::onRequestTrailers() {
if (!wasm_->on_request_trailers_) {
if (!in_vm_context_created_ || !wasm_->on_request_trailers_) {
return Http::FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1289,7 +1291,7 @@ Http::FilterTrailersStatus Context::onRequestTrailers() {
}

Http::FilterMetadataStatus Context::onRequestMetadata() {
if (!wasm_->on_request_metadata_) {
if (!in_vm_context_created_ || !wasm_->on_request_metadata_) {
return Http::FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand Down Expand Up @@ -1319,7 +1321,7 @@ Http::FilterHeadersStatus Context::onResponseHeaders() {
}

Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) {
if (!wasm_->on_response_body_) {
if (!in_vm_context_created_ || !wasm_->on_response_body_) {
return Http::FilterDataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1345,7 +1347,7 @@ Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) {
}

Http::FilterTrailersStatus Context::onResponseTrailers() {
if (!wasm_->on_response_trailers_) {
if (!in_vm_context_created_ || !wasm_->on_response_trailers_) {
return Http::FilterTrailersStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand All @@ -1356,7 +1358,7 @@ Http::FilterTrailersStatus Context::onResponseTrailers() {
}

Http::FilterMetadataStatus Context::onResponseMetadata() {
if (!wasm_->on_response_metadata_) {
if (!in_vm_context_created_ || !wasm_->on_response_metadata_) {
return Http::FilterMetadataStatus::Continue;
}
DeferAfterCallActions actions(this);
Expand Down Expand Up @@ -1626,22 +1628,22 @@ void Context::onDestroy() {

bool Context::onDone() {
DeferAfterCallActions actions(this);
if (wasm_->on_done_) {
if (in_vm_context_created_ && wasm_->on_done_) {
return wasm_->on_done_(this, id_).u64_ != 0;
}
return true;
}

void Context::onLog() {
DeferAfterCallActions actions(this);
if (wasm_->on_log_) {
if (in_vm_context_created_ && wasm_->on_log_) {
wasm_->on_log_(this, id_);
}
}

void Context::onDelete() {
DeferAfterCallActions actions(this);
if (wasm_->on_delete_) {
if (in_vm_context_created_ && wasm_->on_delete_) {
wasm_->on_delete_(this, id_);
}
}
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/common/wasm/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ class Context : public Logger::Loggable<Logger::Id::wasm>,

void addAfterVmCallAction(std::function<void()> f);

void setInVmContextCreatedForTesting() { in_vm_context_created_ = true; }

protected:
friend class Wasm;

Expand Down
3 changes: 2 additions & 1 deletion test/extensions/wasm/wasm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TEST_P(WasmTest, DivByZero) {
auto context = std::make_unique<TestContext>(wasm.get());
EXPECT_CALL(*context, scriptLog_(spdlog::level::err, Eq("before div by zero")));
EXPECT_TRUE(wasm->initialize(code, false));
wasm->setContext(context.get());
context->setInVmContextCreatedForTesting();

if (GetParam() == "v8") {
EXPECT_THROW_WITH_MESSAGE(
Expand Down Expand Up @@ -401,6 +401,7 @@ TEST_P(WasmTest, StatsHighLevel) {
"{{ test_rundir }}/test/extensions/wasm/test_data/stats_cpp.wasm"));
EXPECT_FALSE(code.empty());
auto context = std::make_unique<TestContext>(wasm.get());
context->setInVmContextCreatedForTesting();

EXPECT_CALL(*context, scriptLog_(spdlog::level::trace, Eq("get counter = 1")));
EXPECT_CALL(*context, scriptLog_(spdlog::level::debug, Eq("get counter = 2")));
Expand Down