diff --git a/api/envoy/extensions/filters/http/jwt_authn/v3/config.proto b/api/envoy/extensions/filters/http/jwt_authn/v3/config.proto index 08ef7a09feb20..afc761c07c7e1 100644 --- a/api/envoy/extensions/filters/http/jwt_authn/v3/config.proto +++ b/api/envoy/extensions/filters/http/jwt_authn/v3/config.proto @@ -232,6 +232,35 @@ message RemoteJwks { // Duration after which the cached JWKS should be expired. If not specified, default cache // duration is 5 minutes. google.protobuf.Duration cache_duration = 2; + + // Fetch Jwks asynchronously in the main thread before the listener is activated. + // Fetched Jwks can be used by all worker threads. + // + // If this feature is not enabled: + // + // * The Jwks is fetched on-demand when the requests come. During the fetching, first + // few requests are paused until the Jwks is fetched. + // * Each worker thread fetches its own Jwks since Jwks cache is per worker thread. + // + // If this feature is enabled: + // + // * Fetched Jwks is done in the main thread before the listener is activated. Its fetched + // Jwks can be used by all worker threads. Each worker thread doesn't need to fetch its own. + // * Jwks is ready when the requests come, not need to wait for the Jwks fetching. + // + JwksAsyncFetch async_fetch = 3; +} + +// Fetch Jwks asynchronously in the main thread when the filter config is parsed. +// The listener is activated only after the Jwks is fetched. +// When the Jwks is expired in the cache, it is fetched again in the main thread. +// The fetched Jwks from the main thread can be used by all worker threads. +message JwksAsyncFetch { + // If false, the listener is activated after the initial fetch is completed. + // The initial fetch result can be either successful or failed. + // If true, it is activated without waiting for the initial fetch to complete. + // Default is false. + bool fast_listener = 1; } // This message specifies a header location to extract JWT token. diff --git a/api/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto b/api/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto index 7656f09912e9b..442ba7df061ee 100644 --- a/api/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto +++ b/api/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto @@ -232,6 +232,38 @@ message RemoteJwks { // Duration after which the cached JWKS should be expired. If not specified, default cache // duration is 5 minutes. google.protobuf.Duration cache_duration = 2; + + // Fetch Jwks asynchronously in the main thread before the listener is activated. + // Fetched Jwks can be used by all worker threads. + // + // If this feature is not enabled: + // + // * The Jwks is fetched on-demand when the requests come. During the fetching, first + // few requests are paused until the Jwks is fetched. + // * Each worker thread fetches its own Jwks since Jwks cache is per worker thread. + // + // If this feature is enabled: + // + // * Fetched Jwks is done in the main thread before the listener is activated. Its fetched + // Jwks can be used by all worker threads. Each worker thread doesn't need to fetch its own. + // * Jwks is ready when the requests come, not need to wait for the Jwks fetching. + // + JwksAsyncFetch async_fetch = 3; +} + +// Fetch Jwks asynchronously in the main thread when the filter config is parsed. +// The listener is activated only after the Jwks is fetched. +// When the Jwks is expired in the cache, it is fetched again in the main thread. +// The fetched Jwks from the main thread can be used by all worker threads. +message JwksAsyncFetch { + option (udpa.annotations.versioning).previous_message_type = + "envoy.extensions.filters.http.jwt_authn.v3.JwksAsyncFetch"; + + // If false, the listener is activated after the initial fetch is completed. + // The initial fetch result can be either successful or failed. + // If true, it is activated without waiting for the initial fetch to complete. + // Default is false. + bool fast_listener = 1; } // This message specifies a header location to extract JWT token. diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index 0a13ed9987f94..5f5482c6c0945 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -76,6 +76,7 @@ New Features :ref:`xff ` extension. * http: added the ability to :ref:`unescape slash sequences` in the path. Requests with unescaped slashes can be proxied, rejected or redirected to the new unescaped path. By default this feature is disabled. The default behavior can be overridden through :ref:`http_connection_manager.path_with_escaped_slashes_action` runtime variable. This action can be selectively enabled for a portion of requests by setting the :ref:`http_connection_manager.path_with_escaped_slashes_action_sampling` runtime variable. * http: added upstream and downstream alpha HTTP/3 support! See :ref:`quic_options ` for downstream and the new http3_protocol_options in :ref:`http_protocol_options ` for upstream HTTP/3. +* jwt_authn: added support to fetch remote jwks asynchronously specified by :ref:`async_fetch `. * listener: added ability to change an existing listener's address. * local_rate_limit_filter: added suppoort for locally rate limiting http requests on a per connection basis. This can be enabled by setting the :ref:`local_rate_limit_per_downstream_connection ` field to true. * metric service: added support for sending metric tags as labels. This can be enabled by setting the :ref:`emit_tags_as_labels ` field to true. diff --git a/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v3/config.proto b/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v3/config.proto index 08ef7a09feb20..afc761c07c7e1 100644 --- a/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v3/config.proto +++ b/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v3/config.proto @@ -232,6 +232,35 @@ message RemoteJwks { // Duration after which the cached JWKS should be expired. If not specified, default cache // duration is 5 minutes. google.protobuf.Duration cache_duration = 2; + + // Fetch Jwks asynchronously in the main thread before the listener is activated. + // Fetched Jwks can be used by all worker threads. + // + // If this feature is not enabled: + // + // * The Jwks is fetched on-demand when the requests come. During the fetching, first + // few requests are paused until the Jwks is fetched. + // * Each worker thread fetches its own Jwks since Jwks cache is per worker thread. + // + // If this feature is enabled: + // + // * Fetched Jwks is done in the main thread before the listener is activated. Its fetched + // Jwks can be used by all worker threads. Each worker thread doesn't need to fetch its own. + // * Jwks is ready when the requests come, not need to wait for the Jwks fetching. + // + JwksAsyncFetch async_fetch = 3; +} + +// Fetch Jwks asynchronously in the main thread when the filter config is parsed. +// The listener is activated only after the Jwks is fetched. +// When the Jwks is expired in the cache, it is fetched again in the main thread. +// The fetched Jwks from the main thread can be used by all worker threads. +message JwksAsyncFetch { + // If false, the listener is activated after the initial fetch is completed. + // The initial fetch result can be either successful or failed. + // If true, it is activated without waiting for the initial fetch to complete. + // Default is false. + bool fast_listener = 1; } // This message specifies a header location to extract JWT token. diff --git a/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto b/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto index 7656f09912e9b..442ba7df061ee 100644 --- a/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto +++ b/generated_api_shadow/envoy/extensions/filters/http/jwt_authn/v4alpha/config.proto @@ -232,6 +232,38 @@ message RemoteJwks { // Duration after which the cached JWKS should be expired. If not specified, default cache // duration is 5 minutes. google.protobuf.Duration cache_duration = 2; + + // Fetch Jwks asynchronously in the main thread before the listener is activated. + // Fetched Jwks can be used by all worker threads. + // + // If this feature is not enabled: + // + // * The Jwks is fetched on-demand when the requests come. During the fetching, first + // few requests are paused until the Jwks is fetched. + // * Each worker thread fetches its own Jwks since Jwks cache is per worker thread. + // + // If this feature is enabled: + // + // * Fetched Jwks is done in the main thread before the listener is activated. Its fetched + // Jwks can be used by all worker threads. Each worker thread doesn't need to fetch its own. + // * Jwks is ready when the requests come, not need to wait for the Jwks fetching. + // + JwksAsyncFetch async_fetch = 3; +} + +// Fetch Jwks asynchronously in the main thread when the filter config is parsed. +// The listener is activated only after the Jwks is fetched. +// When the Jwks is expired in the cache, it is fetched again in the main thread. +// The fetched Jwks from the main thread can be used by all worker threads. +message JwksAsyncFetch { + option (udpa.annotations.versioning).previous_message_type = + "envoy.extensions.filters.http.jwt_authn.v3.JwksAsyncFetch"; + + // If false, the listener is activated after the initial fetch is completed. + // The initial fetch result can be either successful or failed. + // If true, it is activated without waiting for the initial fetch to complete. + // Default is false. + bool fast_listener = 1; } // This message specifies a header location to extract JWT token. diff --git a/source/extensions/filters/http/jwt_authn/BUILD b/source/extensions/filters/http/jwt_authn/BUILD index 0d5895dfbd5a6..e8d55bf9f3ee6 100644 --- a/source/extensions/filters/http/jwt_authn/BUILD +++ b/source/extensions/filters/http/jwt_authn/BUILD @@ -20,6 +20,33 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "stats_lib", + hdrs = ["stats.h"], + deps = [ + "//include/envoy/stats:stats_macros", + ], +) + +envoy_cc_library( + name = "jwks_async_fetcher_lib", + srcs = ["jwks_async_fetcher.cc"], + hdrs = ["jwks_async_fetcher.h"], + external_deps = [ + "jwt_verify_lib", + ], + deps = [ + ":stats_lib", + "//include/envoy/server:factory_context_interface", + "//source/common/common:minimal_logger_lib", + "//source/common/init:target_lib", + "//source/common/protobuf:utility_lib", + "//source/common/tracing:http_tracer_lib", + "//source/extensions/filters/http/common:jwks_fetcher_lib", + "@envoy_api//envoy/extensions/filters/http/jwt_authn/v3:pkg_cc_proto", + ], +) + envoy_cc_library( name = "jwks_cache_lib", srcs = ["jwks_cache.cc"], @@ -28,9 +55,8 @@ envoy_cc_library( "jwt_verify_lib", ], deps = [ - "//source/common/common:minimal_logger_lib", + "jwks_async_fetcher_lib", "//source/common/config:datasource_lib", - "//source/common/protobuf:utility_lib", "@envoy_api//envoy/extensions/filters/http/jwt_authn/v3:pkg_cc_proto", ], ) @@ -58,7 +84,7 @@ envoy_cc_library( "jwt_verify_lib", ], deps = [ - ":filter_config_interface", + ":filter_config_lib", ":matchers_lib", "//include/envoy/http:filter_interface", "//source/common/http:headers_lib", @@ -109,7 +135,7 @@ envoy_cc_library( ) envoy_cc_library( - name = "filter_config_interface", + name = "filter_config_lib", srcs = ["filter_config.cc"], hdrs = ["filter_config.h"], deps = [ diff --git a/source/extensions/filters/http/jwt_authn/authenticator.cc b/source/extensions/filters/http/jwt_authn/authenticator.cc index 0115e91d061b6..fd960636ac41d 100644 --- a/source/extensions/filters/http/jwt_authn/authenticator.cc +++ b/source/extensions/filters/http/jwt_authn/authenticator.cc @@ -221,6 +221,7 @@ void AuthenticatorImpl::startVerify() { } void AuthenticatorImpl::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) { + jwks_cache_.stats().jwks_fetch_success_.inc(); const Status status = jwks_data_->setRemoteJwks(std::move(jwks))->getStatus(); if (status != Status::Ok) { doneWithStatus(status); @@ -229,7 +230,10 @@ void AuthenticatorImpl::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) { } } -void AuthenticatorImpl::onJwksError(Failure) { doneWithStatus(Status::JwksFetchFail); } +void AuthenticatorImpl::onJwksError(Failure) { + jwks_cache_.stats().jwks_fetch_failed_.inc(); + doneWithStatus(Status::JwksFetchFail); +} void AuthenticatorImpl::onDestroy() { if (fetcher_) { diff --git a/source/extensions/filters/http/jwt_authn/authenticator.h b/source/extensions/filters/http/jwt_authn/authenticator.h index 928a8045b843e..04ee62cede1e1 100644 --- a/source/extensions/filters/http/jwt_authn/authenticator.h +++ b/source/extensions/filters/http/jwt_authn/authenticator.h @@ -2,7 +2,6 @@ #include "envoy/server/filter_config.h" -#include "extensions/filters/http/common/jwks_fetcher.h" #include "extensions/filters/http/jwt_authn/extractor.h" #include "extensions/filters/http/jwt_authn/jwks_cache.h" @@ -21,11 +20,6 @@ using AuthenticatorCallback = std::function; -/** - * CreateJwksFetcherCb is a callback interface for creating a JwksFetcher instance. - */ -using CreateJwksFetcherCb = std::function; - /** * Authenticator object to handle all JWT authentication flow. */ diff --git a/source/extensions/filters/http/jwt_authn/filter_config.cc b/source/extensions/filters/http/jwt_authn/filter_config.cc index a35ddf038824d..475d993daaeca 100644 --- a/source/extensions/filters/http/jwt_authn/filter_config.cc +++ b/source/extensions/filters/http/jwt_authn/filter_config.cc @@ -19,8 +19,7 @@ FilterConfigImpl::FilterConfigImpl( ENVOY_LOG(debug, "Loaded JwtAuthConfig: {}", proto_config_.DebugString()); - jwks_cache_ = - JwksCache::create(proto_config_, time_source_, context.api(), context.threadLocal()); + jwks_cache_ = JwksCache::create(proto_config_, context, Common::JwksFetcher::create, stats_); std::vector names; for (const auto& it : proto_config_.requirement_map()) { diff --git a/source/extensions/filters/http/jwt_authn/filter_config.h b/source/extensions/filters/http/jwt_authn/filter_config.h index d659dbf95135a..78f0e23d98783 100644 --- a/source/extensions/filters/http/jwt_authn/filter_config.h +++ b/source/extensions/filters/http/jwt_authn/filter_config.h @@ -9,6 +9,7 @@ #include "envoy/thread_local/thread_local.h" #include "extensions/filters/http/jwt_authn/matcher.h" +#include "extensions/filters/http/jwt_authn/stats.h" #include "extensions/filters/http/jwt_authn/verifier.h" #include "absl/container/flat_hash_map.h" @@ -18,21 +19,6 @@ namespace Extensions { namespace HttpFilters { namespace JwtAuthn { -/** - * All stats for the Jwt Authn filter. @see stats_macros.h - */ -#define ALL_JWT_AUTHN_FILTER_STATS(COUNTER) \ - COUNTER(allowed) \ - COUNTER(cors_preflight_bypassed) \ - COUNTER(denied) - -/** - * Wrapper struct for jwt_authn filter stats. @see stats_macros.h - */ -struct JwtAuthnFilterStats { - ALL_JWT_AUTHN_FILTER_STATS(GENERATE_COUNTER_STRUCT) -}; - /** * The per-route filter config */ diff --git a/source/extensions/filters/http/jwt_authn/jwks_async_fetcher.cc b/source/extensions/filters/http/jwt_authn/jwks_async_fetcher.cc new file mode 100644 index 0000000000000..0ec5bbdce98f8 --- /dev/null +++ b/source/extensions/filters/http/jwt_authn/jwks_async_fetcher.cc @@ -0,0 +1,98 @@ +#include "extensions/filters/http/jwt_authn/jwks_async_fetcher.h" + +#include "common/protobuf/utility.h" +#include "common/tracing/http_tracer_impl.h" + +using envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks; + +namespace Envoy { +namespace Extensions { +namespace HttpFilters { +namespace JwtAuthn { +namespace { + +// Default cache expiration time in 5 minutes. +constexpr int PubkeyCacheExpirationSec = 600; + +} // namespace + +JwksAsyncFetcher::JwksAsyncFetcher(const RemoteJwks& remote_jwks, + Server::Configuration::FactoryContext& context, + CreateJwksFetcherCb create_fetcher_fn, + JwtAuthnFilterStats& stats, JwksDoneFetched done_fn) + : remote_jwks_(remote_jwks), context_(context), create_fetcher_fn_(create_fetcher_fn), + stats_(stats), done_fn_(done_fn), cache_duration_(getCacheDuration(remote_jwks)), + debug_name_(absl::StrCat("Jwks async fetching url=", remote_jwks_.http_uri().uri())) { + // if async_fetch is not enabled, do nothing. + if (!remote_jwks_.has_async_fetch()) { + return; + } + + cache_duration_timer_ = context_.dispatcher().createTimer([this]() -> void { fetch(); }); + + // For fast_listener, just trigger a fetch, not register with init_manager. + if (remote_jwks_.async_fetch().fast_listener()) { + fetch(); + return; + } + + // Register to init_manager, force the listener to wait for the fetching. + init_target_ = std::make_unique(debug_name_, [this]() -> void { fetch(); }); + context_.initManager().add(*init_target_); +} + +std::chrono::seconds JwksAsyncFetcher::getCacheDuration(const RemoteJwks& remote_jwks) { + if (remote_jwks.has_cache_duration()) { + return std::chrono::seconds(DurationUtil::durationToSeconds(remote_jwks.cache_duration())); + } + return std::chrono::seconds(PubkeyCacheExpirationSec); +} + +void JwksAsyncFetcher::fetch() { + if (fetcher_) { + fetcher_->cancel(); + } + + ENVOY_LOG(debug, "{}: started", debug_name_); + fetcher_ = create_fetcher_fn_(context_.clusterManager()); + fetcher_->fetch(remote_jwks_.http_uri(), Tracing::NullSpan::instance(), *this); +} + +void JwksAsyncFetcher::handleFetchDone() { + if (init_target_) { + init_target_->ready(); + init_target_.reset(); + } + + cache_duration_timer_->enableTimer(cache_duration_); +} + +void JwksAsyncFetcher::onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) { + stats_.jwks_fetch_success_.inc(); + + done_fn_(std::move(jwks)); + handleFetchDone(); + + // Note: not to free fetcher_ within onJwksSuccess or onJwksError function. + // They are passed to fetcher_->fetch() and are called by fetcher_ after fetch is done. + // After calling these callback functions, fetch_ calls its reset() function. + // If fetcher_ is freed by the callback, calling reset() will crash. + + // Not need to free fetcher_. At the next fetch(), it will be freed with a cancel() call. + // The cancel() is needed to cancel the old call before the new one is created. + // But it is a no-op if the call is completed. +} + +void JwksAsyncFetcher::onJwksError(Failure) { + stats_.jwks_fetch_failed_.inc(); + + ENVOY_LOG(warn, "{}: failed", debug_name_); + handleFetchDone(); + + // Note: not to free fetcher_ in this function. Please see comment at onJwksSuccess. +} + +} // namespace JwtAuthn +} // namespace HttpFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/http/jwt_authn/jwks_async_fetcher.h b/source/extensions/filters/http/jwt_authn/jwks_async_fetcher.h new file mode 100644 index 0000000000000..c5909909b44ed --- /dev/null +++ b/source/extensions/filters/http/jwt_authn/jwks_async_fetcher.h @@ -0,0 +1,85 @@ +#pragma once + +#include + +#include "envoy/extensions/filters/http/jwt_authn/v3/config.pb.h" +#include "envoy/server/factory_context.h" + +#include "common/common/logger.h" +#include "common/init/target_impl.h" + +#include "extensions/filters/http/common/jwks_fetcher.h" +#include "extensions/filters/http/jwt_authn/stats.h" + +namespace Envoy { +namespace Extensions { +namespace HttpFilters { +namespace JwtAuthn { + +/** + * CreateJwksFetcherCb is a callback interface for creating a JwksFetcher instance. + */ +using CreateJwksFetcherCb = std::function; +/** + * JwksDoneFetched is a callback interface to set a Jwks when fetch is done. + */ +using JwksDoneFetched = std::function; + +// This class handles fetching Jwks asynchronously. +// It will be no-op if async_fetch is not enabled. +// At its constructor, it will start to fetch Jwks, register with init_manager if not fast_listener. +// and handle fetching response. When cache is expired, it will fetch again. +// When a Jwks is fetched, done_fn is called to set the Jwks. +class JwksAsyncFetcher : public Logger::Loggable, + public Common::JwksFetcher::JwksReceiver { +public: + JwksAsyncFetcher(const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks, + Server::Configuration::FactoryContext& context, CreateJwksFetcherCb fetcher_fn, + JwtAuthnFilterStats& stats, JwksDoneFetched done_fn); + + // Get the remote Jwks cache duration. + static std::chrono::seconds + getCacheDuration(const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks); + +private: + // Fetch the Jwks + void fetch(); + // Handle fetch done. + void handleFetchDone(); + + // Override the functions from Common::JwksFetcher::JwksReceiver + void onJwksSuccess(google::jwt_verify::JwksPtr&& jwks) override; + void onJwksError(Failure reason) override; + + // the remote Jwks config + const envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks& remote_jwks_; + // the factory context + Server::Configuration::FactoryContext& context_; + // the jwks fetcher creator function + const CreateJwksFetcherCb create_fetcher_fn_; + // stats + JwtAuthnFilterStats& stats_; + // the Jwks done function. + const JwksDoneFetched done_fn_; + + // The Jwks fetcher object + Common::JwksFetcherPtr fetcher_; + + // The cache duration. + const std::chrono::seconds cache_duration_; + // The timer to trigger fetch due to cache duration. + Envoy::Event::TimerPtr cache_duration_timer_; + + // The init target. + std::unique_ptr init_target_; + + // Used in logs. + const std::string debug_name_; +}; + +using JwksAsyncFetcherPtr = std::unique_ptr; + +} // namespace JwtAuthn +} // namespace HttpFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/http/jwt_authn/jwks_cache.cc b/source/extensions/filters/http/jwt_authn/jwks_cache.cc index 7204199fe3e59..48199f0f6d4ff 100644 --- a/source/extensions/filters/http/jwt_authn/jwks_cache.cc +++ b/source/extensions/filters/http/jwt_authn/jwks_cache.cc @@ -7,7 +7,6 @@ #include "common/common/logger.h" #include "common/config/datasource.h" -#include "common/protobuf/utility.h" #include "absl/container/node_hash_map.h" #include "jwt_verify_lib/check_audience.h" @@ -23,16 +22,12 @@ namespace HttpFilters { namespace JwtAuthn { namespace { -// Default cache expiration time in 5 minutes. -constexpr int PubkeyCacheExpirationSec = 600; - -using JwksSharedPtr = std::shared_ptr<::google::jwt_verify::Jwks>; - class JwksDataImpl : public JwksCache::JwksData, public Logger::Loggable { public: - JwksDataImpl(const JwtProvider& jwt_provider, TimeSource& time_source, Api::Api& api, - ThreadLocal::SlotAllocator& tls) - : jwt_provider_(jwt_provider), time_source_(time_source), tls_(tls) { + JwksDataImpl(const JwtProvider& jwt_provider, Server::Configuration::FactoryContext& context, + CreateJwksFetcherCb fetcher_cb, JwtAuthnFilterStats& stats) + : jwt_provider_(jwt_provider), time_source_(context.timeSource()), + tls_(context.threadLocal()) { std::vector audiences; for (const auto& aud : jwt_provider_.audiences()) { @@ -42,7 +37,8 @@ class JwksDataImpl : public JwksCache::JwksData, public Logger::Loggable(); }); - const auto inline_jwks = Config::DataSource::read(jwt_provider_.local_jwks(), true, api); + const auto inline_jwks = + Config::DataSource::read(jwt_provider_.local_jwks(), true, context.api()); if (!inline_jwks.empty()) { auto jwks = ::google::jwt_verify::Jwks::createFrom(inline_jwks, ::google::jwt_verify::Jwks::JWKS); @@ -50,7 +46,14 @@ class JwksDataImpl : public JwksCache::JwksData, public Logger::Loggable( + jwt_provider_.remote_jwks(), context, fetcher_cb, stats, + [this](google::jwt_verify::JwksPtr&& jwks) { setJwksToAllThreads(std::move(jwks)); }); } } } @@ -65,44 +68,32 @@ class JwksDataImpl : public JwksCache::JwksData, public Logger::Loggable= tls_->expire_; } - const ::google::jwt_verify::Jwks* setRemoteJwks(::google::jwt_verify::JwksPtr&& jwks) override { + const ::google::jwt_verify::Jwks* setRemoteJwks(JwksConstPtr&& jwks) override { // convert unique_ptr to shared_ptr - JwksSharedPtr shared_jwks(jwks.release()); + JwksConstSharedPtr shared_jwks = std::move(jwks); tls_->jwks_ = shared_jwks; - tls_->expire_ = getRemoteJwksExpirationTime(); + tls_->expire_ = time_source_.monotonicTime() + + JwksAsyncFetcher::getCacheDuration(jwt_provider_.remote_jwks()); return shared_jwks.get(); } private: struct ThreadLocalCache : public ThreadLocal::ThreadLocalObject { // The jwks object. - JwksSharedPtr jwks_; + JwksConstSharedPtr jwks_; // The pubkey expiration time. MonotonicTime expire_; }; // Set jwks shared_ptr to all threads. - void setJwksToAllThreads(::google::jwt_verify::JwksPtr&& jwks, - std::chrono::steady_clock::time_point expire) { - JwksSharedPtr shared_jwks(jwks.release()); - tls_.runOnAllThreads([shared_jwks, expire](OptRef obj) { + void setJwksToAllThreads(JwksConstPtr&& jwks) { + JwksConstSharedPtr shared_jwks = std::move(jwks); + tls_.runOnAllThreads([shared_jwks](OptRef obj) { obj->jwks_ = shared_jwks; - obj->expire_ = expire; + obj->expire_ = std::chrono::steady_clock::time_point::max(); }); } - // Get the expiration time for a remote Jwks - std::chrono::steady_clock::time_point getRemoteJwksExpirationTime() const { - auto expire = time_source_.monotonicTime(); - if (jwt_provider_.has_remote_jwks() && jwt_provider_.remote_jwks().has_cache_duration()) { - expire += std::chrono::milliseconds( - DurationUtil::durationToMilliseconds(jwt_provider_.remote_jwks().cache_duration())); - } else { - expire += std::chrono::seconds(PubkeyCacheExpirationSec); - } - return expire; - } - // The jwt provider config. const JwtProvider& jwt_provider_; // Check audience object @@ -111,6 +102,8 @@ class JwksDataImpl : public JwksCache::JwksData, public Logger::Loggable tls_; + // async fetcher + JwksAsyncFetcherPtr async_fetcher_; }; using JwksDataImplPtr = std::unique_ptr; @@ -118,11 +111,12 @@ using JwksDataImplPtr = std::unique_ptr; class JwksCacheImpl : public JwksCache { public: // Load the config from envoy config. - JwksCacheImpl(const JwtAuthentication& config, TimeSource& time_source, Api::Api& api, - ThreadLocal::SlotAllocator& tls) { + JwksCacheImpl(const JwtAuthentication& config, Server::Configuration::FactoryContext& context, + CreateJwksFetcherCb fetcher_fn, JwtAuthnFilterStats& stats) + : stats_(stats) { for (const auto& it : config.providers()) { const auto& provider = it.second; - auto jwks_data = std::make_unique(provider, time_source, api, tls); + auto jwks_data = std::make_unique(provider, context, fetcher_fn, stats); if (issuer_ptr_map_.find(provider.issuer()) == issuer_ptr_map_.end()) { issuer_ptr_map_.emplace(provider.issuer(), jwks_data.get()); } @@ -148,6 +142,8 @@ class JwksCacheImpl : public JwksCache { NOT_REACHED_GCOVR_EXCL_LINE; } + JwtAuthnFilterStats& stats() override { return stats_; } + private: JwksData* findIssuerMap(const std::string& issuer) { const auto& it = issuer_ptr_map_.find(issuer); @@ -157,6 +153,8 @@ class JwksCacheImpl : public JwksCache { return it->second; } + // stats + JwtAuthnFilterStats& stats_; // The Jwks data map indexed by provider. absl::node_hash_map jwks_data_map_; // The Jwks data pointer map indexed by issuer. @@ -167,8 +165,9 @@ class JwksCacheImpl : public JwksCache { JwksCachePtr JwksCache::create(const envoy::extensions::filters::http::jwt_authn::v3::JwtAuthentication& config, - TimeSource& time_source, Api::Api& api, ThreadLocal::SlotAllocator& tls) { - return std::make_unique(config, time_source, api, tls); + Server::Configuration::FactoryContext& context, CreateJwksFetcherCb fetcher_fn, + JwtAuthnFilterStats& stats) { + return std::make_unique(config, context, fetcher_fn, stats); } } // namespace JwtAuthn diff --git a/source/extensions/filters/http/jwt_authn/jwks_cache.h b/source/extensions/filters/http/jwt_authn/jwks_cache.h index e3b56348a7407..186401c44c965 100644 --- a/source/extensions/filters/http/jwt_authn/jwks_cache.h +++ b/source/extensions/filters/http/jwt_authn/jwks_cache.h @@ -6,6 +6,10 @@ #include "envoy/extensions/filters/http/jwt_authn/v3/config.pb.h" #include "envoy/thread_local/thread_local.h" +#include "extensions/filters/http/common/jwks_fetcher.h" +#include "extensions/filters/http/jwt_authn/jwks_async_fetcher.h" +#include "extensions/filters/http/jwt_authn/stats.h" + #include "jwt_verify_lib/jwks.h" namespace Envoy { @@ -16,6 +20,9 @@ namespace JwtAuthn { class JwksCache; using JwksCachePtr = std::unique_ptr; +using JwksConstPtr = std::unique_ptr; +using JwksConstSharedPtr = std::shared_ptr; + /** * Interface to access all configured Jwt rules and their cached Jwks objects. * It only caches Jwks specified in the config. @@ -57,8 +64,7 @@ class JwksCache { virtual bool isExpired() const PURE; // Set a remote Jwks. - virtual const ::google::jwt_verify::Jwks* - setRemoteJwks(::google::jwt_verify::JwksPtr&& jwks) PURE; + virtual const ::google::jwt_verify::Jwks* setRemoteJwks(JwksConstPtr&& jwks) PURE; }; // Lookup issuer cache map. The cache only stores Jwks specified in the config. @@ -67,10 +73,13 @@ class JwksCache { // Lookup provider cache map. virtual JwksData* findByProvider(const std::string& provider) PURE; + virtual JwtAuthnFilterStats& stats() PURE; + // Factory function to create an instance. static JwksCachePtr create(const envoy::extensions::filters::http::jwt_authn::v3::JwtAuthentication& config, - TimeSource& time_source, Api::Api& api, ThreadLocal::SlotAllocator& tls); + Server::Configuration::FactoryContext& context, CreateJwksFetcherCb fetcher_fn, + JwtAuthnFilterStats& stats); }; } // namespace JwtAuthn diff --git a/source/extensions/filters/http/jwt_authn/stats.h b/source/extensions/filters/http/jwt_authn/stats.h new file mode 100644 index 0000000000000..f1eccce135606 --- /dev/null +++ b/source/extensions/filters/http/jwt_authn/stats.h @@ -0,0 +1,30 @@ +#pragma once + +#include "envoy/stats/stats_macros.h" + +namespace Envoy { +namespace Extensions { +namespace HttpFilters { +namespace JwtAuthn { + +/** + * All stats for the Jwt Authn filter. @see stats_macros.h + */ +#define ALL_JWT_AUTHN_FILTER_STATS(COUNTER) \ + COUNTER(allowed) \ + COUNTER(cors_preflight_bypassed) \ + COUNTER(denied) \ + COUNTER(jwks_fetch_success) \ + COUNTER(jwks_fetch_failed) + +/** + * Wrapper struct for jwt_authn filter stats. @see stats_macros.h + */ +struct JwtAuthnFilterStats { + ALL_JWT_AUTHN_FILTER_STATS(GENERATE_COUNTER_STRUCT) +}; + +} // namespace JwtAuthn +} // namespace HttpFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/http/jwt_authn/BUILD b/test/extensions/filters/http/jwt_authn/BUILD index 4b63f37aaf650..e224b781035a9 100644 --- a/test/extensions/filters/http/jwt_authn/BUILD +++ b/test/extensions/filters/http/jwt_authn/BUILD @@ -68,6 +68,17 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "jwks_async_fetcher_test", + srcs = ["jwks_async_fetcher_test.cc"], + extension_name = "envoy.filters.http.jwt_authn", + deps = [ + "//source/extensions/filters/http/jwt_authn:jwks_async_fetcher_lib", + "//test/extensions/filters/http/jwt_authn:test_common_lib", + "//test/mocks/server:factory_context_mocks", + ], +) + envoy_extension_cc_test( name = "filter_factory_test", srcs = ["filter_factory_test.cc"], @@ -89,8 +100,7 @@ envoy_extension_cc_test( "//source/extensions/filters/http/common:jwks_fetcher_lib", "//source/extensions/filters/http/jwt_authn:jwks_cache_lib", "//test/extensions/filters/http/jwt_authn:test_common_lib", - "//test/mocks/thread_local:thread_local_mocks", - "//test/test_common:simulated_time_system_lib", + "//test/mocks/server:factory_context_mocks", "//test/test_common:utility_lib", "@envoy_api//envoy/extensions/filters/http/jwt_authn/v3:pkg_cc_proto", ], @@ -104,7 +114,7 @@ envoy_extension_cc_test( ":mock_lib", "//source/extensions/filters/http/common:jwks_fetcher_lib", "//source/extensions/filters/http/jwt_authn:authenticator_lib", - "//source/extensions/filters/http/jwt_authn:filter_config_interface", + "//source/extensions/filters/http/jwt_authn:filter_config_lib", "//source/extensions/filters/http/jwt_authn:matchers_lib", "//test/extensions/filters/http/common:mock_lib", "//test/extensions/filters/http/jwt_authn:test_common_lib", @@ -170,7 +180,7 @@ envoy_extension_cc_test( deps = [ ":mock_lib", ":test_common_lib", - "//source/extensions/filters/http/jwt_authn:filter_config_interface", + "//source/extensions/filters/http/jwt_authn:filter_config_lib", "//source/extensions/filters/http/jwt_authn:matchers_lib", "//test/mocks/server:factory_context_mocks", "//test/test_common:utility_lib", diff --git a/test/extensions/filters/http/jwt_authn/authenticator_test.cc b/test/extensions/filters/http/jwt_authn/authenticator_test.cc index e45145eaea6cc..41d7d2253cbd3 100644 --- a/test/extensions/filters/http/jwt_authn/authenticator_test.cc +++ b/test/extensions/filters/http/jwt_authn/authenticator_test.cc @@ -108,6 +108,9 @@ TEST_F(AuthenticatorTest, TestOkJWTandCache) { // Verify the token is removed. EXPECT_FALSE(headers.has(Http::CustomHeaders::get().Authorization)); } + + EXPECT_EQ(1U, filter_config_->stats().jwks_fetch_success_.value()); + EXPECT_EQ(0U, filter_config_->stats().jwks_fetch_failed_.value()); } // This test verifies the Jwt is forwarded if "forward" flag is set. @@ -131,6 +134,9 @@ TEST_F(AuthenticatorTest, TestForwardJwt) { // Payload not set by default EXPECT_EQ(out_name_, ""); + + EXPECT_EQ(1U, filter_config_->stats().jwks_fetch_success_.value()); + EXPECT_EQ(0U, filter_config_->stats().jwks_fetch_failed_.value()); } // This test verifies the Jwt payload is set. @@ -181,6 +187,9 @@ TEST_F(AuthenticatorTest, TestWrongIssuer) { Http::TestRequestHeaderMapImpl headers{ {"Authorization", "Bearer " + std::string(OtherGoodToken)}}; expectVerifyStatus(Status::JwtUnknownIssuer, headers); + + EXPECT_EQ(0U, filter_config_->stats().jwks_fetch_success_.value()); + EXPECT_EQ(0U, filter_config_->stats().jwks_fetch_failed_.value()); } // Jwt "iss" is "other.com", "issuer" in JwtProvider is not specified, diff --git a/test/extensions/filters/http/jwt_authn/filter_integration_test.cc b/test/extensions/filters/http/jwt_authn/filter_integration_test.cc index 7adac3450ce60..0a5792d1c1f76 100644 --- a/test/extensions/filters/http/jwt_authn/filter_integration_test.cc +++ b/test/extensions/filters/http/jwt_authn/filter_integration_test.cc @@ -78,6 +78,20 @@ std::string getAuthFilterConfig(const std::string& config_str, bool use_local_jw return MessageUtil::getJsonStringFromMessageOrDie(filter); } +std::string getAsyncFetchFilterConfig(const std::string& config_str, bool fast_listener) { + JwtAuthentication proto_config; + TestUtility::loadFromYaml(config_str, proto_config); + + auto& provider0 = (*proto_config.mutable_providers())[std::string(ProviderName)]; + auto* async_fetch = provider0.mutable_remote_jwks()->mutable_async_fetch(); + async_fetch->set_fast_listener(fast_listener); + + HttpFilter filter; + filter.set_name(HttpFilterNames::get().JwtAuthn); + filter.mutable_typed_config()->PackFrom(proto_config); + return MessageUtil::getJsonStringFromMessageOrDie(filter); +} + std::string getFilterConfig(bool use_local_jwks) { return getAuthFilterConfig(ExampleConfig, use_local_jwks); } @@ -329,6 +343,18 @@ class RemoteJwksIntegrationTest : public HttpProtocolIntegrationTest { initialize(); } + void initializeAsyncFetchFilter(bool fast_listener) { + config_helper_.addFilter(getAsyncFetchFilterConfig(ExampleConfig, fast_listener)); + + config_helper_.addConfigModifier([](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { + auto* jwks_cluster = bootstrap.mutable_static_resources()->add_clusters(); + jwks_cluster->MergeFrom(bootstrap.static_resources().clusters()[0]); + jwks_cluster->set_name("pubkey_cluster"); + }); + + initialize(); + } + void waitForJwksResponse(const std::string& status, const std::string& jwks_body) { AssertionResult result = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_jwks_connection_); @@ -448,6 +474,112 @@ TEST_P(RemoteJwksIntegrationTest, FetchFailedMissingCluster) { cleanup(); } +TEST_P(RemoteJwksIntegrationTest, WithGoodTokenAsyncFetch) { + on_server_init_function_ = [this]() { waitForJwksResponse("200", PublicKey); }; + initializeAsyncFetchFilter(false); + + codec_client_ = makeHttpConnection(lookupPort("http")); + + auto response = codec_client_->makeHeaderOnlyRequest(Http::TestRequestHeaderMapImpl{ + {":method", "GET"}, + {":path", "/"}, + {":scheme", "http"}, + {":authority", "host"}, + {"Authorization", "Bearer " + std::string(GoodToken)}, + }); + + waitForNextUpstreamRequest(); + + const auto payload_entry = + upstream_request_->headers().get(Http::LowerCaseString("sec-istio-auth-userinfo")); + EXPECT_FALSE(payload_entry.empty()); + EXPECT_EQ(payload_entry[0]->value().getStringView(), ExpectedPayloadValue); + // Verify the token is removed. + EXPECT_TRUE(upstream_request_->headers().get(Http::CustomHeaders::get().Authorization).empty()); + + upstream_request_->encodeHeaders(Http::TestResponseHeaderMapImpl{{":status", "200"}}, true); + + ASSERT_TRUE(response->waitForEndStream()); + ASSERT_TRUE(response->complete()); + EXPECT_EQ("200", response->headers().getStatusValue()); + + cleanup(); +} + +TEST_P(RemoteJwksIntegrationTest, WithGoodTokenAsyncFetchFast) { + on_server_init_function_ = [this]() { waitForJwksResponse("200", PublicKey); }; + initializeAsyncFetchFilter(true); + + codec_client_ = makeHttpConnection(lookupPort("http")); + + auto response = codec_client_->makeHeaderOnlyRequest(Http::TestRequestHeaderMapImpl{ + {":method", "GET"}, + {":path", "/"}, + {":scheme", "http"}, + {":authority", "host"}, + {"Authorization", "Bearer " + std::string(GoodToken)}, + }); + + waitForNextUpstreamRequest(); + + const auto payload_entry = + upstream_request_->headers().get(Http::LowerCaseString("sec-istio-auth-userinfo")); + EXPECT_FALSE(payload_entry.empty()); + EXPECT_EQ(payload_entry[0]->value().getStringView(), ExpectedPayloadValue); + // Verify the token is removed. + EXPECT_TRUE(upstream_request_->headers().get(Http::CustomHeaders::get().Authorization).empty()); + + upstream_request_->encodeHeaders(Http::TestResponseHeaderMapImpl{{":status", "200"}}, true); + + ASSERT_TRUE(response->waitForEndStream()); + ASSERT_TRUE(response->complete()); + EXPECT_EQ("200", response->headers().getStatusValue()); + + cleanup(); +} + +TEST_P(RemoteJwksIntegrationTest, WithFailedJwksAsyncFetch) { + on_server_init_function_ = [this]() { waitForJwksResponse("500", ""); }; + initializeAsyncFetchFilter(false); + + codec_client_ = makeHttpConnection(lookupPort("http")); + + auto response = codec_client_->makeHeaderOnlyRequest(Http::TestRequestHeaderMapImpl{ + {":method", "GET"}, + {":path", "/"}, + {":scheme", "http"}, + {":authority", "host"}, + {"Authorization", "Bearer " + std::string(GoodToken)}, + }); + + ASSERT_TRUE(response->waitForEndStream()); + ASSERT_TRUE(response->complete()); + EXPECT_EQ("401", response->headers().getStatusValue()); + + cleanup(); +} + +TEST_P(RemoteJwksIntegrationTest, WithFailedJwksAsyncFetchFast) { + on_server_init_function_ = [this]() { waitForJwksResponse("500", ""); }; + initializeAsyncFetchFilter(true); + + codec_client_ = makeHttpConnection(lookupPort("http")); + + auto response = codec_client_->makeHeaderOnlyRequest(Http::TestRequestHeaderMapImpl{ + {":method", "GET"}, + {":path", "/"}, + {":scheme", "http"}, + {":authority", "host"}, + {"Authorization", "Bearer " + std::string(GoodToken)}, + }); + + ASSERT_TRUE(response->waitForEndStream()); + ASSERT_TRUE(response->complete()); + EXPECT_EQ("401", response->headers().getStatusValue()); + + cleanup(); +} + class PerRouteIntegrationTest : public HttpProtocolIntegrationTest { public: void setup(const std::string& filter_config, const PerRouteConfig& per_route) { diff --git a/test/extensions/filters/http/jwt_authn/jwks_async_fetcher_test.cc b/test/extensions/filters/http/jwt_authn/jwks_async_fetcher_test.cc new file mode 100644 index 0000000000000..1d21e254d7f94 --- /dev/null +++ b/test/extensions/filters/http/jwt_authn/jwks_async_fetcher_test.cc @@ -0,0 +1,248 @@ +#include "extensions/filters/http/jwt_authn/jwks_async_fetcher.h" + +#include "test/extensions/filters/http/jwt_authn/test_common.h" +#include "test/mocks/server/factory_context.h" + +using envoy::extensions::filters::http::jwt_authn::v3::RemoteJwks; +using Envoy::Extensions::HttpFilters::Common::JwksFetcher; +using Envoy::Extensions::HttpFilters::Common::JwksFetcherPtr; + +namespace Envoy { +namespace Extensions { +namespace HttpFilters { +namespace JwtAuthn { +namespace { + +JwtAuthnFilterStats generateMockStats(Stats::Scope& scope) { + return {ALL_JWT_AUTHN_FILTER_STATS(POOL_COUNTER_PREFIX(scope, ""))}; +} + +class MockJwksFetcher : public Common::JwksFetcher { +public: + using SaveJwksReceiverFn = std::function; + MockJwksFetcher(SaveJwksReceiverFn receiver_fn) : receiver_fn_(receiver_fn) {} + + void cancel() override {} + void fetch(const envoy::config::core::v3::HttpUri&, Tracing::Span&, + JwksReceiver& receiver) override { + receiver_fn_(receiver); + } + +private: + SaveJwksReceiverFn receiver_fn_; +}; + +// TestParam is for fast_listener, +class JwksAsyncFetcherTest : public testing::TestWithParam { +public: + JwksAsyncFetcherTest() : stats_(generateMockStats(context_.scope())) {} + + // init manager is used in is_slow_listener mode + bool initManagerUsed() const { + return config_.has_async_fetch() && !config_.async_fetch().fast_listener(); + } + + void setupAsyncFetcher(const std::string& config_str) { + TestUtility::loadFromYaml(config_str, config_); + if (config_.has_async_fetch()) { + // Param is for fast_listener, + if (GetParam()) { + config_.mutable_async_fetch()->set_fast_listener(true); + } + } + + if (initManagerUsed()) { + EXPECT_CALL(context_.init_manager_, add(_)) + .WillOnce(Invoke([this](const Init::Target& target) { + init_target_handle_ = target.createHandle("test"); + })); + } + + // if async_fetch is enabled, timer is created + if (config_.has_async_fetch()) { + timer_ = new NiceMock(&context_.dispatcher_); + expected_duration_ = JwksAsyncFetcher::getCacheDuration(config_); + } + + async_fetcher_ = std::make_unique( + config_, context_, + [this](Upstream::ClusterManager&) { + return std::make_unique( + [this](Common::JwksFetcher::JwksReceiver& receiver) { + fetch_receiver_array_.push_back(&receiver); + }); + }, + stats_, + [this](google::jwt_verify::JwksPtr&& jwks) { out_jwks_array_.push_back(std::move(jwks)); }); + + if (initManagerUsed()) { + init_target_handle_->initialize(init_watcher_); + } + } + + RemoteJwks config_; + JwksAsyncFetcherPtr async_fetcher_; + NiceMock context_; + JwtAuthnFilterStats stats_; + std::vector fetch_receiver_array_; + std::vector out_jwks_array_; + + Init::TargetHandlePtr init_target_handle_; + NiceMock init_watcher_; + Event::MockTimer* timer_{}; + std::chrono::milliseconds expected_duration_; +}; + +INSTANTIATE_TEST_SUITE_P(JwksAsyncFetcherTest, JwksAsyncFetcherTest, + testing::ValuesIn({false, true})); + +TEST_P(JwksAsyncFetcherTest, TestNotAsyncFetch) { + const char config[] = R"( + http_uri: + uri: https://pubkey_server/pubkey_path + cluster: pubkey_cluster +)"; + + setupAsyncFetcher(config); + // fetch is not called + EXPECT_EQ(fetch_receiver_array_.size(), 0); + // Not Jwks output + EXPECT_EQ(out_jwks_array_.size(), 0); + // init_watcher ready is not called. + init_watcher_.expectReady().Times(0); + + EXPECT_EQ(0U, stats_.jwks_fetch_success_.value()); + EXPECT_EQ(0U, stats_.jwks_fetch_failed_.value()); +} + +TEST_P(JwksAsyncFetcherTest, TestGoodFetch) { + const char config[] = R"( + http_uri: + uri: https://pubkey_server/pubkey_path + cluster: pubkey_cluster + async_fetch: {} +)"; + + setupAsyncFetcher(config); + // Jwks response is not received yet + EXPECT_EQ(out_jwks_array_.size(), 0); + + if (initManagerUsed()) { + // Verify ready is not called. + init_watcher_.expectReady().Times(0); + EXPECT_TRUE(::testing::Mock::VerifyAndClearExpectations(&init_watcher_)); + init_watcher_.expectReady(); + } + + // Trigger the Jwks response + EXPECT_EQ(fetch_receiver_array_.size(), 1); + auto jwks = google::jwt_verify::Jwks::createFrom(PublicKey, google::jwt_verify::Jwks::JWKS); + fetch_receiver_array_[0]->onJwksSuccess(std::move(jwks)); + + // Output 1 jwks. + EXPECT_EQ(out_jwks_array_.size(), 1); + + EXPECT_EQ(1U, stats_.jwks_fetch_success_.value()); + EXPECT_EQ(0U, stats_.jwks_fetch_failed_.value()); +} + +TEST_P(JwksAsyncFetcherTest, TestNetworkFailureFetch) { + const char config[] = R"( + http_uri: + uri: https://pubkey_server/pubkey_path + cluster: pubkey_cluster + async_fetch: {} +)"; + + // Just start the Jwks fetch call + setupAsyncFetcher(config); + // Jwks response is not received yet + EXPECT_EQ(out_jwks_array_.size(), 0); + + if (initManagerUsed()) { + // Verify ready is not called. + init_watcher_.expectReady().Times(0); + EXPECT_TRUE(::testing::Mock::VerifyAndClearExpectations(&init_watcher_)); + // Verify ready is called. + init_watcher_.expectReady(); + } + + // Trigger the Jwks response + EXPECT_EQ(fetch_receiver_array_.size(), 1); + fetch_receiver_array_[0]->onJwksError(Common::JwksFetcher::JwksReceiver::Failure::Network); + + // Output 0 jwks. + EXPECT_EQ(out_jwks_array_.size(), 0); + + EXPECT_EQ(0U, stats_.jwks_fetch_success_.value()); + EXPECT_EQ(1U, stats_.jwks_fetch_failed_.value()); +} + +TEST_P(JwksAsyncFetcherTest, TestGoodFetchAndRefresh) { + const char config[] = R"( + http_uri: + uri: https://pubkey_server/pubkey_path + cluster: pubkey_cluster + async_fetch: {} +)"; + + setupAsyncFetcher(config); + // Initial fetch is successful + EXPECT_EQ(fetch_receiver_array_.size(), 1); + auto jwks = google::jwt_verify::Jwks::createFrom(PublicKey, google::jwt_verify::Jwks::JWKS); + fetch_receiver_array_[0]->onJwksSuccess(std::move(jwks)); + + // Output 1 jwks. + EXPECT_EQ(out_jwks_array_.size(), 1); + + // Expect refresh timer is enabled. + EXPECT_CALL(*timer_, enableTimer(expected_duration_, nullptr)); + timer_->invokeCallback(); + + // refetch again after cache duration interval: successful. + EXPECT_EQ(fetch_receiver_array_.size(), 2); + auto jwks1 = google::jwt_verify::Jwks::createFrom(PublicKey, google::jwt_verify::Jwks::JWKS); + fetch_receiver_array_[1]->onJwksSuccess(std::move(jwks1)); + + // Output 2 jwks. + EXPECT_EQ(out_jwks_array_.size(), 2); + EXPECT_EQ(2U, stats_.jwks_fetch_success_.value()); + EXPECT_EQ(0U, stats_.jwks_fetch_failed_.value()); +} + +TEST_P(JwksAsyncFetcherTest, TestNetworkFailureFetchAndRefresh) { + const char config[] = R"( + http_uri: + uri: https://pubkey_server/pubkey_path + cluster: pubkey_cluster + async_fetch: {} +)"; + + // Just start the Jwks fetch call + setupAsyncFetcher(config); + // first fetch: network failure. + EXPECT_EQ(fetch_receiver_array_.size(), 1); + fetch_receiver_array_[0]->onJwksError(Common::JwksFetcher::JwksReceiver::Failure::Network); + + // Output 0 jwks. + EXPECT_EQ(out_jwks_array_.size(), 0); + + // Expect refresh timer is enabled. + EXPECT_CALL(*timer_, enableTimer(expected_duration_, nullptr)); + timer_->invokeCallback(); + + // refetch again after cache duration interval: network failure. + EXPECT_EQ(fetch_receiver_array_.size(), 2); + fetch_receiver_array_[1]->onJwksError(Common::JwksFetcher::JwksReceiver::Failure::Network); + + // Output 0 jwks. + EXPECT_EQ(out_jwks_array_.size(), 0); + EXPECT_EQ(0U, stats_.jwks_fetch_success_.value()); + EXPECT_EQ(2U, stats_.jwks_fetch_failed_.value()); +} + +} // namespace +} // namespace JwtAuthn +} // namespace HttpFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/http/jwt_authn/jwks_cache_test.cc b/test/extensions/filters/http/jwt_authn/jwks_cache_test.cc index a0fcdde765c15..6c22ec49a54bc 100644 --- a/test/extensions/filters/http/jwt_authn/jwks_cache_test.cc +++ b/test/extensions/filters/http/jwt_authn/jwks_cache_test.cc @@ -9,12 +9,12 @@ #include "extensions/filters/http/jwt_authn/jwks_cache.h" #include "test/extensions/filters/http/jwt_authn/test_common.h" -#include "test/mocks/thread_local/mocks.h" -#include "test/test_common/simulated_time_system.h" +#include "test/mocks/server/factory_context.h" #include "test/test_common/utility.h" using envoy::extensions::filters::http::jwt_authn::v3::JwtAuthentication; using ::google::jwt_verify::Status; +using ::testing::MockFunction; namespace Envoy { namespace Extensions { @@ -22,25 +22,32 @@ namespace HttpFilters { namespace JwtAuthn { namespace { +JwtAuthnFilterStats generateMockStats(Stats::Scope& scope) { + return {ALL_JWT_AUTHN_FILTER_STATS(POOL_COUNTER_PREFIX(scope, ""))}; +} + class JwksCacheTest : public testing::Test { protected: - JwksCacheTest() : api_(Api::createApiForTest()) {} + JwksCacheTest() : stats_(generateMockStats(context_.scope())) {} + void SetUp() override { + // fetcher is only called at async_fetch. In this test, it is never called. + EXPECT_CALL(mock_fetcher_, Call(_)).Times(0); setupCache(ExampleConfig); jwks_ = google::jwt_verify::Jwks::createFrom(PublicKey, google::jwt_verify::Jwks::JWKS); } void setupCache(const std::string& config_str) { TestUtility::loadFromYaml(config_str, config_); - cache_ = JwksCache::create(config_, time_system_, *api_, tls_); + cache_ = JwksCache::create(config_, context_, mock_fetcher_.AsStdFunction(), stats_); } - Event::SimulatedTimeSystem time_system_; JwtAuthentication config_; JwksCachePtr cache_; google::jwt_verify::JwksPtr jwks_; - Api::ApiPtr api_; - ::testing::NiceMock tls_; + MockFunction mock_fetcher_; + NiceMock context_; + JwtAuthnFilterStats stats_; }; // Test findByProvider @@ -84,7 +91,7 @@ TEST_F(JwksCacheTest, TestSetRemoteJwks) { auto& provider0 = (*config_.mutable_providers())[std::string(ProviderName)]; // Set cache_duration to 1 second to test expiration provider0.mutable_remote_jwks()->mutable_cache_duration()->set_seconds(1); - cache_ = JwksCache::create(config_, time_system_, *api_, tls_); + cache_ = JwksCache::create(config_, context_, mock_fetcher_.AsStdFunction(), stats_); auto jwks = cache_->findByIssuer("https://example.com"); EXPECT_TRUE(jwks->getJwksObj() == nullptr); @@ -94,7 +101,7 @@ TEST_F(JwksCacheTest, TestSetRemoteJwks) { EXPECT_FALSE(jwks->isExpired()); // cache duration is 1 second, sleep two seconds to expire it - time_system_.advanceTimeWait(std::chrono::seconds(2)); + context_.time_system_.advanceTimeWait(std::chrono::seconds(2)); EXPECT_TRUE(jwks->isExpired()); } @@ -103,7 +110,7 @@ TEST_F(JwksCacheTest, TestSetRemoteJwksWithDefaultCacheDuration) { auto& provider0 = (*config_.mutable_providers())[std::string(ProviderName)]; // Clear cache_duration to use default one. provider0.mutable_remote_jwks()->clear_cache_duration(); - cache_ = JwksCache::create(config_, time_system_, *api_, tls_); + cache_ = JwksCache::create(config_, context_, mock_fetcher_.AsStdFunction(), stats_); auto jwks = cache_->findByIssuer("https://example.com"); EXPECT_TRUE(jwks->getJwksObj() == nullptr); @@ -120,7 +127,7 @@ TEST_F(JwksCacheTest, TestGoodInlineJwks) { auto local_jwks = provider0.mutable_local_jwks(); local_jwks->set_inline_string(PublicKey); - cache_ = JwksCache::create(config_, time_system_, *api_, tls_); + cache_ = JwksCache::create(config_, context_, mock_fetcher_.AsStdFunction(), stats_); auto jwks = cache_->findByIssuer("https://example.com"); EXPECT_FALSE(jwks->getJwksObj() == nullptr); @@ -134,7 +141,7 @@ TEST_F(JwksCacheTest, TestBadInlineJwks) { auto local_jwks = provider0.mutable_local_jwks(); local_jwks->set_inline_string("BAD-JWKS"); - cache_ = JwksCache::create(config_, time_system_, *api_, tls_); + cache_ = JwksCache::create(config_, context_, mock_fetcher_.AsStdFunction(), stats_); auto jwks = cache_->findByIssuer("https://example.com"); EXPECT_TRUE(jwks->getJwksObj() == nullptr);