diff --git a/api/envoy/config/grpc_credential/v2alpha/BUILD b/api/envoy/config/grpc_credential/v2alpha/BUILD index ca0a71eaef6cc..29e2a740c76b5 100644 --- a/api/envoy/config/grpc_credential/v2alpha/BUILD +++ b/api/envoy/config/grpc_credential/v2alpha/BUILD @@ -2,6 +2,16 @@ licenses(["notice"]) # Apache 2 load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") +api_proto_library_internal( + name = "aws_iam", + srcs = ["aws_iam.proto"], +) + +api_go_proto_library( + name = "aws_iam", + proto = ":aws_iam", +) + api_proto_library_internal( name = "file_based_metadata", srcs = ["file_based_metadata.proto"], diff --git a/api/envoy/config/grpc_credential/v2alpha/aws_iam.proto b/api/envoy/config/grpc_credential/v2alpha/aws_iam.proto new file mode 100644 index 0000000000000..5531b658be2d7 --- /dev/null +++ b/api/envoy/config/grpc_credential/v2alpha/aws_iam.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +// [#protodoc-title: Grpc Credentials AWS IAM] +// Configuration for AWS IAM Grpc Credentials Plugin + +package envoy.config.grpc_credential.v2alpha; +option java_package = "io.envoyproxy.envoy.config.grpc_credential.v2alpha"; +option java_multiple_files = true; +option go_package = "v2alpha"; + +import "validate/validate.proto"; + +message AwsIamConfig { + // The `service namespace + // `_ + // of the Grpc endpoint. + // + // Example: appmesh + string service_name = 1 [(validate.rules).string.min_bytes = 1]; + + // The `region `_ hosting the Grpc + // endpoint. + // + // Example: us-west-2 + string region = 2; +} diff --git a/docs/root/intro/version_history.rst b/docs/root/intro/version_history.rst index 17a47a8c772f3..52b09b2bb9db7 100644 --- a/docs/root/intro/version_history.rst +++ b/docs/root/intro/version_history.rst @@ -8,6 +8,7 @@ Version history * config: removed deprecated_v1 sds_config from :ref:`Bootstrap config `. * config: removed REST_LEGACY as a valid :ref:`ApiType `. * cors: added :ref:`filter_enabled & shadow_enabled RuntimeFractionalPercent flags ` to filter. +* grpc: added AWS IAM grpc credentials extension for AWS-managed xDS. * http: added new grpc_http1_reverse_bridge filter for converting gRPC requests into HTTP/1.1 requests. * tls: enabled TLS 1.3 on the server-side (non-FIPS builds). diff --git a/include/envoy/grpc/BUILD b/include/envoy/grpc/BUILD index 087a5f8c61d2c..521179147796d 100644 --- a/include/envoy/grpc/BUILD +++ b/include/envoy/grpc/BUILD @@ -37,6 +37,7 @@ envoy_cc_library( "grpc", ], deps = [ + "//include/envoy/api:api_interface", "@envoy_api//envoy/api/v2/core:grpc_service_cc", ], ) diff --git a/include/envoy/grpc/google_grpc_creds.h b/include/envoy/grpc/google_grpc_creds.h index 9228fc6990532..a994250304405 100644 --- a/include/envoy/grpc/google_grpc_creds.h +++ b/include/envoy/grpc/google_grpc_creds.h @@ -2,6 +2,7 @@ #include +#include "envoy/api/api.h" #include "envoy/api/v2/core/grpc_service.pb.h" #include "envoy/common/pure.h" @@ -10,6 +11,15 @@ namespace Envoy { namespace Grpc { +class GoogleGrpcCredentialsFactoryContext { +public: + virtual ~GoogleGrpcCredentialsFactoryContext() = default; + + virtual Api::Api& api() PURE; + + virtual Event::TimeSystem& timeSystem() PURE; +}; + /** * Interface for all Google gRPC credentials factories. */ @@ -25,11 +35,13 @@ class GoogleGrpcCredentialsFactory { * CompositeCallCredentials to combine multiple credentials. * * @param grpc_service_config contains configuration options + * @param context provides the factory's context * @return std::shared_ptr to be used to authenticate a Google gRPC * channel. */ virtual std::shared_ptr - getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config) PURE; + getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config, + GoogleGrpcCredentialsFactoryContext& context) PURE; /** * @return std::string the identifying name for a particular implementation of diff --git a/source/common/aws/BUILD b/source/common/aws/BUILD new file mode 100644 index 0000000000000..29ddb3e577685 --- /dev/null +++ b/source/common/aws/BUILD @@ -0,0 +1,99 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "signer_lib", + hdrs = ["signer.h"], + deps = [ + "//include/envoy/http:message_interface", + ], +) + +envoy_cc_library( + name = "signer_impl_lib", + srcs = ["signer_impl.cc"], + hdrs = ["signer_impl.h"], + external_deps = ["ssl"], + deps = [ + ":credentials_provider_lib", + ":region_provider_lib", + ":signer_lib", + "//source/common/buffer:buffer_lib", + "//source/common/common:assert_lib", + "//source/common/common:hex_lib", + "//source/common/common:logger_lib", + "//source/common/common:stack_array", + "//source/common/common:utility_lib", + "//source/common/http:headers_lib", + ], +) + +envoy_cc_library( + name = "credentials_provider_lib", + hdrs = ["credentials_provider.h"], + external_deps = ["abseil_optional"], +) + +envoy_cc_library( + name = "credentials_provider_impl_lib", + srcs = [ + "credentials_provider_impl.cc", + ], + hdrs = [ + "credentials_provider_impl.h", + ], + deps = [ + ":credentials_provider_lib", + ":metadata_fetcher_impl_lib", + "//source/common/common:lock_guard_lib", + "//source/common/common:logger_lib", + "//source/common/common:utility_lib", + "//source/common/json:json_loader_lib", + ], +) + +envoy_cc_library( + name = "metadata_fetcher_lib", + hdrs = ["metadata_fetcher.h"], + external_deps = ["abseil_optional"], + deps = ["//include/envoy/event:dispatcher_interface"], +) + +envoy_cc_library( + name = "metadata_fetcher_impl_lib", + srcs = ["metadata_fetcher_impl.cc"], + hdrs = ["metadata_fetcher_impl.h"], + external_deps = ["grpc"], + deps = [ + ":metadata_fetcher_lib", + "//include/envoy/event:dispatcher_interface", + "//include/envoy/http:header_map_interface", + "//include/envoy/network:transport_socket_interface", + "//source/common/common:logger_lib", + "//source/common/http/http1:codec_lib", + "//source/common/network:filter_lib", + "//source/common/network:raw_buffer_socket_lib", + ], +) + +envoy_cc_library( + name = "region_provider_lib", + hdrs = ["region_provider.h"], +) + +envoy_cc_library( + name = "region_provider_impl_lib", + srcs = ["region_provider_impl.cc"], + hdrs = ["region_provider_impl.h"], + deps = [ + ":region_provider_lib", + "//source/common/common:logger_lib", + ], +) diff --git a/source/common/aws/credentials_provider.h b/source/common/aws/credentials_provider.h new file mode 100644 index 0000000000000..d65aeda49fe64 --- /dev/null +++ b/source/common/aws/credentials_provider.h @@ -0,0 +1,61 @@ +#pragma once + +#include + +#include "envoy/common/pure.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class Credentials { +public: + Credentials() = default; + ~Credentials() = default; + + Credentials(const std::string& access_key_id, const std::string& secret_access_key) + : access_key_id_(access_key_id), secret_access_key_(secret_access_key) {} + + Credentials(const std::string& access_key_id, const std::string& secret_access_key, + const std::string& session_token) + : access_key_id_(access_key_id), secret_access_key_(secret_access_key), + session_token_(session_token) {} + + void setAccessKeyId(const std::string& access_key_id) { + access_key_id_ = absl::optional(access_key_id); + } + + const absl::optional& accessKeyId() const { return access_key_id_; } + + void setSecretAccessKey(const std::string& secret_key) { + secret_access_key_ = absl::optional(secret_key); + } + + const absl::optional& secretAccessKey() const { return secret_access_key_; } + + void setSessionToken(const std::string& session_token) { + session_token_ = absl::optional(session_token); + } + + const absl::optional& sessionToken() const { return session_token_; } + +private: + absl::optional access_key_id_; + absl::optional secret_access_key_; + absl::optional session_token_; +}; + +class CredentialsProvider { +public: + virtual ~CredentialsProvider() = default; + + virtual Credentials getCredentials() PURE; +}; + +typedef std::shared_ptr CredentialsProviderSharedPtr; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/credentials_provider_impl.cc b/source/common/aws/credentials_provider_impl.cc new file mode 100644 index 0000000000000..5685a4e82f64f --- /dev/null +++ b/source/common/aws/credentials_provider_impl.cc @@ -0,0 +1,241 @@ +#include "common/aws/credentials_provider_impl.h" + +#include + +#include "envoy/common/exception.h" + +#include "common/aws/metadata_fetcher_impl.h" +#include "common/common/lock_guard.h" +#include "common/common/utility.h" +#include "common/http/utility.h" +#include "common/json/json_loader.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +static const char AWS_ACCESS_KEY_ID[] = "AWS_ACCESS_KEY_ID"; +static const char AWS_SECRET_ACCESS_KEY[] = "AWS_SECRET_ACCESS_KEY"; +static const char AWS_SESSION_TOKEN[] = "AWS_SESSION_TOKEN"; + +static const char ACCESS_KEY_ID[] = "AccessKeyId"; +static const char SECRET_ACCESS_KEY[] = "SecretAccessKey"; +static const char TOKEN[] = "Token"; +static const char EXPIRATION[] = "Expiration"; +static const char EXPIRATION_FORMAT[] = "%Y%m%dT%H%M%S%z"; +static const char TRUE[] = "true"; + +static const char AWS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"; +static const char AWS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI"; +static const char AWS_CONTAINER_AUTHORIZATION_TOKEN[] = "AWS_CONTAINER_AUTHORIZATION_TOKEN"; +static const char AWS_EC2_METADATA_DISABLED[] = "AWS_EC2_METADATA_DISABLED"; + +static const std::chrono::hours REFRESH_INTERVAL{1}; +static const std::chrono::seconds REFRESH_GRACE_PERIOD{5}; +static const char EC2_METADATA_HOST[] = "169.254.169.254:80"; +static const char CONTAINER_METADATA_HOST[] = "169.254.170.2:80"; +static const char SECURITY_CREDENTIALS_PATH[] = "/latest/meta-data/iam/security-credentials"; + +Credentials EnvironmentCredentialsProvider::getCredentials() { + const auto access_key_id = std::getenv(AWS_ACCESS_KEY_ID); + Credentials credentials; + if (access_key_id == nullptr) { + return credentials; + } + ENVOY_LOG(debug, "Found environment credential {}={}", AWS_ACCESS_KEY_ID, access_key_id); + credentials.setAccessKeyId(access_key_id); + const auto secret_access_key = std::getenv(AWS_SECRET_ACCESS_KEY); + if (secret_access_key != nullptr) { + ENVOY_LOG(debug, "Found environment credential {}=*****", AWS_SECRET_ACCESS_KEY); + credentials.setSecretAccessKey(secret_access_key); + } + const auto session_token = std::getenv(AWS_SESSION_TOKEN); + if (session_token != nullptr) { + ENVOY_LOG(debug, "Found environment credential {}=*****", AWS_SESSION_TOKEN); + credentials.setSessionToken(session_token); + } + return credentials; +} + +void MetadataCredentialsProviderBase::refreshIfNeeded() { + Thread::LockGuard lock(lock_); + if (!needsRefresh()) { + return; + } + refresh(); +} + +bool InstanceProfileCredentialsProvider::needsRefresh() { + if (time_system_.systemTime() - last_updated_ > REFRESH_INTERVAL) { + return true; + } + return false; +} + +void InstanceProfileCredentialsProvider::refresh() { + auto dispatcher = api_.allocateDispatcher(time_system_); + ENVOY_LOG(debug, "Getting default credentials for ec2 instance"); + // Get the list of credential names + const auto credential_listing = + fetcher_->getMetadata(*dispatcher, EC2_METADATA_HOST, SECURITY_CREDENTIALS_PATH); + if (!credential_listing) { + ENVOY_LOG(error, "Could not retrieve credentials listing"); + return; + } + const auto credential_names = + StringUtil::splitToken(StringUtil::trim(credential_listing.value()), "\n"); + if (credential_names.empty()) { + ENVOY_LOG(error, "No credentials were found"); + return; + } + ENVOY_LOG(debug, "Credentials found:\n{}", credential_listing.value()); + const auto credential_path = std::string(SECURITY_CREDENTIALS_PATH) + "/" + + std::string(credential_names[0].data(), credential_names[0].size()); + ENVOY_LOG(debug, "Loading credentials document from {}", credential_path); + const auto credential_document = + fetcher_->getMetadata(*dispatcher, EC2_METADATA_HOST, credential_path); + if (!credential_document) { + ENVOY_LOG(error, "Could not load credentials document"); + return; + } + Json::ObjectSharedPtr document_json; + try { + document_json = Json::Factory::loadFromString(credential_document.value()); + } catch (EnvoyException& e) { + ENVOY_LOG(error, "Could not parse credentials document: {}", e.what()); + return; + } + Credentials credentials; + const auto access_key_id = document_json->getString(ACCESS_KEY_ID, ""); + if (!access_key_id.empty()) { + ENVOY_LOG(debug, "Found instance credential {}={}", ACCESS_KEY_ID, access_key_id); + credentials.setAccessKeyId(access_key_id); + } + const auto secret_access_key = document_json->getString(SECRET_ACCESS_KEY, ""); + if (!secret_access_key.empty()) { + ENVOY_LOG(debug, "Found instance credential {}=*****", SECRET_ACCESS_KEY); + credentials.setSecretAccessKey(secret_access_key); + } + const auto token = document_json->getString(TOKEN, ""); + if (!token.empty()) { + ENVOY_LOG(debug, "Found instance credential {}=*****", TOKEN); + credentials.setSessionToken(token); + } + cached_credentials_ = credentials; + last_updated_ = time_system_.systemTime(); +} + +bool TaskRoleCredentialsProvider::needsRefresh() { + if (time_system_.systemTime() - last_updated_ > REFRESH_INTERVAL) { + return true; + } + if (expiration_time_ - time_system_.systemTime() < REFRESH_GRACE_PERIOD) { + return true; + } + return false; +} + +void TaskRoleCredentialsProvider::refresh() { + auto dispatcher = api_.allocateDispatcher(time_system_); + ENVOY_LOG(debug, "Getting ecs credentials"); + ENVOY_LOG(debug, "Loading credentials document from {}", credential_uri_); + absl::string_view host_view; + absl::string_view path_view; + Http::Utility::extractHostPathFromUri(credential_uri_, host_view, path_view); + const auto credential_document = + fetcher_->getMetadata(*dispatcher, std::string(host_view.data(), host_view.size()), + std::string(path_view.data(), path_view.size()), authorization_token_); + if (!credential_document) { + ENVOY_LOG(error, "Could not load credentials document"); + return; + } + Json::ObjectSharedPtr document_json; + try { + document_json = Json::Factory::loadFromString(credential_document.value()); + } catch (EnvoyException& e) { + ENVOY_LOG(error, "Could not parse credentials document: {}", e.what()); + return; + } + Credentials credentials; + const auto access_key_id = document_json->getString(ACCESS_KEY_ID, ""); + if (!access_key_id.empty()) { + ENVOY_LOG(debug, "Found task role credential {}={}", ACCESS_KEY_ID, access_key_id); + credentials.setAccessKeyId(access_key_id); + } + const auto secret_access_key = document_json->getString(SECRET_ACCESS_KEY, ""); + if (!secret_access_key.empty()) { + ENVOY_LOG(debug, "Found task role credential {}=*****", SECRET_ACCESS_KEY); + credentials.setSecretAccessKey(secret_access_key); + } + const auto token = document_json->getString(TOKEN, ""); + if (!token.empty()) { + ENVOY_LOG(debug, "Found task role credential {}=*****", TOKEN); + credentials.setSessionToken(token); + } + const auto expiration = document_json->getString(EXPIRATION, ""); + if (!expiration.empty()) { + std::tm timestamp{}; + if (strptime(expiration.c_str(), EXPIRATION_FORMAT, ×tamp) == + (expiration.c_str() + expiration.size())) { + ENVOY_LOG(debug, "Found task role credential {}={}", EXPIRATION, expiration); + expiration_time_ = SystemTime::clock::from_time_t(std::mktime(×tamp)); + } + } + cached_credentials_ = credentials; + last_updated_ = time_system_.systemTime(); +} + +Credentials CredentialsProviderChain::getCredentials() { + for (auto& provider : providers_) { + const auto credentials = provider->getCredentials(); + if (credentials.accessKeyId() && credentials.secretAccessKey()) { + return credentials; + } + } + ENVOY_LOG(debug, "No credentials found. Using anonymous credentials"); + return Credentials(); +} + +DefaultCredentialsProviderChain::DefaultCredentialsProviderChain( + Api::Api& api, Event::TimeSystem& time_system, + const CredentialsProviderChainFactories& factories) { + ENVOY_LOG(debug, "Using environment credentials provider"); + add(factories.createEnvironmentCredentialsProvider()); + const auto relative_uri = std::getenv(AWS_CONTAINER_CREDENTIALS_RELATIVE_URI); + const auto full_uri = std::getenv(AWS_CONTAINER_CREDENTIALS_FULL_URI); + const auto metadata_disabled = std::getenv(AWS_EC2_METADATA_DISABLED); + if (relative_uri != nullptr) { + const auto uri = std::string(CONTAINER_METADATA_HOST) + relative_uri; + ENVOY_LOG(debug, "Using task role credentials provider with URI: {}", uri); + add(factories.createTaskRoleCredentialsProvider( + api, time_system, factories.createMetadataFetcher(), uri, absl::optional())); + } else if (full_uri != nullptr) { + const auto authorization_token = std::getenv(AWS_CONTAINER_AUTHORIZATION_TOKEN); + if (authorization_token != nullptr) { + ENVOY_LOG(debug, "Using task role credentials provider with URI: {} and authorization token", + full_uri); + add(factories.createTaskRoleCredentialsProvider( + api, time_system, factories.createMetadataFetcher(), full_uri, authorization_token)); + } else { + ENVOY_LOG(debug, "Using task role credentials provider with URI: {}", full_uri); + add(factories.createTaskRoleCredentialsProvider(api, time_system, + factories.createMetadataFetcher(), full_uri, + absl::optional())); + } + } else if (metadata_disabled == nullptr || strncmp(metadata_disabled, TRUE, strlen(TRUE)) != 0) { + ENVOY_LOG(debug, "Using instance profile credentials provider"); + add(factories.createInstanceProfileCredentialsProvider(api, time_system, + factories.createMetadataFetcher())); + } +} + +MetadataFetcherPtr createMetadataFetcher(Api::Api& api, Event::TimeSystem& time_system); + +MetadataFetcherPtr DefaultCredentialsProviderChain::createMetadataFetcher() const { + return std::make_unique(); +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/credentials_provider_impl.h b/source/common/aws/credentials_provider_impl.h new file mode 100644 index 0000000000000..74dc01d663b6a --- /dev/null +++ b/source/common/aws/credentials_provider_impl.h @@ -0,0 +1,145 @@ +#pragma once + +#include + +#include "envoy/api/api.h" +#include "envoy/event/timer.h" + +#include "common/aws/credentials_provider.h" +#include "common/aws/metadata_fetcher.h" +#include "common/common/logger.h" +#include "common/common/thread.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class EnvironmentCredentialsProvider : public CredentialsProvider, + public Logger::Loggable { +public: + Credentials getCredentials() override; +}; + +class MetadataCredentialsProviderBase : public CredentialsProvider, + public Logger::Loggable { +public: + MetadataCredentialsProviderBase(Api::Api& api, Event::TimeSystem& time_system, + MetadataFetcherPtr&& fetcher) + : api_(api), time_system_(time_system), fetcher_(std::move(fetcher)) {} + + Credentials getCredentials() override { + refreshIfNeeded(); + return cached_credentials_; + } + +protected: + Api::Api& api_; + Event::TimeSystem& time_system_; + MetadataFetcherPtr fetcher_; + SystemTime last_updated_; + Credentials cached_credentials_; + Thread::MutexBasicLockable lock_; + + void refreshIfNeeded(); + + virtual bool needsRefresh() PURE; + virtual void refresh() PURE; +}; + +class InstanceProfileCredentialsProvider : public MetadataCredentialsProviderBase { +public: + InstanceProfileCredentialsProvider(Api::Api& api, Event::TimeSystem& time_system, + MetadataFetcherPtr&& fetcher) + : MetadataCredentialsProviderBase(api, time_system, std::move(fetcher)) {} + +private: + bool needsRefresh() override; + void refresh() override; +}; + +class TaskRoleCredentialsProvider : public MetadataCredentialsProviderBase { +public: + TaskRoleCredentialsProvider( + Api::Api& api, Event::TimeSystem& time_system, MetadataFetcherPtr&& fetcher, + const std::string& credential_uri, + const absl::optional& authorization_token = absl::optional()) + : MetadataCredentialsProviderBase(api, time_system, std::move(fetcher)), + credential_uri_(credential_uri), authorization_token_(authorization_token) {} + +private: + SystemTime expiration_time_; + std::string credential_uri_; + absl::optional authorization_token_; + + bool needsRefresh() override; + void refresh() override; +}; + +class CredentialsProviderChain : public CredentialsProvider, + public Logger::Loggable { +public: + virtual ~CredentialsProviderChain() = default; + + void add(const CredentialsProviderSharedPtr& credentials_provider) { + providers_.emplace_back(credentials_provider); + } + + Credentials getCredentials() override; + +protected: + std::list providers_; +}; + +class CredentialsProviderChainFactories { +public: + virtual ~CredentialsProviderChainFactories() = default; + + virtual MetadataFetcherPtr createMetadataFetcher() const PURE; + + virtual CredentialsProviderSharedPtr createEnvironmentCredentialsProvider() const PURE; + + virtual CredentialsProviderSharedPtr createTaskRoleCredentialsProvider( + Api::Api& api, Event::TimeSystem& time_system, MetadataFetcherPtr&& fetcher, + const std::string& credential_uri, + const absl::optional& authorization_token) const PURE; + + virtual CredentialsProviderSharedPtr + createInstanceProfileCredentialsProvider(Api::Api& api, Event::TimeSystem& time_system, + MetadataFetcherPtr&& fetcher) const PURE; +}; + +class DefaultCredentialsProviderChain : public CredentialsProviderChain, + public CredentialsProviderChainFactories { +public: + DefaultCredentialsProviderChain(Api::Api& api, Event::TimeSystem& time_system) + : DefaultCredentialsProviderChain(api, time_system, *this) {} + + DefaultCredentialsProviderChain(Api::Api& api, Event::TimeSystem& time_system, + const CredentialsProviderChainFactories& factories); + +private: + virtual MetadataFetcherPtr createMetadataFetcher() const override; + + CredentialsProviderSharedPtr createEnvironmentCredentialsProvider() const override { + return std::make_shared(); + } + + CredentialsProviderSharedPtr createTaskRoleCredentialsProvider( + Api::Api& api, Event::TimeSystem& time_system, MetadataFetcherPtr&& fetcher, + const std::string& credential_uri, + const absl::optional& authorization_token) const override { + return std::make_shared(api, time_system, std::move(fetcher), + credential_uri, authorization_token); + } + + CredentialsProviderSharedPtr + createInstanceProfileCredentialsProvider(Api::Api& api, Event::TimeSystem& time_system, + MetadataFetcherPtr&& fetcher) const override { + return std::make_shared(api, time_system, + std::move(fetcher)); + } +}; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/metadata_fetcher.h b/source/common/aws/metadata_fetcher.h new file mode 100644 index 0000000000000..2bacd62b29aec --- /dev/null +++ b/source/common/aws/metadata_fetcher.h @@ -0,0 +1,33 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/event/dispatcher.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class MetadataFetcher { +public: + virtual ~MetadataFetcher() = default; + + /** + * Fetch instance metadata from a known host. + * @param dispatcher the dispatcher to execute the http request on. + * @param host the instance metadata host. Example: 169.254.169.254:80 + * @param path the instance metadata path. Example: /latest/meta-data/iam/info + * @param auth_token an optional authorization token for requesting the metdata. + * @return an optional string containing the instance metadata if it can be found. + */ + virtual absl::optional getMetadata( + Event::Dispatcher& dispatcher, const std::string& host, const std::string& path, + const absl::optional& auth_token = absl::optional()) const PURE; +}; + +typedef std::unique_ptr MetadataFetcherPtr; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/metadata_fetcher_impl.cc b/source/common/aws/metadata_fetcher_impl.cc new file mode 100644 index 0000000000000..b21fa9b1449ff --- /dev/null +++ b/source/common/aws/metadata_fetcher_impl.cc @@ -0,0 +1,132 @@ +#include "common/aws/metadata_fetcher_impl.h" + +#include "envoy/common/exception.h" +#include "envoy/network/transport_socket.h" + +#include "common/http/headers.h" +#include "common/http/http1/codec_impl.h" +#include "common/network/raw_buffer_socket.h" +#include "common/network/utility.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +const size_t MetadataFetcherImpl::MAX_RETRIES = 4; +const std::chrono::milliseconds MetadataFetcherImpl::NO_DELAY{0}; +const std::chrono::milliseconds MetadataFetcherImpl::RETRY_DELAY{1000}; +const std::chrono::milliseconds MetadataFetcherImpl::TIMEOUT{5000}; + +absl::optional +MetadataFetcherImpl::getMetadata(Event::Dispatcher& dispatcher, const std::string& host, + const std::string& path, + const absl::optional& auth_token) const { + // default to port 80 + auto host_with_port = host; + if (host.find(':') == std::string::npos) { + host_with_port = host + ":80"; + } + Http::HeaderMapImpl headers; + headers.insertMethod().value().setReference(Http::Headers::get().MethodValues.Get); + headers.insertHost().value().setReference(host_with_port); + headers.insertPath().value().setReference(path); + if (auth_token) { + headers.insertAuthorization().value().setReference(auth_token.value()); + } + auto delay = NO_DELAY; + for (size_t retries = 0; retries < MAX_RETRIES; retries++) { + ENVOY_LOG(debug, "Connecting to http://{}{} to retrieve metadata. Try {}/{}", host_with_port, + path, retries + 1, MAX_RETRIES); + auto decoder = std::unique_ptr(decoder_factory_()); + std::unique_ptr session; + const auto timer = dispatcher.createTimer([&]() { + auto socket = std::make_unique(); + auto options = std::make_shared(); + const auto address = + Network::Utility::resolveUrl(Network::Utility::TCP_SCHEME + host_with_port); + session = std::unique_ptr(session_factory_( + dispatcher.createClientConnection(address, nullptr, std::move(socket), options), + dispatcher, *decoder, *decoder, headers, codec_factory_)); + decoder->setCompleteCallback([&session, &dispatcher]() { + session->close(); + dispatcher.exit(); + }); + }); + timer->enableTimer(delay); + dispatcher.run(Event::Dispatcher::RunType::Block); + const auto& body = decoder->body(); + if (!body.empty()) { + ENVOY_LOG(debug, "Found metadata at {}{}", host_with_port, path); + return absl::optional(body); + } + delay = RETRY_DELAY; + } + ENVOY_LOG(error, "Could not find metadata at {}{}", host_with_port, path); + return absl::optional(); +} + +Http::ClientConnection* MetadataFetcherImpl::createCodec(Network::Connection& connection, + Http::ConnectionCallbacks& callbacks) { + return new Http::Http1::ClientConnectionImpl(connection, callbacks); +} + +MetadataFetcherImpl::MetadataSession::MetadataSession(Network::ClientConnectionPtr&& connection, + Event::Dispatcher& dispatcher, + Http::StreamDecoder& decoder, + Http::StreamCallbacks& callbacks, + const Http::HeaderMap& headers, + HttpCodecFactory codec_factory) + : connection_(std::move(connection)) { + codec_ = Http::ClientConnectionPtr{codec_factory(*connection_, *this)}; + connection_->addReadFilter(Network::ReadFilterSharedPtr{new CodecReadFilter(*this)}); + connection_->addConnectionCallbacks(*this); + connection_->connect(); + connection_->noDelay(true); + encoder_ = &codec_->newStream(decoder); + encoder_->getStream().addCallbacks(callbacks); + encoder_->encodeHeaders(headers, true); + timeout_timer_ = dispatcher.createTimer([this]() { close(); }); + timeout_timer_->enableTimer(TIMEOUT); +} + +void MetadataFetcherImpl::MetadataSession::onEvent(Network::ConnectionEvent event) { + if (event == Network::ConnectionEvent::Connected) { + connected_ = true; + } + if (event == Network::ConnectionEvent::RemoteClose) { + Buffer::OwnedImpl empty; + onData(empty); + } + if (event == Network::ConnectionEvent::RemoteClose || + event == Network::ConnectionEvent::LocalClose) { + encoder_->getStream().resetStream(connected_ ? Http::StreamResetReason::ConnectionTermination + : Http::StreamResetReason::ConnectionFailure); + } +} + +void MetadataFetcherImpl::MetadataSession::onData(Buffer::Instance& data) { + try { + codec_->dispatch(data); + } catch (EnvoyException& e) { + close(); + } +} + +void MetadataFetcherImpl::StringBufferDecoder::decodeHeaders(Envoy::Http::HeaderMapPtr&&, + bool end_stream) { + if (end_stream) { + complete(); + } +} + +void MetadataFetcherImpl::StringBufferDecoder::decodeData(Envoy::Buffer::Instance& data, + bool end_stream) { + body_.append(data.toString()); + if (end_stream) { + complete(); + } +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/metadata_fetcher_impl.h b/source/common/aws/metadata_fetcher_impl.h new file mode 100644 index 0000000000000..4d4419b327ba8 --- /dev/null +++ b/source/common/aws/metadata_fetcher_impl.h @@ -0,0 +1,148 @@ +#pragma once + +#include "envoy/api/api.h" +#include "envoy/http/codec.h" +#include "envoy/http/header_map.h" +#include "envoy/network/connection.h" + +#include "common/aws/metadata_fetcher.h" +#include "common/common/logger.h" +#include "common/network/filter_impl.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class MetadataFetcherImpl : public MetadataFetcher, Logger::Loggable { +public: + // Factory abstraction used for creating test http codecs + typedef std::function + HttpCodecFactory; + + class MetadataSession; + typedef std::function + MetadataSessionFactory; + + class StringBufferDecoder; + typedef std::function StringBufferDecoderFactory; + + MetadataFetcherImpl(MetadataSessionFactory session_factory = createSession, + StringBufferDecoderFactory decoder_factory = createDecoder, + HttpCodecFactory codec_factory = createCodec) + : session_factory_(session_factory), decoder_factory_(decoder_factory), + codec_factory_(codec_factory) {} + + absl::optional getMetadata( + Event::Dispatcher& dispatcher, const std::string& host, const std::string& path, + const absl::optional& auth_token = absl::optional()) const override; + + class MetadataSession : public Network::ConnectionCallbacks, public Http::ConnectionCallbacks { + public: + MetadataSession(Network::ClientConnectionPtr&& connection, Event::Dispatcher& dispatcher, + Http::StreamDecoder& decoder, Http::StreamCallbacks& callbacks, + const Http::HeaderMap& headers, HttpCodecFactory codec_factory); + + virtual ~MetadataSession() = default; + + // Network::ConnectionCallbacks + void onEvent(Network::ConnectionEvent event) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + // Http::ConnectionCallbacks + void onGoAway() override {} + + void onData(Buffer::Instance& data); + + virtual void close() { connection_->close(Network::ConnectionCloseType::NoFlush); } + + protected: + MetadataSession(){}; + + private: + Network::ClientConnectionPtr connection_{}; + Http::ClientConnectionPtr codec_{}; + Event::TimerPtr timeout_timer_{}; + Http::StreamEncoder* encoder_{}; + bool connected_{}; + }; + + class StringBufferDecoder : public Http::StreamDecoder, public Http::StreamCallbacks { + public: + virtual ~StringBufferDecoder() = default; + + // Http::StreamDecoder + void decode100ContinueHeaders(Http::HeaderMapPtr&&) override {} + void decodeHeaders(Http::HeaderMapPtr&&, bool end_stream) override; + void decodeData(Buffer::Instance& data, bool end_stream) override; + void decodeTrailers(Http::HeaderMapPtr&&) override {} + void decodeMetadata(Http::MetadataMapPtr&&) override {} + + // Http::StreamCallbacks + void onResetStream(Http::StreamResetReason) override { complete(); } + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + virtual const std::string& body() const { return body_; } + + void setCompleteCallback(std::function complete_cb) { complete_cb_ = complete_cb; } + + private: + friend class MetadataFetcherImplTest; + + void complete() { + if (complete_cb_) { + complete_cb_(); + } + } + + std::string body_; + std::function complete_cb_{}; + }; + + class CodecReadFilter : public Network::ReadFilterBaseImpl { + public: + CodecReadFilter(MetadataSession& parent) : parent_(parent) {} + + // Network::ReadFilter + Network::FilterStatus onData(Buffer::Instance& data, bool) override { + parent_.onData(data); + return Network::FilterStatus::StopIteration; + } + + private: + MetadataSession& parent_; + }; + +private: + friend class MetadataFetcherImplTest; + + static Http::ClientConnection* createCodec(Network::Connection& connection, + Http::ConnectionCallbacks& callbacks); + + static MetadataSession* createSession(Network::ClientConnectionPtr&& connection, + Event::Dispatcher& dispatcher, Http::StreamDecoder& decoder, + Http::StreamCallbacks& callbacks, + const Http::HeaderMap& headers, + HttpCodecFactory codec_factory) { + return new MetadataSession(std::move(connection), dispatcher, decoder, callbacks, headers, + codec_factory); + } + + static StringBufferDecoder* createDecoder() { return new StringBufferDecoder(); } + + static const size_t MAX_RETRIES; + static const std::chrono::milliseconds NO_DELAY; + static const std::chrono::milliseconds RETRY_DELAY; + static const std::chrono::milliseconds TIMEOUT; + + MetadataSessionFactory session_factory_; + StringBufferDecoderFactory decoder_factory_; + HttpCodecFactory codec_factory_; +}; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/region_provider.h b/source/common/aws/region_provider.h new file mode 100644 index 0000000000000..50ac08069381d --- /dev/null +++ b/source/common/aws/region_provider.h @@ -0,0 +1,32 @@ +#pragma once + +#include "envoy/common/pure.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class RegionProvider { +public: + virtual ~RegionProvider() = default; + + virtual absl::optional getRegion() PURE; +}; + +class StaticRegionProvider : public RegionProvider { +public: + StaticRegionProvider(const std::string& region) : region_(region) {} + + absl::optional getRegion() override { return absl::optional(region_); } + +private: + const std::string region_; +}; + +typedef std::shared_ptr RegionProviderSharedPtr; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/region_provider_impl.cc b/source/common/aws/region_provider_impl.cc new file mode 100644 index 0000000000000..791c96bbf0ff1 --- /dev/null +++ b/source/common/aws/region_provider_impl.cc @@ -0,0 +1,20 @@ +#include "common/aws/region_provider_impl.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +static const char AWS_REGION[] = "AWS_REGION"; + +absl::optional EnvironmentRegionProvider::getRegion() { + const auto region = std::getenv(AWS_REGION); + if (region == nullptr) { + return absl::optional(); + } + ENVOY_LOG(debug, "Found environment region {}={}", AWS_REGION, region); + return absl::optional(region); +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/region_provider_impl.h b/source/common/aws/region_provider_impl.h new file mode 100644 index 0000000000000..f3803b889fd92 --- /dev/null +++ b/source/common/aws/region_provider_impl.h @@ -0,0 +1,17 @@ +#pragma once + +#include "common/aws/region_provider.h" +#include "common/common/logger.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class EnvironmentRegionProvider : public RegionProvider, public Logger::Loggable { +public: + absl::optional getRegion() override; +}; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/signer.h b/source/common/aws/signer.h new file mode 100644 index 0000000000000..5504f205bdcce --- /dev/null +++ b/source/common/aws/signer.h @@ -0,0 +1,25 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/http/message.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class Signer { +public: + virtual ~Signer() = default; + + /** + * Sign an AWS request. + * @param message an + */ + virtual void sign(Http::Message& message) const PURE; +}; + +typedef std::shared_ptr SignerSharedPtr; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/source/common/aws/signer_impl.cc b/source/common/aws/signer_impl.cc new file mode 100644 index 0000000000000..da9be5fea4111 --- /dev/null +++ b/source/common/aws/signer_impl.cc @@ -0,0 +1,252 @@ +#include "common/aws/signer_impl.h" + +#include "envoy/common/exception.h" + +#include "common/buffer/buffer_impl.h" +#include "common/common/assert.h" +#include "common/common/hex.h" +#include "common/common/stack_array.h" +#include "common/http/headers.h" + +#include "openssl/evp.h" +#include "openssl/hmac.h" +#include "openssl/sha.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +static const std::string AWS4{"AWS4"}; +static const std::string AWS4_HMAC_SHA256{"AWS4-HMAC-SHA256"}; +static const std::string AWS4_REQUEST{"aws4_request"}; +static const std::string CREDENTIAL{"Credential"}; +static const std::string SIGNED_HEADERS{"SignedHeaders"}; +static const std::string SIGNATURE{"Signature"}; +static const std::string HASHED_EMPTY_STRING{ + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"}; + +DateFormatter SignerImpl::LONG_DATE_FORMATTER("%Y%m%dT%H%M00Z"); +DateFormatter SignerImpl::SHORT_DATE_FORMATTER("%Y%m%d"); +const Http::LowerCaseString SignerImpl::X_AMZ_SECURITY_TOKEN{"x-amz-security-token"}; +const Http::LowerCaseString SignerImpl::X_AMZ_DATE{"x-amz-date"}; +const Http::LowerCaseString SignerImpl::X_AMZ_CONTENT_SHA256{"x-amz-content-sha256"}; + +void SignerImpl::sign(Http::Message& message) const { + const auto& credentials = credentials_provider_->getCredentials(); + if (!credentials.accessKeyId() || !credentials.secretAccessKey()) { + return; + } + const auto& region = region_provider_->getRegion(); + if (!region) { + throw EnvoyException("Could not determine AWS region"); + } + auto& headers = message.headers(); + if (credentials.sessionToken()) { + headers.addCopy(X_AMZ_SECURITY_TOKEN, credentials.sessionToken().value()); + } + const auto long_date = LONG_DATE_FORMATTER.now(time_source_); + const auto short_date = SHORT_DATE_FORMATTER.now(time_source_); + headers.addCopy(X_AMZ_DATE, long_date); + const auto content_hash = createContentHash(message); + headers.addCopy(X_AMZ_CONTENT_SHA256, content_hash); + // Phase 1: Create a canonical request + const auto canonical_headers = canonicalizeHeaders(headers); + const auto signing_headers = createSigningHeaders(canonical_headers); + const auto canonical_request = + createCanonicalRequest(message, canonical_headers, signing_headers, content_hash); + ENVOY_LOG(debug, "Canonical request:\n{}", canonical_request); + // Phase 2: Create a string to sign + const auto credential_scope = createCredentialScope(short_date, region.value()); + const auto string_to_sign = createStringToSign(canonical_request, long_date, credential_scope); + ENVOY_LOG(debug, "String to sign:\n{}", string_to_sign); + // Phase 3: Create a signature + const auto signature = createSignature(credentials.secretAccessKey().value(), short_date, + region.value(), string_to_sign); + // Phase 4: Sign request + const auto authorization_header = createAuthorizationHeader( + credentials.accessKeyId().value(), credential_scope, signing_headers, signature); + ENVOY_LOG(debug, "Signing request with: {}", authorization_header); + headers.addCopy(Http::Headers::get().Authorization, authorization_header); +} + +std::string SignerImpl::createContentHash(Http::Message& message) const { + if (!message.body()) { + return HASHED_EMPTY_STRING; + } + return Hex::encode(hash(*message.body())); +} + +std::string SignerImpl::createCanonicalRequest( + Http::Message& message, const std::map& canonical_headers, + const std::string& signing_headers, const std::string& content_hash) const { + const auto& headers = message.headers(); + std::stringstream out; + // Http method + const auto* method_header = headers.Method(); + if (method_header == nullptr || method_header->value().empty()) { + throw EnvoyException("Message is missing :method header"); + } + out << method_header->value().c_str() << "\n"; + // Path + const auto* path_header = headers.Path(); + if (path_header == nullptr || path_header->value().empty()) { + throw EnvoyException("Message is missing :path header"); + } + const auto& path_value = path_header->value(); + const auto path = StringUtil::cropRight(path_value.getStringView(), "?"); + if (path.empty()) { + out << "/"; + } else { + out << path; + } + out << "\n"; + // Query string + const auto query = StringUtil::cropLeft(path_value.getStringView(), "?"); + if (query != path) { + out << query; + } + out << "\n"; + // Headers + for (const auto& header : canonical_headers) { + out << header.first << ":" << header.second << "\n"; + } + out << "\n" << signing_headers << "\n"; + // Content Hash + out << content_hash; + return out.str(); +} + +std::string SignerImpl::createSigningHeaders( + const std::map& canonical_headers) const { + std::vector keys; + keys.reserve(canonical_headers.size()); + for (const auto& header : canonical_headers) { + keys.emplace_back(header.first); + } + return StringUtil::join(keys, ";"); +} + +std::string SignerImpl::createCredentialScope(const std::string& short_date, + const std::string& region) const { + std::stringstream out; + out << short_date << "/" << region << "/" << service_name_ << "/" << AWS4_REQUEST; + return out.str(); +} + +std::string SignerImpl::createStringToSign(const std::string& canonical_request, + const std::string& long_date, + const std::string& credential_scope) const { + std::stringstream out; + out << AWS4_HMAC_SHA256 << "\n"; + out << long_date << "\n"; + out << credential_scope << "\n"; + out << Hex::encode(hash(Buffer::OwnedImpl(canonical_request))); + return out.str(); +} + +std::string SignerImpl::createSignature(const std::string& secret_access_key, + const std::string& short_date, const std::string& region, + const std::string& string_to_sign) const { + const auto k_secret = AWS4 + secret_access_key; + const auto k_date = hmac(std::vector(k_secret.begin(), k_secret.end()), short_date); + const auto k_region = hmac(k_date, region); + const auto k_service = hmac(k_region, service_name_); + const auto k_signing = hmac(k_service, AWS4_REQUEST); + return Hex::encode(hmac(k_signing, string_to_sign)); +} + +std::string SignerImpl::createAuthorizationHeader(const std::string& access_key_id, + const std::string& credential_scope, + const std::string& signing_headers, + const std::string& signature) const { + std::stringstream out; + out << AWS4_HMAC_SHA256 << " "; + out << CREDENTIAL << "=" << access_key_id << "/" << credential_scope << ", "; + out << SIGNED_HEADERS << "=" << signing_headers << ", "; + out << SIGNATURE << "=" << signature; + return out.str(); +} + +std::map +SignerImpl::canonicalizeHeaders(const Http::HeaderMap& headers) const { + std::map out; + headers.iterate( + [](const Http::HeaderEntry& entry, void* context) -> Http::HeaderMap::Iterate { + auto* map = static_cast*>(context); + const auto& key = entry.key().getStringView(); + // Pseudo-headers should not be canonicalized + if (key.empty() || key[0] == ':') { + return Http::HeaderMap::Iterate::Continue; + } + // Join multi-line headers with commas + std::vector lines; + for (const auto& line : StringUtil::splitToken(entry.value().getStringView(), "\n")) { + lines.emplace_back(StringUtil::trim(line)); + } + auto value = StringUtil::join(lines, ","); + // Remove duplicate spaces + const auto end = std::unique(value.begin(), value.end(), [](char lhs, char rhs) { + return (lhs == rhs) && (lhs == ' '); + }); + value.erase(end, value.end()); + map->emplace(entry.key().c_str(), value); + return Http::HeaderMap::Iterate::Continue; + }, + &out); + // The AWS SDK has a quirk where it removes "default ports" (80, 443) from the host headers + // Additionally, we canonicalize the :authority header as "host" + const auto* authority_header = headers.Host(); + if (authority_header != nullptr && !authority_header->value().empty()) { + const auto& value = authority_header->value().getStringView(); + const auto parts = StringUtil::splitToken(value, ":"); + if (parts.size() > 1 && (parts[1] == "80" || parts[1] == "443")) { + out.emplace(Http::Headers::get().HostLegacy.get(), + std::string(parts[0].data(), parts[0].size())); + } else { + out.emplace(Http::Headers::get().HostLegacy.get(), std::string(value.data(), value.size())); + } + } + return out; +} + +std::vector SignerImpl::hash(const Buffer::Instance& buffer) const { + std::vector digest(SHA256_DIGEST_LENGTH); + EVP_MD_CTX ctx; + auto code = EVP_DigestInit(&ctx, EVP_sha256()); + RELEASE_ASSERT(code == 1, "Failed to init digest context"); + const auto num_slices = buffer.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer.getRawSlices(slices.begin(), num_slices); + for (const auto& slice : slices) { + code = EVP_DigestUpdate(&ctx, slice.mem_, slice.len_); + RELEASE_ASSERT(code == 1, "Failed to update digest"); + } + unsigned int digest_length; + code = EVP_DigestFinal(&ctx, digest.data(), &digest_length); + RELEASE_ASSERT(code == 1, "Failed to finalize digest"); + RELEASE_ASSERT(digest_length == SHA256_DIGEST_LENGTH, "Digest length mismatch"); + return digest; +} + +std::vector SignerImpl::hmac(const std::vector& key, + const std::string& string) const { + std::vector mac(EVP_MAX_MD_SIZE); + HMAC_CTX ctx; + RELEASE_ASSERT(key.size() < std::numeric_limits::max(), "Hmac key is too long"); + HMAC_CTX_init(&ctx); + auto code = HMAC_Init_ex(&ctx, key.data(), static_cast(key.size()), EVP_sha256(), nullptr); + RELEASE_ASSERT(code == 1, "Failed to init hmac context"); + code = HMAC_Update(&ctx, reinterpret_cast(string.data()), string.size()); + RELEASE_ASSERT(code == 1, "Failed to update hmac"); + unsigned int len; + code = HMAC_Final(&ctx, mac.data(), &len); + RELEASE_ASSERT(code == 1, "Failed to finalize hmac"); + RELEASE_ASSERT(len <= EVP_MAX_MD_SIZE, "Hmac length too large"); + HMAC_CTX_cleanup(&ctx); + mac.resize(len); + return mac; +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy diff --git a/source/common/aws/signer_impl.h b/source/common/aws/signer_impl.h new file mode 100644 index 0000000000000..f1282c74923a7 --- /dev/null +++ b/source/common/aws/signer_impl.h @@ -0,0 +1,73 @@ +#pragma once + +#include "common/aws/credentials_provider.h" +#include "common/aws/region_provider.h" +#include "common/aws/signer.h" +#include "common/common/logger.h" +#include "common/common/utility.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +/** + * Implementation of the Signature V4 signing process. + * See https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html + */ +class SignerImpl : public Signer, public Logger::Loggable { +public: + SignerImpl(const std::string& service_name, + const CredentialsProviderSharedPtr& credentials_provider, + const RegionProviderSharedPtr& region_provider, TimeSource& time_source) + : service_name_(service_name), credentials_provider_(credentials_provider), + region_provider_(region_provider), time_source_(time_source) {} + + void sign(Http::Message& message) const override; + + static const Http::LowerCaseString X_AMZ_SECURITY_TOKEN; + static const Http::LowerCaseString X_AMZ_DATE; + static const Http::LowerCaseString X_AMZ_CONTENT_SHA256; + +private: + friend class SignerImplTest; + + static DateFormatter LONG_DATE_FORMATTER; + static DateFormatter SHORT_DATE_FORMATTER; + const std::string service_name_; + CredentialsProviderSharedPtr credentials_provider_; + RegionProviderSharedPtr region_provider_; + TimeSource& time_source_; + + std::string createContentHash(Http::Message& message) const; + + std::string createCanonicalRequest(Http::Message& message, + const std::map& canonical_headers, + const std::string& signing_headers, + const std::string& content_hash) const; + + std::string + createSigningHeaders(const std::map& canonical_headers) const; + + std::string createCredentialScope(const std::string& short_date, const std::string& region) const; + + std::string createStringToSign(const std::string& canonical_request, const std::string& long_date, + const std::string& credential_scope) const; + + std::string createSignature(const std::string& secret_access_key, const std::string& short_date, + const std::string& region, const std::string& string_to_sign) const; + + std::string createAuthorizationHeader(const std::string& access_key_id, + const std::string& credential_scope, + const std::string& signing_headers, + const std::string& signature) const; + + std::map canonicalizeHeaders(const Http::HeaderMap& headers) const; + + std::vector hash(const Buffer::Instance& buffer) const; + + std::vector hmac(const std::vector& key, const std::string& string) const; +}; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy diff --git a/source/common/common/logger.h b/source/common/common/logger.h index 341caa20c8c96..f7a09fc9eacf1 100644 --- a/source/common/common/logger.h +++ b/source/common/common/logger.h @@ -22,6 +22,7 @@ namespace Logger { // clang-format off #define ALL_LOGGER_IDS(FUNCTION) \ FUNCTION(admin) \ + FUNCTION(aws) \ FUNCTION(assert) \ FUNCTION(backtrace) \ FUNCTION(client) \ diff --git a/source/common/grpc/async_client_manager_impl.cc b/source/common/grpc/async_client_manager_impl.cc index 3a4d01d4a86fe..851052972ca4c 100644 --- a/source/common/grpc/async_client_manager_impl.cc +++ b/source/common/grpc/async_client_manager_impl.cc @@ -33,7 +33,7 @@ AsyncClientFactoryImpl::AsyncClientFactoryImpl(Upstream::ClusterManager& cm, AsyncClientManagerImpl::AsyncClientManagerImpl(Upstream::ClusterManager& cm, ThreadLocal::Instance& tls, TimeSource& time_source, Api::Api& api) - : cm_(cm), tls_(tls), time_source_(time_source) { + : cm_(cm), tls_(tls), time_source_(time_source), api_(api) { #ifdef ENVOY_GOOGLE_GRPC google_tls_slot_ = tls.allocateSlot(); google_tls_slot_->set( @@ -48,13 +48,14 @@ AsyncClientPtr AsyncClientFactoryImpl::create() { } GoogleAsyncClientFactoryImpl::GoogleAsyncClientFactoryImpl( - ThreadLocal::Instance& tls, ThreadLocal::Slot* google_tls_slot, Stats::Scope& scope, - const envoy::api::v2::core::GrpcService& config) - : tls_(tls), google_tls_slot_(google_tls_slot), + Api::Api& api, ThreadLocal::Instance& tls, ThreadLocal::Slot* google_tls_slot, + Stats::Scope& scope, const envoy::api::v2::core::GrpcService& config) + : api_(api), tls_(tls), google_tls_slot_(google_tls_slot), scope_(scope.createScope(fmt::format("grpc.{}.", config.google_grpc().stat_prefix()))), config_(config) { #ifndef ENVOY_GOOGLE_GRPC + UNREFERENCED_PARAMETER(api_); UNREFERENCED_PARAMETER(tls_); UNREFERENCED_PARAMETER(google_tls_slot_); UNREFERENCED_PARAMETER(scope_); @@ -69,8 +70,8 @@ AsyncClientPtr GoogleAsyncClientFactoryImpl::create() { #ifdef ENVOY_GOOGLE_GRPC GoogleGenericStubFactory stub_factory; return std::make_unique( - tls_.dispatcher(), google_tls_slot_->getTyped(), stub_factory, - scope_, config_); + api_, tls_.dispatcher(), google_tls_slot_->getTyped(), + stub_factory, scope_, config_); #else return nullptr; #endif @@ -83,7 +84,7 @@ AsyncClientManagerImpl::factoryForGrpcService(const envoy::api::v2::core::GrpcSe case envoy::api::v2::core::GrpcService::kEnvoyGrpc: return std::make_unique(cm_, config, skip_cluster_check, time_source_); case envoy::api::v2::core::GrpcService::kGoogleGrpc: - return std::make_unique(tls_, google_tls_slot_.get(), scope, + return std::make_unique(api_, tls_, google_tls_slot_.get(), scope, config); default: NOT_REACHED_GCOVR_EXCL_LINE; diff --git a/source/common/grpc/async_client_manager_impl.h b/source/common/grpc/async_client_manager_impl.h index 0eb9fc128ccf8..e21bcfe93e9ce 100644 --- a/source/common/grpc/async_client_manager_impl.h +++ b/source/common/grpc/async_client_manager_impl.h @@ -26,13 +26,14 @@ class AsyncClientFactoryImpl : public AsyncClientFactory { class GoogleAsyncClientFactoryImpl : public AsyncClientFactory { public: - GoogleAsyncClientFactoryImpl(ThreadLocal::Instance& tls, ThreadLocal::Slot* google_tls_slot, - Stats::Scope& scope, + GoogleAsyncClientFactoryImpl(Api::Api& api, ThreadLocal::Instance& tls, + ThreadLocal::Slot* google_tls_slot, Stats::Scope& scope, const envoy::api::v2::core::GrpcService& config); AsyncClientPtr create() override; private: + Api::Api& api_; ThreadLocal::Instance& tls_; ThreadLocal::Slot* google_tls_slot_; Stats::ScopeSharedPtr scope_; @@ -54,6 +55,7 @@ class AsyncClientManagerImpl : public AsyncClientManager { ThreadLocal::Instance& tls_; ThreadLocal::SlotPtr google_tls_slot_; TimeSource& time_source_; + Api::Api& api_; }; } // namespace Grpc diff --git a/source/common/grpc/google_async_client_impl.cc b/source/common/grpc/google_async_client_impl.cc index 2ea310b49ba3b..6b62a488cc8c1 100644 --- a/source/common/grpc/google_async_client_impl.cc +++ b/source/common/grpc/google_async_client_impl.cc @@ -64,7 +64,7 @@ void GoogleAsyncClientThreadLocal::completionThread() { ENVOY_LOG(debug, "completionThread exiting"); } -GoogleAsyncClientImpl::GoogleAsyncClientImpl(Event::Dispatcher& dispatcher, +GoogleAsyncClientImpl::GoogleAsyncClientImpl(Api::Api& api, Event::Dispatcher& dispatcher, GoogleAsyncClientThreadLocal& tls, GoogleStubFactory& stub_factory, Stats::ScopeSharedPtr scope, @@ -75,7 +75,9 @@ GoogleAsyncClientImpl::GoogleAsyncClientImpl(Event::Dispatcher& dispatcher, // smart enough to do connection pooling and reuse with identical channel args, so this should // have comparable overhead to what we are doing in Grpc::AsyncClientImpl, i.e. no expensive // new connection implied. - std::shared_ptr creds = getGoogleGrpcChannelCredentials(config); + GoogleGrpcCredentialsFactoryContextImpl context(api, dispatcher.timeSystem()); + std::shared_ptr creds = + getGoogleGrpcChannelCredentials(config, context); std::shared_ptr channel = CreateChannel(config.google_grpc().target_uri(), creds); stub_ = stub_factory.createStub(channel); // Initialize client stats. diff --git a/source/common/grpc/google_async_client_impl.h b/source/common/grpc/google_async_client_impl.h index 71a22e2efde64..6cd443bbb4ea4 100644 --- a/source/common/grpc/google_async_client_impl.h +++ b/source/common/grpc/google_async_client_impl.h @@ -154,8 +154,9 @@ class GoogleGenericStubFactory : public GoogleStubFactory { // Google gRPC C++ client library implementation of Grpc::AsyncClient. class GoogleAsyncClientImpl final : public AsyncClient, Logger::Loggable { public: - GoogleAsyncClientImpl(Event::Dispatcher& dispatcher, GoogleAsyncClientThreadLocal& tls, - GoogleStubFactory& stub_factory, Stats::ScopeSharedPtr scope, + GoogleAsyncClientImpl(Api::Api& api, Event::Dispatcher& dispatcher, + GoogleAsyncClientThreadLocal& tls, GoogleStubFactory& stub_factory, + Stats::ScopeSharedPtr scope, const envoy::api::v2::core::GrpcService& config); ~GoogleAsyncClientImpl() override; diff --git a/source/common/grpc/google_grpc_creds_impl.cc b/source/common/grpc/google_grpc_creds_impl.cc index ab3a1ab9486df..f5f3ac84e2336 100644 --- a/source/common/grpc/google_grpc_creds_impl.cc +++ b/source/common/grpc/google_grpc_creds_impl.cc @@ -1,7 +1,6 @@ #include "common/grpc/google_grpc_creds_impl.h" #include "envoy/api/v2/core/grpc_service.pb.h" -#include "envoy/grpc/google_grpc_creds.h" #include "envoy/registry/registry.h" #include "common/config/datasource.h" @@ -112,7 +111,9 @@ class DefaultGoogleGrpcCredentialsFactory : public GoogleGrpcCredentialsFactory public: std::shared_ptr - getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config) override { + getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config, + GoogleGrpcCredentialsFactoryContext& context) override { + UNREFERENCED_PARAMETER(context); return CredsUtility::defaultChannelCredentials(grpc_service_config); } @@ -126,7 +127,8 @@ static Registry::RegisterFactory -getGoogleGrpcChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service) { +getGoogleGrpcChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service, + GoogleGrpcCredentialsFactoryContext& context) { GoogleGrpcCredentialsFactory* credentials_factory = nullptr; const std::string& google_grpc_credentials_factory_name = grpc_service.google_grpc().credentials_factory_name(); @@ -141,7 +143,7 @@ getGoogleGrpcChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_se throw EnvoyException(fmt::format("Unknown google grpc credentials factory: {}", google_grpc_credentials_factory_name)); } - return credentials_factory->getChannelCredentials(grpc_service); + return credentials_factory->getChannelCredentials(grpc_service, context); } } // namespace Grpc diff --git a/source/common/grpc/google_grpc_creds_impl.h b/source/common/grpc/google_grpc_creds_impl.h index 09db13b63fbad..0a3d8c2251d7c 100644 --- a/source/common/grpc/google_grpc_creds_impl.h +++ b/source/common/grpc/google_grpc_creds_impl.h @@ -1,6 +1,7 @@ #pragma once #include "envoy/api/v2/core/grpc_service.pb.h" +#include "envoy/grpc/google_grpc_creds.h" #include "grpcpp/grpcpp.h" @@ -11,7 +12,8 @@ grpc::SslCredentialsOptions buildSslOptionsFromConfig( const envoy::api::v2::core::GrpcService::GoogleGrpc::SslCredentials& ssl_config); std::shared_ptr -getGoogleGrpcChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service); +getGoogleGrpcChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service, + GoogleGrpcCredentialsFactoryContext& context); class CredsUtility { public: @@ -57,5 +59,19 @@ class CredsUtility { defaultChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config); }; +class GoogleGrpcCredentialsFactoryContextImpl : public GoogleGrpcCredentialsFactoryContext { +public: + GoogleGrpcCredentialsFactoryContextImpl(Api::Api& api, Event::TimeSystem& time_system) + : api_(api), time_system_(time_system) {} + + Api::Api& api() override { return api_; } + + Event::TimeSystem& timeSystem() override { return time_system_; } + +private: + Api::Api& api_; + Event::TimeSystem& time_system_; +}; + } // namespace Grpc } // namespace Envoy diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index 23d472d14ab22..02d0a720df91c 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -12,6 +12,7 @@ EXTENSIONS = { # "envoy.grpc_credentials.file_based_metadata": "//source/extensions/grpc_credentials/file_based_metadata:config", + "envoy.grpc_credentials.aws_iam": "//source/extensions/grpc_credentials/aws_iam:config", # # Health checkers diff --git a/source/extensions/grpc_credentials/aws_iam/BUILD b/source/extensions/grpc_credentials/aws_iam/BUILD new file mode 100644 index 0000000000000..f092989684847 --- /dev/null +++ b/source/extensions/grpc_credentials/aws_iam/BUILD @@ -0,0 +1,31 @@ +licenses(["notice"]) # Apache 2 + +# AWS IAM gRPC Credentials + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + external_deps = ["grpc"], + deps = [ + "//include/envoy/grpc:google_grpc_creds_interface", + "//include/envoy/registry", + "//source/common/aws:credentials_provider_impl_lib", + "//source/common/aws:region_provider_impl_lib", + "//source/common/aws:signer_impl_lib", + "//source/common/config:utility_lib", + "//source/common/grpc:common_lib", + "//source/common/grpc:google_grpc_creds_lib", + "//source/common/http:message_lib", + "//source/extensions/grpc_credentials:well_known_names", + "@envoy_api//envoy/config/grpc_credential/v2alpha:aws_iam_cc", + ], +) diff --git a/source/extensions/grpc_credentials/aws_iam/config.cc b/source/extensions/grpc_credentials/aws_iam/config.cc new file mode 100644 index 0000000000000..b7b585b2fda78 --- /dev/null +++ b/source/extensions/grpc_credentials/aws_iam/config.cc @@ -0,0 +1,111 @@ +#include "extensions/grpc_credentials/aws_iam/config.h" + +#include "envoy/api/v2/core/grpc_service.pb.h" +#include "envoy/config/grpc_credential/v2alpha/aws_iam.pb.validate.h" +#include "envoy/grpc/google_grpc_creds.h" +#include "envoy/registry/registry.h" + +#include "common/aws/credentials_provider_impl.h" +#include "common/aws/region_provider_impl.h" +#include "common/aws/signer_impl.h" +#include "common/config/utility.h" +#include "common/grpc/google_grpc_creds_impl.h" +#include "common/http/headers.h" +#include "common/http/message_impl.h" +#include "common/http/utility.h" +#include "common/protobuf/utility.h" + +namespace Envoy { +namespace Extensions { +namespace GrpcCredentials { +namespace AwsIam { + +std::shared_ptr AwsIamGrpcCredentialsFactory::getChannelCredentials( + const envoy::api::v2::core::GrpcService& grpc_service_config, + Grpc::GoogleGrpcCredentialsFactoryContext& context) { + const auto& google_grpc = grpc_service_config.google_grpc(); + std::shared_ptr creds = + Grpc::CredsUtility::defaultSslChannelCredentials(grpc_service_config); + std::shared_ptr call_creds = nullptr; + for (const auto& credential : google_grpc.call_credentials()) { + switch (credential.credential_specifier_case()) { + case envoy::api::v2::core::GrpcService::GoogleGrpc::CallCredentials::kFromPlugin: { + if (credential.from_plugin().name() == GrpcCredentialsNames::get().AwsIam) { + AwsIamGrpcCredentialsFactory credentials_factory; + const Envoy::ProtobufTypes::MessagePtr config_message = + Envoy::Config::Utility::translateToFactoryConfig(credential.from_plugin(), + credentials_factory); + const auto& config = Envoy::MessageUtil::downcastAndValidate< + const envoy::config::grpc_credential::v2alpha::AwsIamConfig&>(*config_message); + Aws::Auth::RegionProviderSharedPtr region_provider; + if (!config.region().empty()) { + region_provider = std::make_shared(config.region()); + } else { + region_provider = std::make_shared(); + } + auto credentials_provider = std::make_shared( + context.api(), context.timeSystem()); + auto auth_signer = std::make_shared( + config.service_name(), credentials_provider, region_provider, context.timeSystem()); + std::shared_ptr new_call_creds = + grpc::MetadataCredentialsFromPlugin(std::make_unique(auth_signer)); + if (call_creds == nullptr) { + call_creds = new_call_creds; + } else { + call_creds = grpc::CompositeCallCredentials(call_creds, new_call_creds); + } + } + break; + } + default: + // unused credential types + continue; + } + } + if (call_creds != nullptr) { + return grpc::CompositeChannelCredentials(creds, call_creds); + } + return creds; +} + +grpc::Status AwsIamAuthenticator::GetMetadata(grpc::string_ref service_url, + grpc::string_ref method_name, + const grpc::AuthContext&, + std::multimap* metadata) { + const std::string uri(std::string(service_url.data(), service_url.size()) + "/" + + std::string(method_name.data(), method_name.size())); + absl::string_view host; + absl::string_view path; + Http::Utility::extractHostPathFromUri(uri, host, path); + Http::RequestMessageImpl message; + message.headers().insertMethod().value().setReference(Http::Headers::get().MethodValues.Post); + message.headers().insertHost().value(host); + message.headers().insertPath().value(path); + try { + signer_->sign(message); + } catch (const EnvoyException& e) { + return grpc::Status(grpc::StatusCode::INTERNAL, e.what()); + } + // Copy back whatever headers were added + message.headers().iterate( + [](const Http::HeaderEntry& entry, void* context) -> Http::HeaderMap::Iterate { + auto* md = static_cast*>(context); + const auto& key = entry.key().getStringView(); + // Skip pseudo-headers + if (key.empty() || key[0] == ':') { + return Http::HeaderMap::Iterate::Continue; + } + md->emplace(entry.key().c_str(), entry.value().c_str()); + return Http::HeaderMap::Iterate::Continue; + }, + metadata); + return grpc::Status::OK; +} + +static Registry::RegisterFactory + aws_iam_google_grpc_credentials_registered_; + +} // namespace AwsIam +} // namespace GrpcCredentials +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/grpc_credentials/aws_iam/config.h b/source/extensions/grpc_credentials/aws_iam/config.h new file mode 100644 index 0000000000000..9ffa5ba638b2c --- /dev/null +++ b/source/extensions/grpc_credentials/aws_iam/config.h @@ -0,0 +1,45 @@ +#pragma once + +#include "envoy/config/grpc_credential/v2alpha/aws_iam.pb.h" +#include "envoy/grpc/google_grpc_creds.h" + +#include "common/aws/signer.h" +#include "common/protobuf/protobuf.h" + +#include "extensions/grpc_credentials/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace GrpcCredentials { +namespace AwsIam { + +class AwsIamGrpcCredentialsFactory : public Grpc::GoogleGrpcCredentialsFactory { +public: + std::shared_ptr + getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config, + Grpc::GoogleGrpcCredentialsFactoryContext& context) override; + + Envoy::ProtobufTypes::MessagePtr createEmptyConfigProto() { + return std::make_unique(); + } + + std::string name() const override { return GrpcCredentialsNames::get().AwsIam; } +}; + +class AwsIamAuthenticator : public grpc::MetadataCredentialsPlugin { +public: + AwsIamAuthenticator(Aws::Auth::SignerSharedPtr signer) : signer_(signer) {} + + grpc::Status GetMetadata(grpc::string_ref service_url, grpc::string_ref method_name, + const grpc::AuthContext&, + std::multimap* metadata) override; + +private: + Aws::Auth::SignerSharedPtr signer_; + const envoy::config::grpc_credential::v2alpha::AwsIamConfig config_; +}; + +} // namespace AwsIam +} // namespace GrpcCredentials +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/grpc_credentials/example/config.cc b/source/extensions/grpc_credentials/example/config.cc index 4e5ef7f2eb354..0116a6d83c983 100644 --- a/source/extensions/grpc_credentials/example/config.cc +++ b/source/extensions/grpc_credentials/example/config.cc @@ -13,7 +13,9 @@ namespace Example { std::shared_ptr AccessTokenExampleGrpcCredentialsFactory::getChannelCredentials( - const envoy::api::v2::core::GrpcService& grpc_service_config) { + const envoy::api::v2::core::GrpcService& grpc_service_config, + Grpc::GoogleGrpcCredentialsFactoryContext& context) { + UNREFERENCED_PARAMETER(context); const auto& google_grpc = grpc_service_config.google_grpc(); std::shared_ptr creds = Grpc::CredsUtility::defaultSslChannelCredentials(grpc_service_config); diff --git a/source/extensions/grpc_credentials/example/config.h b/source/extensions/grpc_credentials/example/config.h index 053b79335dc38..207cf56d35704 100644 --- a/source/extensions/grpc_credentials/example/config.h +++ b/source/extensions/grpc_credentials/example/config.h @@ -28,7 +28,8 @@ namespace Example { class AccessTokenExampleGrpcCredentialsFactory : public Grpc::GoogleGrpcCredentialsFactory { public: virtual std::shared_ptr - getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config) override; + getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config, + Grpc::GoogleGrpcCredentialsFactoryContext& context) override; std::string name() const override { return GrpcCredentialsNames::get().AccessTokenExample; } }; diff --git a/source/extensions/grpc_credentials/file_based_metadata/config.cc b/source/extensions/grpc_credentials/file_based_metadata/config.cc index 951f11c6c6ee9..b5eed92a964e6 100644 --- a/source/extensions/grpc_credentials/file_based_metadata/config.cc +++ b/source/extensions/grpc_credentials/file_based_metadata/config.cc @@ -17,7 +17,9 @@ namespace FileBasedMetadata { std::shared_ptr FileBasedMetadataGrpcCredentialsFactory::getChannelCredentials( - const envoy::api::v2::core::GrpcService& grpc_service_config) { + const envoy::api::v2::core::GrpcService& grpc_service_config, + Grpc::GoogleGrpcCredentialsFactoryContext& context) { + UNREFERENCED_PARAMETER(context); const auto& google_grpc = grpc_service_config.google_grpc(); std::shared_ptr creds = Grpc::CredsUtility::defaultSslChannelCredentials(grpc_service_config); diff --git a/source/extensions/grpc_credentials/file_based_metadata/config.h b/source/extensions/grpc_credentials/file_based_metadata/config.h index 9325e0b7d3d0b..1898a4d1e5303 100644 --- a/source/extensions/grpc_credentials/file_based_metadata/config.h +++ b/source/extensions/grpc_credentials/file_based_metadata/config.h @@ -24,7 +24,8 @@ namespace FileBasedMetadata { class FileBasedMetadataGrpcCredentialsFactory : public Grpc::GoogleGrpcCredentialsFactory { public: std::shared_ptr - getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config) override; + getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config, + Grpc::GoogleGrpcCredentialsFactoryContext& context) override; Envoy::ProtobufTypes::MessagePtr createEmptyConfigProto() { return std::make_unique(); diff --git a/source/extensions/grpc_credentials/well_known_names.h b/source/extensions/grpc_credentials/well_known_names.h index 92dbcaf7c226a..7f7ca3f357680 100644 --- a/source/extensions/grpc_credentials/well_known_names.h +++ b/source/extensions/grpc_credentials/well_known_names.h @@ -18,6 +18,8 @@ class GrpcCredentialsNameValues { const std::string AccessTokenExample = "envoy.grpc_credentials.access_token_example"; // File Based Metadata credentials const std::string FileBasedMetadata = "envoy.grpc_credentials.file_based_metadata"; + // AWS IAM + const std::string AwsIam = "envoy.grpc_credentials.aws_iam"; }; typedef ConstSingleton GrpcCredentialsNames; diff --git a/test/common/aws/BUILD b/test/common/aws/BUILD new file mode 100644 index 0000000000000..19a23b883c321 --- /dev/null +++ b/test/common/aws/BUILD @@ -0,0 +1,57 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_test", + "envoy_package", +) + +envoy_package() + +envoy_cc_test( + name = "signer_impl_test", + srcs = ["signer_impl_test.cc"], + deps = [ + "//source/common/aws:signer_impl_lib", + "//source/common/buffer:buffer_lib", + "//source/common/http:message_lib", + "//test/mocks/aws:aws_mocks", + "//test/test_common:simulated_time_system_lib", + "//test/test_common:utility_lib", + ], +) + +envoy_cc_test( + name = "credentials_provider_impl_test", + srcs = ["credentials_provider_impl_test.cc"], + deps = [ + "//source/common/aws:credentials_provider_impl_lib", + "//test/mocks/api:api_mocks", + "//test/mocks/aws:aws_mocks", + "//test/mocks/event:event_mocks", + "//test/test_common:environment_lib", + "//test/test_common:simulated_time_system_lib", + ], +) + +envoy_cc_test( + name = "metadata_fetcher_impl_test", + srcs = ["metadata_fetcher_impl_test.cc"], + deps = [ + "//source/common/aws:metadata_fetcher_impl_lib", + "//source/common/http:header_map_lib", + "//test/mocks/event:event_mocks", + "//test/mocks/http:http_mocks", + "//test/mocks/network:connection_mocks", + "//test/test_common:simulated_time_system_lib", + ], +) + +envoy_cc_test( + name = "region_provider_impl_test", + srcs = ["region_provider_impl_test.cc"], + deps = [ + "//source/common/aws:region_provider_impl_lib", + "//test/test_common:environment_lib", + ], +) diff --git a/test/common/aws/credentials_provider_impl_test.cc b/test/common/aws/credentials_provider_impl_test.cc new file mode 100644 index 0000000000000..b1449cc56dce8 --- /dev/null +++ b/test/common/aws/credentials_provider_impl_test.cc @@ -0,0 +1,398 @@ +#include "common/aws/credentials_provider_impl.h" + +#include "test/mocks/api/mocks.h" +#include "test/mocks/aws/mocks.h" +#include "test/mocks/event/mocks.h" +#include "test/test_common/environment.h" +#include "test/test_common/simulated_time_system.h" + +using testing::_; +using testing::InSequence; +using testing::NiceMock; +using testing::Ref; +using testing::Return; + +namespace Envoy { +namespace Aws { +namespace Auth { + +class EvironmentCredentialsProviderTest : public testing::Test { +public: + ~EvironmentCredentialsProviderTest() { + TestEnvironment::unsetEnvVar("AWS_ACCESS_KEY_ID"); + TestEnvironment::unsetEnvVar("AWS_SECRET_ACCESS_KEY"); + TestEnvironment::unsetEnvVar("AWS_SESSION_TOKEN"); + } + + EnvironmentCredentialsProvider provider_; +}; + +TEST_F(EvironmentCredentialsProviderTest, AllEnvironmentVars) { + TestEnvironment::setEnvVar("AWS_ACCESS_KEY_ID", "akid", 1); + TestEnvironment::setEnvVar("AWS_SECRET_ACCESS_KEY", "secret", 1); + TestEnvironment::setEnvVar("AWS_SESSION_TOKEN", "token", 1); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_EQ("token", credentials.sessionToken().value()); +} + +TEST_F(EvironmentCredentialsProviderTest, NoEnvironmentVars) { + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(EvironmentCredentialsProviderTest, MissingAccessKeyId) { + TestEnvironment::setEnvVar("AWS_SECRET_ACCESS_KEY", "secret", 1); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(EvironmentCredentialsProviderTest, NoSessionToken) { + TestEnvironment::setEnvVar("AWS_ACCESS_KEY_ID", "akid", 1); + TestEnvironment::setEnvVar("AWS_SECRET_ACCESS_KEY", "secret", 1); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +class InstanceProfileCredentialsProviderTest : public testing::Test { +public: + InstanceProfileCredentialsProviderTest() + : fetcher_(new NiceMock()), + provider_(api_, time_system_, MetadataFetcherPtr{fetcher_}) {} + + void expectCredentialListing(const absl::optional& listing) { + EXPECT_CALL(api_, allocateDispatcher_(Ref(time_system_))); + EXPECT_CALL(*fetcher_, getMetadata(_, "169.254.169.254:80", + "/latest/meta-data/iam/security-credentials", _)) + .WillOnce(Return(listing)); + } + + void expectDocument(const absl::optional& document) { + EXPECT_CALL(*fetcher_, getMetadata(_, "169.254.169.254:80", + "/latest/meta-data/iam/security-credentials/doc1", _)) + .WillOnce(Return(document)); + } + + NiceMock api_; + NiceMock* dispatcher_; + Event::SimulatedTimeSystem time_system_; + NiceMock* fetcher_; + InstanceProfileCredentialsProvider provider_; +}; + +TEST_F(InstanceProfileCredentialsProviderTest, FailedCredentailListing) { + expectCredentialListing(absl::optional()); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(InstanceProfileCredentialsProviderTest, EmptyCredentialListing) { + expectCredentialListing(""); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(InstanceProfileCredentialsProviderTest, MissingDocument) { + expectCredentialListing("doc1\ndoc2\ndoc3"); + expectDocument(absl::optional()); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(InstanceProfileCredentialsProviderTest, MalformedDocumenet) { + expectCredentialListing("doc1"); + expectDocument(R"EOF( +not json +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(InstanceProfileCredentialsProviderTest, EmptyValues) { + expectCredentialListing("doc1"); + expectDocument(R"EOF( +{ + "AccessKeyId": "", + "SecretAccessKey": "", + "Token": "" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(InstanceProfileCredentialsProviderTest, FullCachedCredentials) { + expectCredentialListing("doc1"); + expectDocument(R"EOF( +{ + "AccessKeyId": "akid", + "SecretAccessKey": "secret", + "Token": "token" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_EQ("token", credentials.sessionToken().value()); + const auto cached_credentials = provider_.getCredentials(); + EXPECT_EQ("akid", cached_credentials.accessKeyId().value()); + EXPECT_EQ("secret", cached_credentials.secretAccessKey().value()); + EXPECT_EQ("token", cached_credentials.sessionToken().value()); +} + +TEST_F(InstanceProfileCredentialsProviderTest, CredentialExpiration) { + InSequence sequence; + expectCredentialListing("doc1"); + expectDocument(R"EOF( +{ + "AccessKeyId": "akid", + "SecretAccessKey": "secret", + "Token": "token" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_EQ("token", credentials.sessionToken().value()); + time_system_.sleep(std::chrono::hours(2)); + expectCredentialListing("doc1"); + expectDocument(R"EOF( +{ + "AccessKeyId": "new_akid", + "SecretAccessKey": "new_secret", + "Token": "new_token" +} +)EOF"); + const auto new_credentials = provider_.getCredentials(); + EXPECT_EQ("new_akid", new_credentials.accessKeyId().value()); + EXPECT_EQ("new_secret", new_credentials.secretAccessKey().value()); + EXPECT_EQ("new_token", new_credentials.sessionToken().value()); +} + +class TaskRoleCredentialsProviderTest : public testing::Test { +public: + TaskRoleCredentialsProviderTest() + : fetcher_(new NiceMock()), + provider_(api_, time_system_, MetadataFetcherPtr{fetcher_}, "169.254.170.2:80/path/to/doc", + "auth_token") { + // 20180102T030405Z + time_system_.setSystemTime(std::chrono::milliseconds(1514862245000)); + } + + void expectDocument(const absl::optional& document) { + EXPECT_CALL(api_, allocateDispatcher_(Ref(time_system_))); + EXPECT_CALL(*fetcher_, getMetadata(_, "169.254.170.2:80", "/path/to/doc", _)) + .WillOnce(Return(document)); + } + + NiceMock api_; + Event::SimulatedTimeSystem time_system_; + NiceMock* fetcher_; + TaskRoleCredentialsProvider provider_; +}; + +TEST_F(TaskRoleCredentialsProviderTest, FailedFetchingDocument) { + expectDocument(absl::optional()); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(TaskRoleCredentialsProviderTest, MalformedDocumenet) { + expectDocument(R"EOF( +not json +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(TaskRoleCredentialsProviderTest, EmptyValues) { + expectDocument(R"EOF( +{ + "AccessKeyId": "", + "SecretAccessKey": "", + "Token": "", + "Expiration": "" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_FALSE(credentials.accessKeyId().has_value()); + EXPECT_FALSE(credentials.secretAccessKey().has_value()); + EXPECT_FALSE(credentials.sessionToken().has_value()); +} + +TEST_F(TaskRoleCredentialsProviderTest, FullCachedCredentials) { + expectDocument(R"EOF( +{ + "AccessKeyId": "akid", + "SecretAccessKey": "secret", + "Token": "token", + "Expiration": "20180102T030500Z" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_EQ("token", credentials.sessionToken().value()); + const auto cached_credentials = provider_.getCredentials(); + EXPECT_EQ("akid", cached_credentials.accessKeyId().value()); + EXPECT_EQ("secret", cached_credentials.secretAccessKey().value()); + EXPECT_EQ("token", cached_credentials.sessionToken().value()); +} + +TEST_F(TaskRoleCredentialsProviderTest, NormalCredentialExpiration) { + InSequence sequence; + expectDocument(R"EOF( +{ + "AccessKeyId": "akid", + "SecretAccessKey": "secret", + "Token": "token", + "Expiration": "20190102T030405Z" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_EQ("token", credentials.sessionToken().value()); + time_system_.sleep(std::chrono::hours(2)); + expectDocument(R"EOF( +{ + "AccessKeyId": "new_akid", + "SecretAccessKey": "new_secret", + "Token": "new_token", + "Expiration": "20190102T030405Z" +} +)EOF"); + const auto cached_credentials = provider_.getCredentials(); + EXPECT_EQ("new_akid", cached_credentials.accessKeyId().value()); + EXPECT_EQ("new_secret", cached_credentials.secretAccessKey().value()); + EXPECT_EQ("new_token", cached_credentials.sessionToken().value()); +} + +TEST_F(TaskRoleCredentialsProviderTest, TimestampCredentialExpiration) { + InSequence sequence; + expectDocument(R"EOF( +{ + "AccessKeyId": "akid", + "SecretAccessKey": "secret", + "Token": "token", + "Expiration": "20180102T030405Z" +} +)EOF"); + const auto credentials = provider_.getCredentials(); + EXPECT_EQ("akid", credentials.accessKeyId().value()); + EXPECT_EQ("secret", credentials.secretAccessKey().value()); + EXPECT_EQ("token", credentials.sessionToken().value()); + expectDocument(R"EOF( +{ + "AccessKeyId": "new_akid", + "SecretAccessKey": "new_secret", + "Token": "new_token", + "Expiration": "20190102T030405Z" +} +)EOF"); + const auto cached_credentials = provider_.getCredentials(); + EXPECT_EQ("new_akid", cached_credentials.accessKeyId().value()); + EXPECT_EQ("new_secret", cached_credentials.secretAccessKey().value()); + EXPECT_EQ("new_token", cached_credentials.sessionToken().value()); +} + +class DefaultCredentialsProviderChainTest : public testing::Test { +public: + DefaultCredentialsProviderChainTest() { + EXPECT_CALL(factories_, createEnvironmentCredentialsProvider()); + } + + ~DefaultCredentialsProviderChainTest() { + TestEnvironment::unsetEnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"); + TestEnvironment::unsetEnvVar("AWS_CONTAINER_CREDENTIALS_FULL_URI"); + TestEnvironment::unsetEnvVar("AWS_CONTAINER_AUTHORIZATION_TOKEN"); + TestEnvironment::unsetEnvVar("AWS_EC2_METADATA_DISABLED"); + } + + class MockCredentialsProviderChainFactories : public CredentialsProviderChainFactories { + public: + MOCK_CONST_METHOD0(createMetadataFetcher, MetadataFetcherPtr()); + MOCK_CONST_METHOD0(createEnvironmentCredentialsProvider, CredentialsProviderSharedPtr()); + MOCK_CONST_METHOD5(createTaskRoleCredentialsProvider, + CredentialsProviderSharedPtr(Api::Api&, Event::TimeSystem&, + MetadataFetcherPtr&&, const std::string&, + const absl::optional&)); + MOCK_CONST_METHOD3(createInstanceProfileCredentialsProvider, + CredentialsProviderSharedPtr(Api::Api&, Event::TimeSystem&, + MetadataFetcherPtr&& fetcher)); + }; + + NiceMock api_; + Event::SimulatedTimeSystem time_system_; + NiceMock factories_; +}; + +TEST_F(DefaultCredentialsProviderChainTest, NoEnvironmentVars) { + EXPECT_CALL(factories_, + createInstanceProfileCredentialsProvider(Ref(api_), Ref(time_system_), _)); + DefaultCredentialsProviderChain chain(api_, time_system_, factories_); +} + +TEST_F(DefaultCredentialsProviderChainTest, MetadataDisabled) { + TestEnvironment::setEnvVar("AWS_EC2_METADATA_DISABLED", "true", 1); + EXPECT_CALL(factories_, createInstanceProfileCredentialsProvider(Ref(api_), Ref(time_system_), _)) + .Times(0); + DefaultCredentialsProviderChain chain(api_, time_system_, factories_); +} + +TEST_F(DefaultCredentialsProviderChainTest, MetadataNotDisabled) { + TestEnvironment::setEnvVar("AWS_EC2_METADATA_DISABLED", "false", 1); + EXPECT_CALL(factories_, + createInstanceProfileCredentialsProvider(Ref(api_), Ref(time_system_), _)); + DefaultCredentialsProviderChain chain(api_, time_system_, factories_); +} + +TEST_F(DefaultCredentialsProviderChainTest, RelativeUri) { + TestEnvironment::setEnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/path/to/creds", 1); + EXPECT_CALL(factories_, createTaskRoleCredentialsProvider(Ref(api_), Ref(time_system_), _, + "169.254.170.2:80/path/to/creds", + absl::optional())); + DefaultCredentialsProviderChain chain(api_, time_system_, factories_); +} + +TEST_F(DefaultCredentialsProviderChainTest, FullUriNoAuthorizationToken) { + TestEnvironment::setEnvVar("AWS_CONTAINER_CREDENTIALS_FULL_URI", "http://host/path/to/creds", 1); + EXPECT_CALL(factories_, createTaskRoleCredentialsProvider(Ref(api_), Ref(time_system_), _, + "http://host/path/to/creds", + absl::optional())); + DefaultCredentialsProviderChain chain(api_, time_system_, factories_); +} + +TEST_F(DefaultCredentialsProviderChainTest, FullUriWithAuthorizationToken) { + TestEnvironment::setEnvVar("AWS_CONTAINER_CREDENTIALS_FULL_URI", "http://host/path/to/creds", 1); + TestEnvironment::setEnvVar("AWS_CONTAINER_AUTHORIZATION_TOKEN", "auth_token", 1); + EXPECT_CALL(factories_, createTaskRoleCredentialsProvider( + Ref(api_), Ref(time_system_), _, "http://host/path/to/creds", + absl::optional("auth_token"))); + DefaultCredentialsProviderChain chain(api_, time_system_, factories_); +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/test/common/aws/metadata_fetcher_impl_test.cc b/test/common/aws/metadata_fetcher_impl_test.cc new file mode 100644 index 0000000000000..ebafd374cbd54 --- /dev/null +++ b/test/common/aws/metadata_fetcher_impl_test.cc @@ -0,0 +1,210 @@ +#include "common/aws/metadata_fetcher_impl.h" +#include "common/event/dispatcher_impl.h" +#include "common/http/header_map_impl.h" +#include "common/http/headers.h" + +#include "test/mocks/event/mocks.h" +#include "test/mocks/http/mocks.h" +#include "test/mocks/network/connection.h" +#include "test/test_common/simulated_time_system.h" + +using testing::_; +using testing::InSequence; +using testing::Invoke; +using testing::InvokeWithoutArgs; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; + +namespace Envoy { +namespace Aws { +namespace Auth { + +class MetadataSessionTest : public testing::Test { +public: + MetadataSessionTest() + : connection_(new NiceMock()), + codec_(new NiceMock()) {} + + std::unique_ptr createSession() { + EXPECT_CALL(*connection_, addReadFilter(_)); + EXPECT_CALL(*connection_, addConnectionCallbacks(_)); + EXPECT_CALL(*connection_, connect()); + EXPECT_CALL(*connection_, noDelay(true)); + EXPECT_CALL(*codec_, newStream(_)).WillOnce(ReturnRef(encoder_)); + EXPECT_CALL(encoder_, getStream()).WillRepeatedly(ReturnRef(stream_)); + EXPECT_CALL(stream_, addCallbacks(Ref(decoder_))); + EXPECT_CALL(encoder_, encodeHeaders(Ref(headers_), true)); + EXPECT_CALL(dispatcher_, createTimer_(_)).WillOnce(InvokeWithoutArgs([]() { + auto timer = new NiceMock(); + EXPECT_CALL(*timer, enableTimer(std::chrono::milliseconds(5000))); + return timer; + })); + return std::make_unique( + Network::ClientConnectionPtr{connection_}, dispatcher_, decoder_, decoder_, headers_, + [this](Network::Connection&, Http::ConnectionCallbacks&) { return codec_; }); + } + + NiceMock* connection_; + NiceMock dispatcher_; + NiceMock* codec_; + MetadataFetcherImpl::StringBufferDecoder decoder_; + Http::HeaderMapImpl headers_; + NiceMock encoder_; + NiceMock stream_; +}; + +TEST_F(MetadataSessionTest, RemoteCloseConnected) { + auto session = createSession(); + session->onEvent(Network::ConnectionEvent::Connected); + EXPECT_CALL(*codec_, dispatch(_)); + EXPECT_CALL(stream_, resetStream(Http::StreamResetReason::ConnectionTermination)); + session->onEvent(Network::ConnectionEvent::RemoteClose); +} + +TEST_F(MetadataSessionTest, RemoteCloseNotConnected) { + auto session = createSession(); + EXPECT_CALL(*codec_, dispatch(_)); + EXPECT_CALL(stream_, resetStream(Http::StreamResetReason::ConnectionFailure)); + session->onEvent(Network::ConnectionEvent::RemoteClose); +} + +TEST_F(MetadataSessionTest, LocalCloseConnected) { + auto session = createSession(); + session->onEvent(Network::ConnectionEvent::Connected); + EXPECT_CALL(stream_, resetStream(Http::StreamResetReason::ConnectionTermination)); + session->onEvent(Network::ConnectionEvent::LocalClose); +} + +TEST_F(MetadataSessionTest, LocalCloseNotConnected) { + auto session = createSession(); + EXPECT_CALL(stream_, resetStream(Http::StreamResetReason::ConnectionFailure)); + session->onEvent(Network::ConnectionEvent::RemoteClose); +} + +TEST_F(MetadataSessionTest, CloseNoFlush) { + auto session = createSession(); + EXPECT_CALL(*connection_, close(Network::ConnectionCloseType::NoFlush)); + session->close(); +} + +class MetadataFetcherImplTest : public testing::Test { +public: + class MockStringBufferDecoder : public MetadataFetcherImpl::StringBufferDecoder { + public: + MOCK_CONST_METHOD0(body, const std::string&()); + }; + + class MockMetadataSession : public MetadataFetcherImpl::MetadataSession { + public: + MOCK_METHOD0(close, void()); + }; + + void expectConnection() { + EXPECT_CALL(dispatcher_, createClientConnection_(_, _, _, _)) + .WillOnce(Invoke([](Network::Address::InstanceConstSharedPtr address, + Network::Address::InstanceConstSharedPtr, Network::TransportSocketPtr&, + const Network::ConnectionSocket::OptionsSharedPtr&) { + EXPECT_EQ("127.0.0.1:80", address->asString()); + return nullptr; + })); + } + + void expectTimerRun(const std::chrono::milliseconds& delay) { + EXPECT_CALL(dispatcher_, createTimer_(_)).WillOnce(Invoke([this, delay](Event::TimerCb cb) { + auto timer = new NiceMock(); + EXPECT_CALL(*timer, enableTimer(delay)); + expectConnection(); + cb(); + return timer; + })); + EXPECT_CALL(dispatcher_, exit()); + EXPECT_CALL(dispatcher_, run(Event::Dispatcher::RunType::Block)); + } + + MetadataFetcherPtr + expectMetadata(const std::string& data, + const absl::optional& auth_token = absl::optional()) { + return std::make_unique( + [&auth_token](Network::ClientConnectionPtr&&, Event::Dispatcher&, Http::StreamDecoder&, + Http::StreamCallbacks&, const Http::HeaderMap& headers, + MetadataFetcherImpl::HttpCodecFactory) { + EXPECT_EQ(Http::Headers::get().MethodValues.Get, + headers.Method()->value().getStringView()); + EXPECT_EQ("127.0.0.1:80", headers.Host()->value().getStringView()); + EXPECT_EQ("/path", headers.Path()->value().getStringView()); + auto session = new NiceMock(); + EXPECT_CALL(*session, close()); + if (auth_token) { + EXPECT_EQ(auth_token.value(), headers.Authorization()->value().getStringView()); + } + return session; + }, + [this, data]() { + auto decoder = new NiceMock(); + if (num_failures_ == 0) { + expectTimerRun(std::chrono::milliseconds(0)); + } else { + expectTimerRun(std::chrono::milliseconds(1000)); + } + EXPECT_CALL(*decoder, body) + .WillOnce(Invoke([this, &data, decoder]() -> const std::string& { + Buffer::OwnedImpl buffer; + if (num_failures_ >= required_failures_ && !data.empty()) { + buffer.add(data); + } else { + num_failures_++; + } + decoder->decodeData(buffer, true); + return decoder->body_; + })); + return decoder; + }); + } + + NiceMock dispatcher_; + int num_failures_{}; + int required_failures_{}; +}; + +TEST_F(MetadataFetcherImplTest, SuccessfulRequest) { + const auto fetcher = expectMetadata("test"); + const auto data = fetcher->getMetadata(dispatcher_, "127.0.0.1", "/path"); + EXPECT_EQ("test", data); + EXPECT_EQ(0, num_failures_); +} + +TEST_F(MetadataFetcherImplTest, SuccessfulAfter2ndRequest) { + required_failures_ = 1; + const auto fetcher = expectMetadata("test"); + const auto data = fetcher->getMetadata(dispatcher_, "127.0.0.1", "/path"); + EXPECT_EQ("test", data.value()); + EXPECT_EQ(1, num_failures_); +} + +TEST_F(MetadataFetcherImplTest, SuccessfulAfter3rdRequest) { + required_failures_ = 2; + const auto fetcher = expectMetadata("test"); + const auto data = fetcher->getMetadata(dispatcher_, "127.0.0.1", "/path"); + EXPECT_EQ("test", data.value()); + EXPECT_EQ(2, num_failures_); +} + +TEST_F(MetadataFetcherImplTest, AuthToken) { + const auto fetcher = expectMetadata("test", "auth_token"); + const auto data = fetcher->getMetadata(dispatcher_, "127.0.0.1", "/path", "auth_token"); + EXPECT_EQ("test", data); + EXPECT_EQ(0, num_failures_); +} + +TEST_F(MetadataFetcherImplTest, FailureEveryRetry) { + const auto fetcher = expectMetadata(""); + const auto data = fetcher->getMetadata(dispatcher_, "127.0.0.1", "/path"); + EXPECT_FALSE(data.has_value()); + EXPECT_EQ(4, num_failures_); +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/test/common/aws/region_provider_impl_test.cc b/test/common/aws/region_provider_impl_test.cc new file mode 100644 index 0000000000000..51a72bc272bf8 --- /dev/null +++ b/test/common/aws/region_provider_impl_test.cc @@ -0,0 +1,27 @@ +#include "common/aws/region_provider_impl.h" + +#include "test/test_common/environment.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class EnvironmentRegionProviderTest : public testing::Test { +public: + virtual ~EnvironmentRegionProviderTest() { TestEnvironment::unsetEnvVar("AWS_REGION"); } + + EnvironmentRegionProvider provider_; +}; + +TEST_F(EnvironmentRegionProviderTest, SomeRegion) { + TestEnvironment::setEnvVar("AWS_REGION", "test-region", 1); + EXPECT_EQ("test-region", provider_.getRegion().value()); +} + +TEST_F(EnvironmentRegionProviderTest, NoRegion) { EXPECT_FALSE(provider_.getRegion().has_value()); } + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/test/common/aws/signer_impl_test.cc b/test/common/aws/signer_impl_test.cc new file mode 100644 index 0000000000000..6c815670f9392 --- /dev/null +++ b/test/common/aws/signer_impl_test.cc @@ -0,0 +1,188 @@ +#include "common/aws/signer_impl.h" +#include "common/buffer/buffer_impl.h" +#include "common/http/message_impl.h" + +#include "test/mocks/aws/mocks.h" +#include "test/test_common/simulated_time_system.h" +#include "test/test_common/utility.h" + +using testing::NiceMock; +using testing::Return; + +namespace Envoy { +namespace Aws { +namespace Auth { + +class SignerImplTest : public testing::Test { +public: + SignerImplTest() + : credentials_provider_(new NiceMock()), + region_provider_(new NiceMock()), + message_(new Http::RequestMessageImpl()), + signer_("service", CredentialsProviderSharedPtr{credentials_provider_}, + RegionProviderSharedPtr{region_provider_}, time_system_), + credentials_("akid", "secret"), token_credentials_("akid", "secret", "token"), + region_("region") { + // 20180102T030405Z + time_system_.setSystemTime(std::chrono::milliseconds(1514862245000)); + } + + void addMethod(const std::string& method) { message_->headers().insertMethod().value(method); } + + void addPath(const std::string& path) { message_->headers().insertPath().value(path); } + + void addHeader(const std::string& key, const std::string& value) { + message_->headers().addCopy(Http::LowerCaseString(key), value); + } + + void setBody(const std::string& body) { + message_->body() = std::make_unique(body); + } + + std::string canonicalRequest() { + const auto canonical_headers = signer_.canonicalizeHeaders(message_->headers()); + const auto signing_headers = signer_.createSigningHeaders(canonical_headers); + const auto content_hash = signer_.createContentHash(*message_); + return signer_.createCanonicalRequest(*message_, canonical_headers, signing_headers, + content_hash); + } + + NiceMock* credentials_provider_; + NiceMock* region_provider_; + Event::SimulatedTimeSystem time_system_; + Http::MessagePtr message_; + SignerImpl signer_; + Credentials credentials_; + Credentials token_credentials_; + absl::optional region_; +}; + +TEST_F(SignerImplTest, AnonymousCredentials) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(Credentials())); + EXPECT_CALL(*region_provider_, getRegion()).Times(0); + signer_.sign(*message_); + EXPECT_EQ(nullptr, message_->headers().Authorization()); +} + +TEST_F(SignerImplTest, MissingRegionException) { + absl::optional no_region; + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(no_region)); + EXPECT_THROW_WITH_MESSAGE(signer_.sign(*message_), EnvoyException, + "Could not determine AWS region"); + EXPECT_EQ(nullptr, message_->headers().Authorization()); +} + +TEST_F(SignerImplTest, SecurityTokenHeader) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(token_credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(region_)); + addMethod("GET"); + addPath("/"); + signer_.sign(*message_); + EXPECT_STREQ("token", message_->headers().get(SignerImpl::X_AMZ_SECURITY_TOKEN)->value().c_str()); + EXPECT_STREQ("AWS4-HMAC-SHA256 Credential=akid/20180102/region/service/aws4_request, " + "SignedHeaders=x-amz-content-sha256;x-amz-date;x-amz-security-token, " + "Signature=1d42526aabf7d8b6d7d33d9db43b03537300cc7e6bb2817e349749e0a08f5b5e", + message_->headers().Authorization()->value().c_str()); +} + +TEST_F(SignerImplTest, DateHeader) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(region_)); + addMethod("GET"); + addPath("/"); + signer_.sign(*message_); + EXPECT_STREQ("20180102T030400Z", + message_->headers().get(SignerImpl::X_AMZ_DATE)->value().c_str()); + EXPECT_STREQ("AWS4-HMAC-SHA256 Credential=akid/20180102/region/service/aws4_request, " + "SignedHeaders=x-amz-content-sha256;x-amz-date, " + "Signature=4ee6aa9355259c18133f150b139ea9aeb7969c9408ad361b2151f50a516afe42", + message_->headers().Authorization()->value().c_str()); +} + +TEST_F(SignerImplTest, EmptyContentHeader) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(region_)); + addMethod("GET"); + addPath("/empty?content=none"); + signer_.sign(*message_); + EXPECT_STREQ("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + message_->headers().get(SignerImpl::X_AMZ_CONTENT_SHA256)->value().c_str()); + EXPECT_STREQ("AWS4-HMAC-SHA256 Credential=akid/20180102/region/service/aws4_request, " + "SignedHeaders=x-amz-content-sha256;x-amz-date, " + "Signature=999e211bc7134cc685f830a332cf4d871b6d8bb8ced9367c1a0b59b95a03ee7d", + message_->headers().Authorization()->value().c_str()); +} + +TEST_F(SignerImplTest, ContentHeader) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(region_)); + addMethod("POST"); + addPath("/"); + setBody("test1234"); + signer_.sign(*message_); + EXPECT_STREQ("937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244", + message_->headers().get(SignerImpl::X_AMZ_CONTENT_SHA256)->value().c_str()); + EXPECT_STREQ("AWS4-HMAC-SHA256 Credential=akid/20180102/region/service/aws4_request, " + "SignedHeaders=x-amz-content-sha256;x-amz-date, " + "Signature=4eab89c36f45f2032d6010ba1adab93f8510ddd6afe540821f3a05bb0253e27b", + message_->headers().Authorization()->value().c_str()); +} + +TEST_F(SignerImplTest, MissingMethodException) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(region_)); + EXPECT_THROW_WITH_MESSAGE(signer_.sign(*message_), EnvoyException, + "Message is missing :method header"); + EXPECT_EQ(nullptr, message_->headers().Authorization()); +} + +TEST_F(SignerImplTest, MissingPathException) { + EXPECT_CALL(*credentials_provider_, getCredentials()).WillOnce(Return(credentials_)); + EXPECT_CALL(*region_provider_, getRegion()).WillOnce(Return(region_)); + addMethod("GET"); + EXPECT_THROW_WITH_MESSAGE(signer_.sign(*message_), EnvoyException, + "Message is missing :path header"); + EXPECT_EQ(nullptr, message_->headers().Authorization()); +} + +TEST_F(SignerImplTest, ComplexCanonicalRequest) { + addMethod("POST"); + addPath("/path/foo?bar=baz"); + setBody("test1234"); + addHeader(":authority", "example.com:80"); + addHeader("UpperCase", "uppercasevalue"); + addHeader("MultiLine", "hello\n\nworld\n\nline3\n"); + addHeader("Trimmable", " trim me "); + addHeader("EmptyOne", ""); + addHeader("Alphabetic", "abcd"); + EXPECT_EQ(R"(POST +/path/foo +bar=baz +alphabetic:abcd +emptyone: +host:example.com +multiline:hello,world,line3 +trimmable:trim me +uppercase:uppercasevalue + +alphabetic;emptyone;host;multiline;trimmable;uppercase +937e8d5fbb48bd4949536cd65b8d35c426b80d2f830c5c308e2cdec422ae2244)", + canonicalRequest()); +} + +TEST_F(SignerImplTest, EmptyCanonicalRequest) { + addMethod("POST"); + addPath("/hello"); + EXPECT_EQ(R"(POST +/hello + + + +e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855)", + canonicalRequest()); +} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/test/common/grpc/google_async_client_impl_test.cc b/test/common/grpc/google_async_client_impl_test.cc index 5b42d4229c64c..2e34debaeb15f 100644 --- a/test/common/grpc/google_async_client_impl_test.cc +++ b/test/common/grpc/google_async_client_impl_test.cc @@ -55,8 +55,8 @@ class EnvoyGoogleAsyncClientImplTest : public testing::Test { google_grpc->set_target_uri("fake_address"); google_grpc->set_stat_prefix("test_cluster"); tls_ = std::make_unique(*api_); - grpc_client_ = - std::make_unique(dispatcher_, *tls_, stub_factory_, scope_, config); + grpc_client_ = std::make_unique(*api_, dispatcher_, *tls_, stub_factory_, + scope_, config); } DangerousDeprecatedTestTime test_time_; diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index 16006fa4c17fd..2a0883efc6fb5 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -305,7 +305,7 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { #ifdef ENVOY_GOOGLE_GRPC google_tls_ = std::make_unique(*api_); GoogleGenericStubFactory stub_factory; - return std::make_unique(dispatcher_, *google_tls_, stub_factory, + return std::make_unique(*api_, dispatcher_, *google_tls_, stub_factory, stats_scope_, createGoogleGrpcConfig()); #else NOT_REACHED_GCOVR_EXCL_LINE; diff --git a/test/mocks/aws/BUILD b/test/mocks/aws/BUILD new file mode 100644 index 0000000000000..e410f4325b5d9 --- /dev/null +++ b/test/mocks/aws/BUILD @@ -0,0 +1,21 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_mock", + "envoy_package", +) + +envoy_package() + +envoy_cc_mock( + name = "aws_mocks", + srcs = ["mocks.cc"], + hdrs = ["mocks.h"], + deps = [ + "//source/common/aws:credentials_provider_lib", + "//source/common/aws:metadata_fetcher_lib", + "//source/common/aws:region_provider_lib", + "//source/common/aws:signer_lib", + ], +) diff --git a/test/mocks/aws/mocks.cc b/test/mocks/aws/mocks.cc new file mode 100644 index 0000000000000..4c35f65be539b --- /dev/null +++ b/test/mocks/aws/mocks.cc @@ -0,0 +1,21 @@ +#include "test/mocks/aws/mocks.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +MockCredentialsProvider::MockCredentialsProvider() {} +MockCredentialsProvider::~MockCredentialsProvider() {} + +MockRegionProvider::MockRegionProvider() {} +MockRegionProvider::~MockRegionProvider() {} + +MockSigner::MockSigner() {} +MockSigner::~MockSigner() {} + +MockMetadataFetcher::MockMetadataFetcher() {} +MockMetadataFetcher::~MockMetadataFetcher() {} + +} // namespace Auth +} // namespace Aws +} // namespace Envoy \ No newline at end of file diff --git a/test/mocks/aws/mocks.h b/test/mocks/aws/mocks.h new file mode 100644 index 0000000000000..4a129c8edc843 --- /dev/null +++ b/test/mocks/aws/mocks.h @@ -0,0 +1,51 @@ +#pragma once + +#include "common/aws/credentials_provider.h" +#include "common/aws/metadata_fetcher.h" +#include "common/aws/region_provider.h" +#include "common/aws/signer.h" + +#include "gmock/gmock.h" + +namespace Envoy { +namespace Aws { +namespace Auth { + +class MockCredentialsProvider : public CredentialsProvider { +public: + MockCredentialsProvider(); + ~MockCredentialsProvider(); + + MOCK_METHOD0(getCredentials, Credentials()); +}; + +class MockRegionProvider : public RegionProvider { +public: + MockRegionProvider(); + ~MockRegionProvider(); + + MOCK_METHOD0(getRegion, absl::optional()); +}; + +class MockSigner : public Signer { +public: + MockSigner(); + ~MockSigner(); + + MOCK_CONST_METHOD1(sign, void(Http::Message&)); +}; + +class MockMetadataFetcher : public MetadataFetcher { +public: + MockMetadataFetcher(); + ~MockMetadataFetcher(); + + MOCK_CONST_METHOD4(getMetadata, + absl::optional(Event::Dispatcher&, const std::string&, + const std::string&, + const absl::optional&)); +}; + +} // namespace Auth +} // namespace Aws +} // namespace Envoy diff --git a/test/test_common/environment.cc b/test/test_common/environment.cc index 87cdadacb427c..12254f7ae4c81 100644 --- a/test/test_common/environment.cc +++ b/test/test_common/environment.cc @@ -323,4 +323,14 @@ void TestEnvironment::setEnvVar(const std::string& name, const std::string& valu #endif } +void TestEnvironment::unsetEnvVar(const std::string& name) { +#ifdef WIN32 + const int rc = ::_putenv_s(name.c_str(), ""); + ASSERT_EQ(0, rc); +#else + const int rc = ::unsetenv(name.c_str()); + ASSERT_EQ(rc, 0); +#endif +} + } // namespace Envoy diff --git a/test/test_common/environment.h b/test/test_common/environment.h index fe94b9ea870b0..ed0ffa438ce74 100644 --- a/test/test_common/environment.h +++ b/test/test_common/environment.h @@ -198,6 +198,11 @@ class TestEnvironment { * Set environment variable. Same args as setenv(2). */ static void setEnvVar(const std::string& name, const std::string& value, int overwrite); + + /** + * Removes environment variable. Same args as unsetenv(3). + */ + static void unsetEnvVar(const std::string& name); }; } // namespace Envoy