diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index b6b917970f..6547d3d9a7 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -343,7 +343,11 @@ WasmResult serializeValue(Filters::Common::Expr::CelValue value, std::string* re // An expression wrapper for the WASM state class WasmStateWrapper : public google::api::expr::runtime::CelMap { public: - WasmStateWrapper(const StreamInfo::FilterState& filter_state) : filter_state_(filter_state) {} + WasmStateWrapper(const StreamInfo::FilterState& filter_state, + const StreamInfo::FilterState* connection_filter_state) + : filter_state_(filter_state), connection_filter_state_(connection_filter_state) {} + WasmStateWrapper(const StreamInfo::FilterState& filter_state) + : filter_state_(filter_state), connection_filter_state_(nullptr) {} absl::optional operator[](google::api::expr::runtime::CelValue key) const override { if (!key.IsString()) { @@ -354,6 +358,15 @@ class WasmStateWrapper : public google::api::expr::runtime::CelMap { const WasmState& result = filter_state_.getDataReadOnly(value); return google::api::expr::runtime::CelValue::CreateBytes(&result.value()); } catch (const EnvoyException& e) { + // If doesn't exist in request filter state, try looking up in connection filter state. + try { + if (connection_filter_state_) { + const WasmState& result = connection_filter_state_->getDataReadOnly(value); + return google::api::expr::runtime::CelValue::CreateBytes(&result.value()); + } + } catch (const EnvoyException& e) { + return {}; + } return {}; } } @@ -365,6 +378,7 @@ class WasmStateWrapper : public google::api::expr::runtime::CelMap { private: const StreamInfo::FilterState& filter_state_; + const StreamInfo::FilterState* connection_filter_state_; }; #define PROPERTY_TOKENS(_f) \ @@ -426,10 +440,17 @@ WasmResult Context::getProperty(absl::string_view path, std::string* result) { case PropertyToken::METADATA: value = CelValue::CreateMessage(&info->dynamicMetadata(), &arena); break; - case PropertyToken::FILTER_STATE: - value = CelValue::CreateMap( - Protobuf::Arena::Create(&arena, info->filterState())); + case PropertyToken::FILTER_STATE: { + const Envoy::Network::Connection* connection = getConnection(); + if (connection) { + value = CelValue::CreateMap(Protobuf::Arena::Create( + &arena, info->filterState(), &connection->streamInfo().filterState())); + } else { + value = CelValue::CreateMap( + Protobuf::Arena::Create(&arena, info->filterState())); + } break; + } case PropertyToken::REQUEST: value = CelValue::CreateMap(Protobuf::Arena::Create( &arena, request_headers, *info)); @@ -974,6 +995,15 @@ StreamInfo::StreamInfo* Context::getRequestStreamInfo() const { return nullptr; } +const Network::Connection* Context::getConnection() const { + if (encoder_callbacks_) { + return encoder_callbacks_->connection(); + } else if (decoder_callbacks_) { + return decoder_callbacks_->connection(); + } + return nullptr; +} + WasmResult Context::setProperty(absl::string_view key, absl::string_view serialized_value) { auto* stream_info = getRequestStreamInfo(); if (!stream_info) { diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 4a84ac9823..60237e8f37 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -85,6 +85,12 @@ class Context : public Logger::Loggable, const StreamInfo::StreamInfo* getConstRequestStreamInfo() const; StreamInfo::StreamInfo* getRequestStreamInfo() const; + // Retrieves the connection object associated with the request (a.k.a active stream). + // It selects a value based on the following order: encoder callback, decoder + // callback. As long as any one of the callbacks is invoked, the value should be + // available. + const Network::Connection* getConnection() const; + // // VM level downcalls into the WASM code on Context(id == 0). //