diff --git a/.circleci/config.yml b/.circleci/config.yml index ca346d0343373..5abbabeff266a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,6 +1,6 @@ references: envoy-build-image: &envoy-build-image - envoyproxy/envoy-build:7f7f5666c72e00ac7c1909b4fc9a2121d772c859 + envoyproxy/envoy-build:1ef23d481a4701ad4a414d1ef98036bd2ed322e7 version: 2 jobs: diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000..3f3fbbb8dc6c8 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +/docs/root/intro/version_history.rst merge=union diff --git a/.github/stale.yml b/.github/stale.yml index dc297cc57d898..31ea115101448 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -41,3 +41,4 @@ issues: Thank you for your contributions. exemptLabels: - help wanted + - no stalebot diff --git a/.gitignore b/.gitignore index db2be52a8b129..7f4c28b379488 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ SOURCE_VERSION .cache .vimrc .vscode +.vs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bb7c9efe3a1ee..41ee9af1e7402 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,7 +90,7 @@ maximize the chances of your PR being merged. * We expect that once a PR is opened, it will be actively worked on until it is merged or closed. We reserve the right to close PRs that are not making progress. This is generally defined as no changes for 7 days. Obviously PRs that are closed due to lack of activity can be reopened later. - Closing stale PRs helps us keep on top of all of the work currently in flight. + Closing stale PRs helps us to keep on top of all of the work currently in flight. * If a commit deprecates a feature, the commit message must mention what has been deprecated. Additionally, [DEPRECATED.md](DEPRECATED.md) must be updated as part of the commit. * Please consider joining the [envoy-dev](https://groups.google.com/forum/#!forum/envoy-dev) diff --git a/DEPRECATED.md b/DEPRECATED.md index 43456db808c50..955494c23b6e5 100644 --- a/DEPRECATED.md +++ b/DEPRECATED.md @@ -8,13 +8,24 @@ A logged warning is expected for each deprecated item that is in deprecation win ## Version 1.8.0 (pending) -* Use of the legacy +* Use of the v1 API is deprecated. See envoy-announce + [email](https://groups.google.com/forum/#!topic/envoy-announce/oPnYMZw8H4U). +* Use of the legacy [ratelimit.proto](https://github.com/envoyproxy/envoy/blob/b0a518d064c8255e0e20557a8f909b6ff457558f/source/common/ratelimit/ratelimit.proto) is deprecated, in favor of the proto defined in [date-plane-api](https://github.com/envoyproxy/envoy/blob/master/api/envoy/service/ratelimit/v2/rls.proto) Prior to 1.8.0, Envoy can use either proto to send client requests to a ratelimit server with the use of the `use_data_plane_proto` boolean flag in the [ratelimit configuration](https://github.com/envoyproxy/envoy/blob/master/api/envoy/config/ratelimit/v2/rls.proto). However, when using the deprecated client a warning is logged. +* Use of the --v2-config-only flag. +* Use of both `use_websocket` and `websocket_config` in + [route.proto](https://github.com/envoyproxy/envoy/blob/master/api/envoy/api/v2/route/route.proto) + is deprecated. Please use the new `upgrade_configs` in the + [HttpConnectionManager](https://github.com/envoyproxy/envoy/blob/master/api/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto) + instead. +* Setting hosts via `hosts` field in `Cluster` is deprecated. Use `load_assignment` instead. +* Use of `response_headers_to_*` and `request_headers_to_add` are deprecated at the `RouteAction` + level. Please use the configuration options at the `Route` level. ## Version 1.7.0 diff --git a/README.md b/README.md index abfda00435a8c..f01a5e5ce56b8 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ involved and how Envoy plays a role, read the CNCF * [Official documentation](https://www.envoyproxy.io/) * [FAQ](https://www.envoyproxy.io/docs/envoy/latest/faq/overview) -* [Unofficial Chinese documentation](https://github.com/lixiangyun/envoyproxy_doc_ZH_CN) +* [Unofficial Chinese documentation](https://github.com/servicemesher/envoy/) * Watch [a video overview of Envoy](https://www.youtube.com/watch?v=RVZX4CwKhGE) ([transcript](https://www.microservices.com/talks/lyfts-envoy-monolith-service-mesh-matt-klein/)) to find out more about the origin story and design philosophy of Envoy diff --git a/REPO_LAYOUT.md b/REPO_LAYOUT.md index bc2bb7660a810..cd87e015ac5b8 100644 --- a/REPO_LAYOUT.md +++ b/REPO_LAYOUT.md @@ -28,7 +28,7 @@ are: ## [source/](source/) * [common/](source/common/): Core Envoy code (not specific to extensions) that is also not - specific to a standalone server implementation. I.e., this is code that could be used if Envoy + specific to a standalone server implementation. I.e., this is the code that could be used if Envoy were eventually embedded as a library. * [docs/](source/docs/): Miscellaneous developer/design documentation that is not relevant for the public user documentation. diff --git a/SECURITY_RELEASE_PROCESS.md b/SECURITY_RELEASE_PROCESS.md index ca8d16e276781..ac49c34c04745 100644 --- a/SECURITY_RELEASE_PROCESS.md +++ b/SECURITY_RELEASE_PROCESS.md @@ -1,7 +1,7 @@ # Security Release Process Envoy is a large growing community of volunteers, users, and vendors. The Envoy community has -adopted this security disclosures and response policy to ensure we responsibly handle critical +adopted this security disclosure and response policy to ensure we responsibly handle critical issues. ## Product Security Team (PST) @@ -73,7 +73,7 @@ These steps should be completed within the 1-7 days of Disclosure. - The Fix Lead and the Fix Team will create a [CVSS](https://www.first.org/cvss/specification-document) using the [CVSS Calculator](https://www.first.org/cvss/calculator/3.0). The Fix Lead makes the final call on the - calculated CVSS; it is better to move quickly than make the CVSS perfect. + calculated CVSS; it is better to move quickly than making the CVSS perfect. - The Fix Team will notify the Fix Lead that work on the fix branch is complete once there are LGTMs on all commits in the private repo from one or more maintainers. @@ -160,7 +160,7 @@ said issue, they must agree to the same terms and only find out information on a In the unfortunate event you share the information beyond what is allowed by this policy, you _must_ urgently inform the envoy-security@googlegroups.com mailing list of exactly what information leaked -and to whom. A retrospective will take place after the leak so we can assess how to not make the +and to whom. A retrospective will take place after the leak so we can assess how to prevent making the same mistake in the future. If you continue to leak information and break the policy outlined here, you will be removed from the diff --git a/STYLE.md b/STYLE.md index 2b641607e23de..72a5411ae48cb 100644 --- a/STYLE.md +++ b/STYLE.md @@ -1,6 +1,6 @@ # C++ coding style -* The Envoy source code is formatted using clang-format. Thus all white space, etc. +* The Envoy source code is formatted using clang-format. Thus all white spaces, etc. issues are taken care of automatically. The Travis tests will automatically check the code format and fail. There are make targets that can both check the format (check_format) as well as fix the code format for you (fix_format). @@ -96,7 +96,7 @@ A few general notes on our error handling philosophy: silently be ignored and should crash the process either via the C++ allocation error exception, an explicit `RELEASE_ASSERT` following a third party library call, or an obvious crash on a subsequent line via null pointer dereference. This rule is again based on the philosophy that the engineering - costs of properly handling these cases is not worth it. Time is better spent designing proper system + costs of properly handling these cases are not worth it. Time is better spent designing proper system controls that shed load if resource usage becomes too high, etc. * The "less is more" error handling philosophy described in the previous two points is primarily based on the fact that restarts are designed to be fast, reliable and cheap. diff --git a/api/STYLE.md b/api/STYLE.md index d932c3a3b17cd..92592d4aac2e1 100644 --- a/api/STYLE.md +++ b/api/STYLE.md @@ -131,3 +131,6 @@ the build system to prevent circular dependency formation. Package group `//envoy/api/v2:friends` selects consumers of the core API package (services and configs) and is the default visibility for the core API packages. The default visibility for services and configs should be `//docs` (proto documentation tool). + +Extensions should use the regular hierarchy. For example, configuration for network filters belongs +in a package under `envoy.config.filter.network`. diff --git a/api/XDS_PROTOCOL.md b/api/XDS_PROTOCOL.md index 2021c68334bd5..67c7cc1a7bfe6 100644 --- a/api/XDS_PROTOCOL.md +++ b/api/XDS_PROTOCOL.md @@ -147,7 +147,7 @@ management server will provide the complete state of the LDS/CDS resources in each response. An absent `Listener` or `Cluster` will be deleted. For EDS/RDS, the management server does not need to supply every requested -resource and may also supply additional, unrequested resources, `resource_names` +resource and may also supply additional, unrequested resources. `resource_names` is only a hint. Envoy will silently ignore any superfluous resources. When a requested resource is missing in a RDS or EDS update, Envoy will retain the last known value for this resource. The management server may be able to infer all @@ -166,7 +166,7 @@ For EDS/RDS, Envoy may either generate a distinct stream for each resource of a given type (e.g. if each `ConfigSource` has its own distinct upstream cluster for a management server), or may combine together multiple resource requests for a given resource type when they are destined for the same management server. -This is left to implementation specifics, management servers should be capable +While this is left to implementation specifics, management servers should be capable of handling one or more `resource_names` for a given resource type in each request. Both sequence diagrams below are valid for fetching two EDS resources `{foo, bar}`: @@ -285,6 +285,51 @@ admin: ``` +### Incremental xDS + +Incremental xDS is a separate xDS endpoint available for ADS, CDS and RDS that +allows: + + * Incremental updates of the list of tracked resources by the xDS client. + This supports Envoy on-demand / lazily requesting additional resources. For + example, this may occur when a request corresponding to an unknown cluster + arrives. + * The xDS server can incremetally update the resources on the client. + This supports the goal of scalability of xDS resources. Rather than deliver + all 100k clusters when a single cluster is modified, the management server + only needs to deliver the single cluster that changed. + +An xDS incremental session is always in the context of a gRPC bidirectional +stream. This allows the xDS server to keep track of the state of xDS clients +connected to it. There is no REST version of Incremental xDS. + +In incremental xDS the nonce field is required and used to pair a +[`IncrementalDiscoveryResponse`](https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/discovery.proto#discoveryrequest) +to a [`IncrementalDiscoveryRequest`](https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/discovery.proto#discoveryrequest) +ACK or NACK. +Optionally, a response message level system_version_info is present for +debugging purposes only. + +`IncrementalDiscoveryRequest` can be sent in 3 situations: + 1. Initial message in a xDS bidirectional gRPC stream. + 2. As an ACK or NACK response to a previous `IncrementalDiscoveryResponse`. + In this case the `response_nonce` is set to the nonce value in the Response. + ACK or NACK is determined by the absence or presence of `error_detail`. + 3. Spontaneous `IncrementalDiscoveryRequest` from the client. + This can be done to dynamically add or remove elements from the tracked + `resource_names` set. In this case `response_nonce` must be omitted. + +In this first example the client connects and receives a first update that it +ACKs. The second update fails and the client NACKs the update. Later the xDS +client spontaneously requests the "wc" resource. + +![Incremental session example](diagrams/incremental.svg) + +On reconnect the xDS Incremental client may tell the server of its known resources +to avoid resending them over the network. + +![Incremental reconnect example](diagrams/incremental-reconnect.svg) + ## REST-JSON polling subscriptions Synchronous (long) polling via REST endpoints is also available for the xDS diff --git a/api/bazel/api_build_system.bzl b/api/bazel/api_build_system.bzl index 875df406bdc41..497d82c5ccc07 100644 --- a/api/bazel/api_build_system.bzl +++ b/api/bazel/api_build_system.bzl @@ -1,23 +1,22 @@ load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") load("@com_lyft_protoc_gen_validate//bazel:pgv_proto_library.bzl", "pgv_cc_proto_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library", "go_grpc_library") +load("@io_bazel_rules_go//proto:def.bzl", "go_grpc_library", "go_proto_library") load("@io_bazel_rules_go//go:def.bzl", "go_test") -_PY_SUFFIX="_py" -_CC_SUFFIX="_cc" -_GO_PROTO_SUFFIX="_go_proto" -_GO_GRPC_SUFFIX="_go_grpc" -_GO_IMPORTPATH_PREFIX="github.com/envoyproxy/data-plane-api/api/" +_PY_SUFFIX = "_py" +_CC_SUFFIX = "_cc" +_GO_PROTO_SUFFIX = "_go_proto" +_GO_GRPC_SUFFIX = "_go_grpc" +_GO_IMPORTPATH_PREFIX = "github.com/envoyproxy/data-plane-api/api/" def _Suffix(d, suffix): - return d + suffix + return d + suffix def _LibrarySuffix(library_name, suffix): - # Transform //a/b/c to //a/b/c:c in preparation for suffix operation below. - if library_name.startswith("//") and ":" not in library_name: - library_name += ":" + Label(library_name).name - return _Suffix(library_name, suffix) - + # Transform //a/b/c to //a/b/c:c in preparation for suffix operation below. + if library_name.startswith("//") and ":" not in library_name: + library_name += ":" + Label(library_name).name + return _Suffix(library_name, suffix) # TODO(htuch): has_services is currently ignored but will in future support # gRPC stub generation. @@ -32,6 +31,7 @@ def api_py_proto_library(name, srcs = [], deps = [], has_services = 0): protoc = "@com_google_protobuf//:protoc", deps = [_LibrarySuffix(d, _PY_SUFFIX) for d in deps] + [ "@com_lyft_protoc_gen_validate//validate:validate_py", + "@googleapis//:api_httpbody_protos_py", "@googleapis//:http_api_protos_py", "@googleapis//:rpc_status_protos_py", "@com_github_gogo_protobuf//:gogo_proto_py", @@ -54,7 +54,7 @@ def api_go_proto_library(name, proto, deps = []): "@com_github_golang_protobuf//ptypes/any:go_default_library", "@com_lyft_protoc_gen_validate//validate:go_default_library", "@googleapis//:rpc_status_go_proto", - ] + ], ) def api_go_grpc_library(name, proto, deps = []): @@ -71,9 +71,19 @@ def api_go_grpc_library(name, proto, deps = []): "@com_github_golang_protobuf//ptypes/any:go_default_library", "@com_lyft_protoc_gen_validate//validate:go_default_library", "@googleapis//:http_api_go_proto", - ] + ], ) +# This is api_proto_library plus some logic internal to //envoy/api. +def api_proto_library_internal(visibility = ["//visibility:private"], **kwargs): + # //envoy/docs/build.sh needs visibility in order to generate documents. + if visibility == ["//visibility:private"]: + visibility = ["//docs"] + elif visibility != ["//visibility:public"]: + visibility = visibility + ["//docs"] + + api_proto_library(visibility = visibility, **kwargs) + # TODO(htuch): has_services is currently ignored but will in future support # gRPC stub generation. # TODO(htuch): Automatically generate go_proto_library and go_grpc_library @@ -86,11 +96,6 @@ def api_proto_library(name, visibility = ["//visibility:private"], srcs = [], de # it can play well with the PGV plugin and (2) other language support that # can make use of native proto_library. - if visibility == ["//visibility:private"]: - visibility = ["//docs"] - elif visibility != ["//visibility:public"]: - visibility = visibility + ["//docs"] - native.proto_library( name = name, srcs = srcs, @@ -102,6 +107,7 @@ def api_proto_library(name, visibility = ["//visibility:private"], srcs = [], de "@com_google_protobuf//:struct_proto", "@com_google_protobuf//:timestamp_proto", "@com_google_protobuf//:wrappers_proto", + "@googleapis//:api_httpbody_protos_proto", "@googleapis//:http_api_protos_proto", "@googleapis//:rpc_status_protos_lib", "@com_github_gogo_protobuf//:gogo_proto", @@ -109,6 +115,7 @@ def api_proto_library(name, visibility = ["//visibility:private"], srcs = [], de ], visibility = visibility, ) + # Under the hood, this is just an extension of the Protobuf library's # bespoke cc_proto_library. It doesn't consume proto_library as a proto # provider. Hopefully one day we can move to a model where this target and @@ -126,7 +133,7 @@ def api_proto_library(name, visibility = ["//visibility:private"], srcs = [], de visibility = ["//visibility:public"], ) if (require_py == 1): - api_py_proto_library(name, srcs, deps, has_services) + api_py_proto_library(name, srcs, deps, has_services) def api_cc_test(name, srcs, proto_deps): native.cc_test( diff --git a/api/bazel/repositories.bzl b/api/bazel/repositories.bzl index 840b2c6625c42..2e497d712fc87 100644 --- a/api/bazel/repositories.bzl +++ b/api/bazel/repositories.bzl @@ -1,9 +1,9 @@ -GOOGLEAPIS_SHA = "d642131a6e6582fc226caf9893cb7fe7885b3411" # May 23, 2018 -GOGOPROTO_SHA = "1adfc126b41513cc696b209667c8656ea7aac67c" # v1.0.0 -PROMETHEUS_SHA = "99fa1f4be8e564e8a6b613da7fa6f46c9edafc6c" # Nov 17, 2017 -OPENCENSUS_SHA = "ab82e5fdec8267dc2a726544b10af97675970847" # May 23, 2018 +GOOGLEAPIS_SHA = "d642131a6e6582fc226caf9893cb7fe7885b3411" # May 23, 2018 +GOGOPROTO_SHA = "1adfc126b41513cc696b209667c8656ea7aac67c" # v1.0.0 +PROMETHEUS_SHA = "99fa1f4be8e564e8a6b613da7fa6f46c9edafc6c" # Nov 17, 2017 +OPENCENSUS_SHA = "ab82e5fdec8267dc2a726544b10af97675970847" # May 23, 2018 -PGV_GIT_SHA = "345b6b478ef955ad31382955d21fb504e95f38c7" +PGV_GIT_SHA = "f9d2b11e44149635b23a002693b76512b01ae515" load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") @@ -17,10 +17,59 @@ def api_dependencies(): name = "googleapis", strip_prefix = "googleapis-" + GOOGLEAPIS_SHA, url = "https://github.com/googleapis/googleapis/archive/" + GOOGLEAPIS_SHA + ".tar.gz", + # TODO(dio): Consider writing a Skylark macro for importing Google API proto. build_file_content = """ load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +filegroup( + name = "api_httpbody_protos_src", + srcs = [ + "google/api/httpbody.proto", + ], + visibility = ["//visibility:public"], +) + +proto_library( + name = "api_httpbody_protos_proto", + srcs = [":api_httpbody_protos_src"], + deps = ["@com_google_protobuf//:descriptor_proto"], + visibility = ["//visibility:public"], +) + +cc_proto_library( + name = "api_httpbody_protos", + srcs = [ + "google/api/httpbody.proto", + ], + default_runtime = "@com_google_protobuf//:protobuf", + protoc = "@com_google_protobuf//:protoc", + deps = ["@com_google_protobuf//:cc_wkt_protos"], + visibility = ["//visibility:public"], +) + +py_proto_library( + name = "api_httpbody_protos_py", + srcs = [ + "google/api/httpbody.proto", + ], + include = ".", + default_runtime = "@com_google_protobuf//:protobuf_python", + protoc = "@com_google_protobuf//:protoc", + visibility = ["//visibility:public"], + deps = ["@com_google_protobuf//:protobuf_python"], +) + +go_proto_library( + name = "api_httpbody_go_proto", + importpath = "google.golang.org/genproto/googleapis/api/httpbody", + proto = ":api_httpbody_protos_proto", + visibility = ["//visibility:public"], + deps = [ + ":descriptor_go_proto", + ], +) + filegroup( name = "http_api_protos_src", srcs = [ @@ -28,7 +77,7 @@ filegroup( "google/api/http.proto", ], visibility = ["//visibility:public"], - ) +) go_proto_library( name = "descriptor_go_proto", @@ -93,6 +142,7 @@ proto_library( deps = ["@com_google_protobuf//:any_proto"], visibility = ["//visibility:public"], ) + cc_proto_library( name = "rpc_status_protos", srcs = ["google/rpc/status.proto"], @@ -189,7 +239,7 @@ py_proto_library( ) native.new_http_archive( - name = "promotheus_metrics_model", + name = "prometheus_metrics_model", strip_prefix = "client_model-" + PROMETHEUS_SHA, url = "https://github.com/prometheus/client_model/archive/" + PROMETHEUS_SHA + ".tar.gz", build_file_content = """ diff --git a/api/diagrams/incremental-reconnect.svg b/api/diagrams/incremental-reconnect.svg new file mode 100644 index 0000000000000..ef8472340ab5d --- /dev/null +++ b/api/diagrams/incremental-reconnect.svg @@ -0,0 +1 @@ +Created with Raphaël 2.2.0EnvoyEnvoyManagement ServerManagement Server{T=CDS}{R={(foo, v0), (bar, v0),nonce=n0, T=CDS}{response_nonce=n0, T=CDS} (ACK)Session is interrupted here. Reconnect.{initial_resource_versions={(foo, v0), (bar, v0)}, T=CDS} \ No newline at end of file diff --git a/api/diagrams/incremental.svg b/api/diagrams/incremental.svg new file mode 100644 index 0000000000000..e0e93b8a56725 --- /dev/null +++ b/api/diagrams/incremental.svg @@ -0,0 +1 @@ +Created with Raphaël 2.2.0EnvoyEnvoyManagement ServerManagement Server{T=CDS}{R={(foo, v0)},nonce=n0, T=CDS}{response_nonce=n0, T=CDS} (ACK)spontaneous update of server{R={(bar, v0)},nonce=n1, T=CDS}{response_nonce=n1, error_detail="could not apply", T=CDS} (NACK)spontaneous resource list update{resource_list_subscribe=wc, T=CDS}{R={(wc, v0)},nonce=n2, T=CDS}{response_nonce=n2, T=CDS} (ACK) \ No newline at end of file diff --git a/api/docs/BUILD b/api/docs/BUILD index ffd68728a1e8a..54a7b87eea4f5 100644 --- a/api/docs/BUILD +++ b/api/docs/BUILD @@ -12,6 +12,7 @@ package_group( proto_library( name = "protos", deps = [ + "//envoy/admin/v2alpha:clusters", "//envoy/admin/v2alpha:config_dump", "//envoy/api/v2:cds", "//envoy/api/v2:discovery", @@ -57,6 +58,7 @@ proto_library( "//envoy/config/trace/v2:trace", "//envoy/config/transport_socket/capture/v2alpha:capture", "//envoy/data/accesslog/v2:accesslog", + "//envoy/data/core/v2alpha:health_check_event", "//envoy/data/tap/v2alpha:capture", "//envoy/service/accesslog/v2:als", "//envoy/service/auth/v2alpha:attribute_context", @@ -66,5 +68,8 @@ proto_library( "//envoy/service/metrics/v2:metrics_service", "//envoy/type:percent", "//envoy/type:range", + "//envoy/type/matcher:metadata", + "//envoy/type/matcher:number", + "//envoy/type/matcher:string", ], ) diff --git a/api/envoy/admin/v2alpha/BUILD b/api/envoy/admin/v2alpha/BUILD index 9d2875da2a443..98696461bd1d1 100644 --- a/api/envoy/admin/v2alpha/BUILD +++ b/api/envoy/admin/v2alpha/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "config_dump", srcs = ["config_dump.proto"], visibility = ["//visibility:public"], @@ -13,3 +13,21 @@ api_proto_library( "//envoy/config/bootstrap/v2:bootstrap", ], ) + +api_proto_library_internal( + name = "clusters", + srcs = ["clusters.proto"], + visibility = ["//visibility:public"], + deps = [ + ":metrics", + "//envoy/api/v2/core:address", + "//envoy/api/v2/core:health_check", + "//envoy/type:percent", + ], +) + +api_proto_library_internal( + name = "metrics", + srcs = ["metrics.proto"], + visibility = ["//visibility:public"], +) diff --git a/api/envoy/admin/v2alpha/clusters.proto b/api/envoy/admin/v2alpha/clusters.proto new file mode 100644 index 0000000000000..fc8d91eac3075 --- /dev/null +++ b/api/envoy/admin/v2alpha/clusters.proto @@ -0,0 +1,75 @@ +syntax = "proto3"; + +package envoy.admin.v2alpha; + +import "envoy/admin/v2alpha/metrics.proto"; +import "envoy/api/v2/core/address.proto"; +import "envoy/api/v2/core/health_check.proto"; +import "envoy/type/percent.proto"; + +// [#protodoc-title: Clusters] + +// Admin endpoint uses this wrapper for `/clusters` to display cluster status information. +// See :ref:`/clusters ` for more information. +message Clusters { + // Mapping from cluster name to each cluster's status. + repeated ClusterStatus cluster_statuses = 1; +} + +// Details an individual cluster's current status. +message ClusterStatus { + // Name of the cluster. + string name = 1; + + // Denotes whether this cluster was added via API or configured statically. + bool added_via_api = 2; + + // The success rate threshold used in the last interval. The threshold is used to eject hosts + // based on their success rate. See + // :ref:`Cluster outlier detection ` statistics + // + // Note: this field may be omitted in any of the three following cases: + // + // 1. There were not enough hosts with enough request volume to proceed with success rate based + // outlier ejection. + // 2. The threshold is computed to be < 0 because a negative value implies that there was no + // threshold for that interval. + // 3. Outlier detection is not enabled for this cluster. + envoy.type.Percent success_rate_ejection_threshold = 3; + + // Mapping from host address to the host's current status. + repeated HostStatus host_statuses = 4; +} + +// Current state of a particular host. +message HostStatus { + // Address of this host. + envoy.api.v2.core.Address address = 1; + + // Mapping from the name of the statistic to the current value. + map stats = 2; + + // The host's current health status. + HostHealthStatus health_status = 3; + + // Request success rate for this host over the last calculated interval. + // + // Note: the message will not be present if host did not have enough request volume to calculate + // success rate or the cluster did not have enough hosts to run through success rate outlier + // ejection. + envoy.type.Percent success_rate = 4; +} + +// Health status for a host. +message HostHealthStatus { + // The host is currently failing active health checks. + bool failed_active_health_check = 1; + + // The host is currently considered an outlier and has been ejected. + bool failed_outlier_check = 2; + + // Health status as reported by EDS. Note: only HEALTHY and UNHEALTHY are currently supported + // here. + // TODO(mrice32): pipe through remaining EDS health status possibilities. + envoy.api.v2.core.HealthStatus eds_health_status = 3; +} diff --git a/api/envoy/admin/v2alpha/metrics.proto b/api/envoy/admin/v2alpha/metrics.proto new file mode 100644 index 0000000000000..93927157c1ef6 --- /dev/null +++ b/api/envoy/admin/v2alpha/metrics.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package envoy.admin.v2alpha; + +// [#protodoc-title: Metrics] + +// Proto representation of an Envoy Counter or Gauge value. +message SimpleMetric { + enum Type { + COUNTER = 0; + GAUGE = 1; + } + + // Type of metric represented. + Type type = 1; + + // Current metric value. + uint64 value = 2; +} diff --git a/api/envoy/api/v2/BUILD b/api/envoy/api/v2/BUILD index 3e557a10239ed..261d140819985 100644 --- a/api/envoy/api/v2/BUILD +++ b/api/envoy/api/v2/BUILD @@ -1,4 +1,4 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 @@ -16,7 +16,7 @@ package_group( ], ) -api_proto_library( +api_proto_library_internal( name = "discovery", srcs = ["discovery.proto"], visibility = [":friends"], @@ -29,7 +29,7 @@ api_go_proto_library( deps = ["//envoy/api/v2/core:base_go_proto"], ) -api_proto_library( +api_proto_library_internal( name = "eds", srcs = ["eds.proto"], has_services = 1, @@ -40,6 +40,7 @@ api_proto_library( "//envoy/api/v2/core:base", "//envoy/api/v2/core:health_check", "//envoy/api/v2/endpoint", + "//envoy/type:percent", ], ) @@ -52,10 +53,11 @@ api_go_grpc_library( "//envoy/api/v2/core:base_go_proto", "//envoy/api/v2/core:health_check_go_proto", "//envoy/api/v2/endpoint:endpoint_go_proto", + "//envoy/type:percent_go_proto", ], ) -api_proto_library( +api_proto_library_internal( name = "cds", srcs = ["cds.proto"], has_services = 1, @@ -95,7 +97,7 @@ api_go_grpc_library( ], ) -api_proto_library( +api_proto_library_internal( name = "lds", srcs = ["lds.proto"], has_services = 1, @@ -119,7 +121,7 @@ api_go_grpc_library( ], ) -api_proto_library( +api_proto_library_internal( name = "rds", srcs = ["rds.proto"], has_services = 1, diff --git a/api/envoy/api/v2/auth/BUILD b/api/envoy/api/v2/auth/BUILD index c3ea89de1ff1d..55f522c0085a9 100644 --- a/api/envoy/api/v2/auth/BUILD +++ b/api/envoy/api/v2/auth/BUILD @@ -1,4 +1,4 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 @@ -15,7 +15,7 @@ package_group( ], ) -api_proto_library( +api_proto_library_internal( name = "cert", srcs = ["cert.proto"], visibility = [":friends"], diff --git a/api/envoy/api/v2/cds.proto b/api/envoy/api/v2/cds.proto index 8359cd51964b0..e23cda0c8f8f1 100644 --- a/api/envoy/api/v2/cds.proto +++ b/api/envoy/api/v2/cds.proto @@ -31,6 +31,10 @@ service ClusterDiscoveryService { rpc StreamClusters(stream DiscoveryRequest) returns (stream DiscoveryResponse) { } + rpc IncrementalClusters(stream IncrementalDiscoveryRequest) + returns (stream IncrementalDiscoveryResponse) { + } + rpc FetchClusters(DiscoveryRequest) returns (DiscoveryResponse) { option (google.api.http) = { post: "/v2/discovery:clusters" @@ -156,7 +160,13 @@ message Cluster { // :ref:`STRICT_DNS` // or :ref:`LOGICAL_DNS`, // then hosts is required. - repeated core.Address hosts = 7; + // + // .. attention:: + // + // **This field is deprecated**. Set the + // :ref:`load_assignment` field instead. + // + repeated core.Address hosts = 7 [deprecated = true]; // Setting this is required for specifying members of // :ref:`STATIC`, @@ -172,7 +182,6 @@ message Cluster { // :ref:`endpoint assignments`. // Setting this overrides :ref:`hosts` values. // - // [#not-implemented-hide:] ClusterLoadAssignment load_assignment = 33; // Optional :ref:`active health checking ` @@ -341,6 +350,18 @@ message Cluster { // weighted cluster contains the same keys and values as the subset's // metadata. The same host may appear in multiple subsets. repeated LbSubsetSelector subset_selectors = 3; + + // If true, routing to subsets will take into account the localities and locality weights of the + // endpoints when making the routing decision. + // + // There are some potential pitfalls associated with enabling this feature, as the resulting + // traffic split after applying both a subset match and locality weights might be undesirable. + // + // Consider for example a situation in which you have 50/50 split across two localities X/Y + // which have 100 hosts each without subsetting. If the subset LB results in X having only 1 + // host selected but Y having 100, then a lot more load is being dumped on the single host in X + // than originally anticipated in the load balancing assignment delivered via EDS. + bool locality_weight_aware = 4; } // Configuration for load balancing subsetting. @@ -416,6 +437,17 @@ message Cluster { ZoneAwareLbConfig zone_aware_lb_config = 2; LocalityWeightedLbConfig locality_weighted_lb_config = 3; } + // If set, all health check/weight/metadata updates that happen within this duration will be + // merged and delivered in one shot when the duration expires. The start of the duration is when + // the first update happens. This is useful for big clusters, with potentially noisy deploys + // that might trigger excessive CPU usage due to a constant stream of healthcheck state changes + // or metadata updates. By default, this is not configured and updates apply immediately. Also, + // the first set of updates to be seen apply immediately as well (e.g.: a new cluster). + // + // Note: merging does not apply to cluster membership changes (e.g.: adds/removes); this is + // because merging those updates isn't currently safe. See + // https://github.com/envoyproxy/envoy/pull/3941. + google.protobuf.Duration update_merge_window = 4; } // Common configuration for all load balancer implementations. diff --git a/api/envoy/api/v2/cluster/BUILD b/api/envoy/api/v2/cluster/BUILD index 16e759069359e..a3b091dea5f28 100644 --- a/api/envoy/api/v2/cluster/BUILD +++ b/api/envoy/api/v2/cluster/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "circuit_breaker", srcs = ["circuit_breaker.proto"], visibility = [ @@ -21,7 +21,7 @@ api_go_proto_library( ], ) -api_proto_library( +api_proto_library_internal( name = "outlier_detection", srcs = ["outlier_detection.proto"], visibility = [ diff --git a/api/envoy/api/v2/cluster/circuit_breaker.proto b/api/envoy/api/v2/cluster/circuit_breaker.proto index 19e378d779de2..1d574311d2009 100644 --- a/api/envoy/api/v2/cluster/circuit_breaker.proto +++ b/api/envoy/api/v2/cluster/circuit_breaker.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package envoy.api.v2.cluster; option go_package = "cluster"; +option csharp_namespace = "Envoy.Api.V2.ClusterNS"; import "envoy/api/v2/core/base.proto"; diff --git a/api/envoy/api/v2/cluster/outlier_detection.proto b/api/envoy/api/v2/cluster/outlier_detection.proto index 8fc873cbd08df..3ef961928d5b1 100644 --- a/api/envoy/api/v2/cluster/outlier_detection.proto +++ b/api/envoy/api/v2/cluster/outlier_detection.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package envoy.api.v2.cluster; +option csharp_namespace = "Envoy.Api.V2.ClusterNS"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; diff --git a/api/envoy/api/v2/core/BUILD b/api/envoy/api/v2/core/BUILD index 666315758b826..71a8d33f59d35 100644 --- a/api/envoy/api/v2/core/BUILD +++ b/api/envoy/api/v2/core/BUILD @@ -1,4 +1,4 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 @@ -16,7 +16,7 @@ package_group( ], ) -api_proto_library( +api_proto_library_internal( name = "address", srcs = ["address.proto"], visibility = [ @@ -31,7 +31,7 @@ api_go_proto_library( deps = [":base_go_proto"], ) -api_proto_library( +api_proto_library_internal( name = "base", srcs = ["base.proto"], visibility = [ @@ -44,7 +44,7 @@ api_go_proto_library( proto = ":base", ) -api_proto_library( +api_proto_library_internal( name = "health_check", srcs = ["health_check.proto"], visibility = [ @@ -59,7 +59,7 @@ api_go_proto_library( deps = [":base_go_proto"], ) -api_proto_library( +api_proto_library_internal( name = "config_source", srcs = ["config_source.proto"], visibility = [ @@ -80,7 +80,12 @@ api_go_proto_library( ], ) -api_proto_library( +api_go_proto_library( + name = "http_uri", + proto = ":http_uri", +) + +api_proto_library_internal( name = "http_uri", srcs = ["http_uri.proto"], visibility = [ @@ -88,12 +93,7 @@ api_proto_library( ], ) -api_go_proto_library( - name = "http_uri", - proto = ":http_uri", -) - -api_proto_library( +api_proto_library_internal( name = "grpc_service", srcs = ["grpc_service.proto"], visibility = [ @@ -108,7 +108,7 @@ api_go_proto_library( deps = [":base_go_proto"], ) -api_proto_library( +api_proto_library_internal( name = "protocol", srcs = ["protocol.proto"], visibility = [ diff --git a/api/envoy/api/v2/core/base.proto b/api/envoy/api/v2/core/base.proto index ffdd03fdfe595..1e86c529c46e2 100644 --- a/api/envoy/api/v2/core/base.proto +++ b/api/envoy/api/v2/core/base.proto @@ -133,7 +133,7 @@ enum RequestMethod { // Header name/value pair. message HeaderValue { // Header name. - string key = 1; + string key = 1 [(validate.rules).string.min_bytes = 1]; // Header value. // diff --git a/api/envoy/api/v2/core/config_source.proto b/api/envoy/api/v2/core/config_source.proto index 17bdbbeb28d9e..1ebb265da2ae4 100644 --- a/api/envoy/api/v2/core/config_source.proto +++ b/api/envoy/api/v2/core/config_source.proto @@ -28,7 +28,7 @@ message ApiConfigSource { GRPC = 2; } ApiType api_type = 1 [(validate.rules).enum.defined_only = true]; - // Multiple cluster names may be provided for REST_LEGACY/REST. If > 1 + // Cluster names should be used only with REST_LEGACY/REST. If > 1 // cluster is defined, clusters will be cycled through if any kind of failure // occurs. // @@ -40,11 +40,6 @@ message ApiConfigSource { // Multiple gRPC services be provided for GRPC. If > 1 cluster is defined, // services will be cycled through if any kind of failure occurs. - // - // .. note:: - // - // If a gRPC service points to a ``cluster_name``, it must be statically - // defined and its type must not be ``EDS``. repeated GrpcService grpc_services = 4; // For REST APIs, the delay between successive polls. diff --git a/api/envoy/api/v2/core/health_check.proto b/api/envoy/api/v2/core/health_check.proto index 55df7947e9f5f..ad35ca61536a0 100644 --- a/api/envoy/api/v2/core/health_check.proto +++ b/api/envoy/api/v2/core/health_check.proto @@ -21,15 +21,29 @@ option (gogoproto.equal_all) = true; message HealthCheck { // The time to wait for a health check response. If the timeout is reached the // health check attempt will be considered a failure. - google.protobuf.Duration timeout = 1 [(validate.rules).duration.required = true]; + google.protobuf.Duration timeout = 1 [ + (validate.rules).duration = {required: true, gt: {seconds: 0}}, + (gogoproto.stdduration) = true + ]; // The interval between health checks. - google.protobuf.Duration interval = 2 [(validate.rules).duration.required = true]; + google.protobuf.Duration interval = 2 [ + (validate.rules).duration = {required: true, gt: {seconds: 0}}, + (gogoproto.stdduration) = true + ]; // An optional jitter amount in millseconds. If specified, during every - // internal Envoy will add 0 to interval_jitter to the wait time. + // interval Envoy will add 0 to interval_jitter to the wait time. google.protobuf.Duration interval_jitter = 3; + // An optional jitter amount as a percentage of interval_ms. If specified, + // during every interval Envoy will add 0 to interval_ms * + // interval_jitter_percent / 100 to the wait time. + // + // If interval_jitter_ms and interval_jitter_percent are both set, both of + // them will be used to increase the wait time. + uint32 interval_jitter_percent = 18; + // The number of unhealthy health checks required before a host is marked // unhealthy. Note that for *http* health checking if a host responds with 503 // this threshold is ignored and the host is considered unhealthy immediately. @@ -81,7 +95,9 @@ message HealthCheck { string service_name = 5; // Specifies a list of HTTP headers that should be added to each request that is sent to the - // health checked cluster. + // health checked cluster. For more information, including details on header value syntax, see + // the documentation on :ref:`custom request headers + // `. repeated core.HeaderValueOption request_headers_to_add = 6; // If set, health checks will be made using http/2. @@ -138,9 +154,6 @@ message HealthCheck { // TCP health check. TcpHealthCheck tcp_health_check = 9; - // Redis health check. - RedisHealthCheck redis_health_check = 10; - // gRPC health check. GrpcHealthCheck grpc_health_check = 11; @@ -148,6 +161,10 @@ message HealthCheck { CustomHealthCheck custom_health_check = 13; } + reserved 10; // redis_health_check is deprecated by :ref:`custom_health_check + // ` + reserved "redis_health_check"; + // The "no traffic interval" is a special health check interval that is used when a cluster has // never had traffic routed to it. This lower interval allows cluster information to be kept up to // date, without sending a potentially large amount of active health checking traffic for no @@ -179,6 +196,10 @@ message HealthCheck { // // The default value for "healthy edge interval" is the same as the default interval. google.protobuf.Duration healthy_edge_interval = 16; + + // Specifies the path to the :ref:`health check event log `. + // If empty, no event log will be written. + string event_log_path = 17; } // Endpoint health status. diff --git a/api/envoy/api/v2/discovery.proto b/api/envoy/api/v2/discovery.proto index 74e7c5a2be965..f3ab1913d9146 100644 --- a/api/envoy/api/v2/discovery.proto +++ b/api/envoy/api/v2/discovery.proto @@ -93,3 +93,103 @@ message DiscoveryResponse { // required for non-stream based xDS implementations. string nonce = 5; } + +// IncrementalDiscoveryRequest and IncrementalDiscoveryResponse are used in a +// new gRPC endpoint for Incremental xDS. The feature is not supported for REST +// management servers. +// +// With Incremental xDS, the IncrementalDiscoveryResponses do not need to +// include a full snapshot of the tracked resources. Instead +// IncrementalDiscoveryResponses are a diff to the state of a xDS client. +// In Incremental XDS there are per resource versions which allows to track +// state at the resource granularity. +// An xDS Incremental session is always in the context of a gRPC bidirectional +// stream. This allows the xDS server to keep track of the state of xDS clients +// connected to it. +// +// In Incremental xDS the nonce field is required and used to pair +// IncrementalDiscoveryResponse to a IncrementalDiscoveryRequest ACK or NACK. +// Optionaly, a response message level system_version_info is present for +// debugging purposes only. +// +// IncrementalDiscoveryRequest can be sent in 3 situations: +// 1. Initial message in a xDS bidirectional gRPC stream. +// 2. As a ACK or NACK response to a previous IncrementalDiscoveryResponse. +// In this case the response_nonce is set to the nonce value in the Response. +// ACK or NACK is determined by the absence or presence of error_detail. +// 3. Spontaneous IncrementalDiscoveryRequest from the client. +// This can be done to dynamically add or remove elements from the tracked +// resource_names set. In this case response_nonce must be omitted. +message IncrementalDiscoveryRequest { + // The node making the request. + core.Node node = 1; + + // Type of the resource that is being requested, e.g. + // "type.googleapis.com/envoy.api.v2.ClusterLoadAssignment". This is implicit + // in requests made via singleton xDS APIs such as CDS, LDS, etc. but is + // required for ADS. + string type_url = 2; + + // IncrementalDiscoveryRequests allow the client to add or remove individual + // resources to the set of tracked resources in the context of a stream. + // All resource names in the resource_names_subscribe list are added to the + // set of tracked resources and all resource names in the resource_names_unsubscribe + // list are removed from the set of tracked resources. + // Unlike in non incremental xDS, an empty resource_names_subscribe or + // resource_names_unsubscribe list simply means that no resources are to be + // added or removed to the resource list. + // The xDS server must send updates for all tracked resources but can also + // send updates for resources the client has not subscribed to. This behavior + // is similar to non incremental xDS. + // These two fields can be set for all types of IncrementalDiscoveryRequests + // (initial, ACK/NACK or spontaneous). + // + // A list of Resource names to add to the list of tracked resources. + repeated string resource_names_subscribe = 3; + + // A list of Resource names to remove from the list of tracked resources. + repeated string resource_names_unsubscribe = 4; + + // This map must be populated when the IncrementalDiscoveryRequest is the + // first in a stream. The keys are the resources names of the xDS resources + // known to the xDS client. The values in the map are the associated resource + // level version info. + map initial_resource_versions = 5; + + // When the IncrementalDiscoveryRequest is a ACK or NACK message in response + // to a previous IncrementalDiscoveryResponse, the response_nonce must be the + // nonce in the IncrementalDiscoveryResponse. + // Otherwise response_nonce must be omitted. + string response_nonce = 6; + + // This is populated when the previous :ref:`DiscoveryResponse ` + // failed to update configuration. The *message* field in *error_details* + // provides the Envoy internal exception related to the failure. + google.rpc.Status error_detail = 7; +} + +message IncrementalDiscoveryResponse { + // The version of the response data (used for debugging). + string system_version_info = 1; + + // The response resources. These are typed resources that match the type url + // in the IncrementalDiscoveryRequest. + repeated Resource resources = 2 [(gogoproto.nullable) = false]; + + // Resources names of resources that have be deleted and to be removed from the xDS Client. + // Removed resources for missing resources can be ignored. + repeated string removed_resources = 6; + + // The nonce provides a way for IncrementalDiscoveryRequests to uniquely + // reference a IncrementalDiscoveryResponse. The nonce is required. + string nonce = 5; +} + +message Resource { + // The resource level version. It allows xDS to track the state of individual + // resources. + string version = 1; + + // The resource being tracked. + google.protobuf.Any resource = 2; +} diff --git a/api/envoy/api/v2/eds.proto b/api/envoy/api/v2/eds.proto index 0c63fbaa58484..44a1fe0a97496 100644 --- a/api/envoy/api/v2/eds.proto +++ b/api/envoy/api/v2/eds.proto @@ -6,6 +6,7 @@ option java_generic_services = true; import "envoy/api/v2/discovery.proto"; import "envoy/api/v2/endpoint/endpoint.proto"; +import "envoy/type/percent.proto"; import "google/api/annotations.proto"; @@ -50,12 +51,35 @@ message ClusterLoadAssignment { // Load balancing policy settings. message Policy { - // Percentage of traffic (0-100) that should be dropped. This - // action allows protection of upstream hosts should they unable to - // recover from an outage or should they be unable to autoscale and hence - // overall incoming traffic volume need to be trimmed to protect them. - // [#v2-api-diff: This is known as maintenance mode in v1.] - double drop_overload = 1 [(validate.rules).double = {gte: 0, lte: 100}]; + reserved 1; + + message DropOverload { + // Identifier for the policy specifying the drop. + string category = 1 [(validate.rules).string.min_bytes = 1]; + + // Percentage of traffic that should be dropped for the category. + envoy.type.FractionalPercent drop_percentage = 2; + } + // Action to trim the overall incoming traffic to protect the upstream + // hosts. This action allows protection in case the hosts are unable to + // recover from an outage, or unable to autoscale or unable to handle + // incoming traffic volume for any reason. + // + // At the client each category is applied one after the other to generate + // the 'actual' drop percentage on all outgoing traffic. For example: + // + // .. code-block:: json + // + // { "drop_overloads": [ + // { "category": "throttle", "drop_percentage": 60 } + // { "category": "lb", "drop_percentage": 50 } + // ]} + // + // The actual drop percentages applied to the traffic at the clients will be + // "throttle"_drop = 60% + // "lb"_drop = 20% // 50% of the remaining 'actual' load, which is 40%. + // actual_outgoing_load = 20% // remaining after applying all categories. + repeated DropOverload drop_overloads = 2; } // Load balancing policy settings. diff --git a/api/envoy/api/v2/endpoint/BUILD b/api/envoy/api/v2/endpoint/BUILD index 14808743df530..87884fe33342b 100644 --- a/api/envoy/api/v2/endpoint/BUILD +++ b/api/envoy/api/v2/endpoint/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "endpoint", srcs = ["endpoint.proto"], visibility = ["//envoy/api/v2:friends"], @@ -29,7 +29,7 @@ api_go_proto_library( ], ) -api_proto_library( +api_proto_library_internal( name = "load_report", srcs = ["load_report.proto"], visibility = ["//envoy/api/v2:friends"], diff --git a/api/envoy/api/v2/endpoint/endpoint.proto b/api/envoy/api/v2/endpoint/endpoint.proto index 6f4cad1ce66e9..c1983f175440e 100644 --- a/api/envoy/api/v2/endpoint/endpoint.proto +++ b/api/envoy/api/v2/endpoint/endpoint.proto @@ -29,7 +29,7 @@ message Endpoint { // and will be resolved via DNS. core.Address address = 1; - // [#not-implemented-hide:] The optional health check configuration. + // The optional health check configuration. message HealthCheckConfig { // Optional alternative health check port value. // @@ -37,11 +37,11 @@ message Endpoint { // as the host's serving address port. This provides an alternative health // check port. Setting this with a non-zero value allows an upstream host // to have different health check address port. - uint32 port_value = 1; + uint32 port_value = 1 [(validate.rules).uint32.lte = 65535]; } - // [#not-implemented-hide:] The optional health check configuration is used as - // configuration for the health checker to contact the health checked host. + // The optional health check configuration is used as configuration for the + // health checker to contact the health checked host. // // .. attention:: // @@ -123,5 +123,5 @@ message LocalityLbEndpoints { // next highest priority group. // // Priorities should range from 0 (highest) to N (lowest) without skipping. - uint32 priority = 5; + uint32 priority = 5 [(validate.rules).uint32 = {lte: 128}]; } diff --git a/api/envoy/api/v2/endpoint/load_report.proto b/api/envoy/api/v2/endpoint/load_report.proto index b61a0025a7a05..45ca3a168bbaf 100644 --- a/api/envoy/api/v2/endpoint/load_report.proto +++ b/api/envoy/api/v2/endpoint/load_report.proto @@ -4,6 +4,8 @@ package envoy.api.v2.endpoint; import "envoy/api/v2/core/base.proto"; +import "google/protobuf/duration.proto"; + import "validate/validate.proto"; import "gogoproto/gogo.proto"; @@ -93,4 +95,20 @@ message ClusterStats { // The total number of dropped requests. This covers requests // deliberately dropped by the drop_overload policy and circuit breaking. uint64 total_dropped_requests = 3; + + message DroppedRequests { + // Identifier for the policy specifying the drop. + string category = 1 [(validate.rules).string.min_bytes = 1]; + // Total number of deliberately dropped requests for the category. + uint64 dropped_count = 2; + } + // Information about deliberately dropped requests for each category specified + // in the DropOverload policy. + repeated DroppedRequests dropped_requests = 5; + + // Period over which the actual load report occurred. This will be guaranteed to include every + // request reported. Due to system load and delays between the *LoadStatsRequest* sent from Envoy + // and the *LoadStatsResponse* message sent from the management server, this may be longer than + // the requested load reporting interval in the *LoadStatsResponse*. + google.protobuf.Duration load_report_interval = 4; } diff --git a/api/envoy/api/v2/listener/BUILD b/api/envoy/api/v2/listener/BUILD index e8f48a10b29f2..bfa6a1407107f 100644 --- a/api/envoy/api/v2/listener/BUILD +++ b/api/envoy/api/v2/listener/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "listener", srcs = ["listener.proto"], visibility = ["//envoy/api/v2:friends"], diff --git a/api/envoy/api/v2/listener/listener.proto b/api/envoy/api/v2/listener/listener.proto index f9436c24d1c32..1e8015dbb2446 100644 --- a/api/envoy/api/v2/listener/listener.proto +++ b/api/envoy/api/v2/listener/listener.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package envoy.api.v2.listener; option go_package = "listener"; +option csharp_namespace = "Envoy.Api.V2.ListenerNS"; import "envoy/api/v2/core/address.proto"; import "envoy/api/v2/auth/cert.proto"; @@ -54,10 +55,11 @@ message Filter { // // The following order applies: // -// [#comment:TODO(PiotrSikora): destination IP / ranges are going to be 1.] -// 1. Server name (e.g. SNI for TLS protocol), -// 2. Transport protocol. -// 3. Application protocols (e.g. ALPN for TLS protocol). +// 1. Destination port. +// 2. Destination IP address. +// 3. Server name (e.g. SNI for TLS protocol), +// 4. Transport protocol. +// 5. Application protocols (e.g. ALPN for TLS protocol). // // For criterias that allow ranges or wildcards, the most specific value in any // of the configured filter chains that matches the incoming connection is going @@ -70,9 +72,12 @@ message Filter { // // [#comment:TODO(PiotrSikora): Add support for configurable precedence of the rules] message FilterChainMatch { + // Optional destination port to consider when use_original_dst is set on the + // listener in determining a filter chain match. + google.protobuf.UInt32Value destination_port = 8 [(validate.rules).uint32 = {gte: 1, lte: 65535}]; + // If non-empty, an IP address and prefix length to match addresses when the // listener is bound to 0.0.0.0/:: or when use_original_dst is specified. - // [#not-implemented-hide:] repeated core.CidrRange prefix_ranges = 3; // If non-empty, an IP address and suffix length to match addresses when the @@ -96,11 +101,6 @@ message FilterChainMatch { // [#not-implemented-hide:] repeated google.protobuf.UInt32Value source_ports = 7; - // Optional destination port to consider when use_original_dst is set on the - // listener in determining a filter chain match. - // [#not-implemented-hide:] - google.protobuf.UInt32Value destination_port = 8; - // If non-empty, a list of server names (e.g. SNI for TLS protocol) to consider when determining // a filter chain match. Those values will be compared against the server names of a new // connection, when detected by one of the listener filters. @@ -148,20 +148,8 @@ message FilterChainMatch { // unless all connecting clients are known to use ALPN. repeated string application_protocols = 10; - // If non-empty, a list of server names (e.g. SNI for TLS protocol) to consider when determining - // a filter chain match. Those values will be compared against the server names of a new - // connection, when detected by one of the listener filters. - // - // The server name will be matched against all wildcard domains, i.e. ``www.example.com`` - // will be first matched against ``www.example.com``, then ``*.example.com``, then ``*.com``. - // - // Note that partial wildcards are not supported, and values like ``*w.example.com`` are invalid. - // - // .. attention:: - // - // Deprecated. Use :ref:`server_names ` - // instead. - repeated string sni_domains = 1 [deprecated = true]; + reserved 1; + reserved "sni_domains"; } // A filter chain wraps a set of match criteria, an option TLS context, a set of filters, and diff --git a/api/envoy/api/v2/ratelimit/BUILD b/api/envoy/api/v2/ratelimit/BUILD index 0c6497e63a1fa..6e640b04986c6 100644 --- a/api/envoy/api/v2/ratelimit/BUILD +++ b/api/envoy/api/v2/ratelimit/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "ratelimit", srcs = ["ratelimit.proto"], visibility = ["//envoy/api/v2:friends"], diff --git a/api/envoy/api/v2/rds.proto b/api/envoy/api/v2/rds.proto index e820852defc4a..00ac0145b301a 100644 --- a/api/envoy/api/v2/rds.proto +++ b/api/envoy/api/v2/rds.proto @@ -28,6 +28,10 @@ service RouteDiscoveryService { rpc StreamRoutes(stream DiscoveryRequest) returns (stream DiscoveryResponse) { } + rpc IncrementalRoutes(stream IncrementalDiscoveryRequest) + returns (stream IncrementalDiscoveryResponse) { + } + rpc FetchRoutes(DiscoveryRequest) returns (DiscoveryResponse) { option (google.api.http) = { post: "/v2/discovery:routes" diff --git a/api/envoy/api/v2/route/BUILD b/api/envoy/api/v2/route/BUILD index 09c6b2dd553e3..5bc60102532e4 100644 --- a/api/envoy/api/v2/route/BUILD +++ b/api/envoy/api/v2/route/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "route", srcs = ["route.proto"], visibility = ["//envoy/api/v2:friends"], diff --git a/api/envoy/api/v2/route/route.proto b/api/envoy/api/v2/route/route.proto index 9d972f9d9af36..a497a1871f9e9 100644 --- a/api/envoy/api/v2/route/route.proto +++ b/api/envoy/api/v2/route/route.proto @@ -75,7 +75,7 @@ message VirtualHost { // Specifies a list of HTTP headers that should be added to each request // handled by this virtual host. Headers specified at this level are applied - // after headers from enclosed :ref:`envoy_api_msg_route.RouteAction` and before headers from the + // after headers from enclosed :ref:`envoy_api_msg_route.Route` and before headers from the // enclosing :ref:`envoy_api_msg_RouteConfiguration`. For more information, including // details on header value syntax, see the documentation on :ref:`custom request headers // `. @@ -83,7 +83,7 @@ message VirtualHost { // Specifies a list of HTTP headers that should be added to each response // handled by this virtual host. Headers specified at this level are applied - // after headers from enclosed :ref:`envoy_api_msg_route.RouteAction` and before headers from the + // after headers from enclosed :ref:`envoy_api_msg_route.Route` and before headers from the // enclosing :ref:`envoy_api_msg_RouteConfiguration`. For more information, including // details on header value syntax, see the documentation on :ref:`custom request headers // `. @@ -148,6 +148,26 @@ message Route { // specific; see the :ref:`HTTP filter documentation ` for // if and how it is utilized. map per_filter_config = 8; + + // Specifies a set of headers that will be added to requests matching this + // route. Headers specified at this level are applied before headers from the + // enclosing :ref:`envoy_api_msg_route.VirtualHost` and + // :ref:`envoy_api_msg_RouteConfiguration`. For more information, including details on + // header value syntax, see the documentation on :ref:`custom request headers + // `. + repeated core.HeaderValueOption request_headers_to_add = 9; + + // Specifies a set of headers that will be added to responses to requests + // matching this route. Headers specified at this level are applied before + // headers from the enclosing :ref:`envoy_api_msg_route.VirtualHost` and + // :ref:`envoy_api_msg_RouteConfiguration`. For more information, including + // details on header value syntax, see the documentation on + // :ref:`custom request headers `. + repeated core.HeaderValueOption response_headers_to_add = 10; + + // Specifies a list of HTTP headers that should be removed from each response + // to requests matching this route. + repeated string response_headers_to_remove = 11; } // Compared to the :ref:`cluster ` field that specifies a @@ -176,8 +196,7 @@ message WeightedCluster { // Specifies a list of headers to be added to requests when this cluster is selected // through the enclosing :ref:`envoy_api_msg_route.RouteAction`. // Headers specified at this level are applied before headers from the enclosing - // :ref:`envoy_api_msg_route.RouteAction`, - // :ref:`envoy_api_msg_route.VirtualHost`, and + // :ref:`envoy_api_msg_route.Route`, :ref:`envoy_api_msg_route.VirtualHost`, and // :ref:`envoy_api_msg_RouteConfiguration`. For more information, including details on // header value syntax, see the documentation on :ref:`custom request headers // `. @@ -186,8 +205,7 @@ message WeightedCluster { // Specifies a list of headers to be added to responses when this cluster is selected // through the enclosing :ref:`envoy_api_msg_route.RouteAction`. // Headers specified at this level are applied before headers from the enclosing - // :ref:`envoy_api_msg_route.RouteAction`, - // :ref:`envoy_api_msg_route.VirtualHost`, and + // :ref:`envoy_api_msg_route.Route`, :ref:`envoy_api_msg_route.VirtualHost`, and // :ref:`envoy_api_msg_RouteConfiguration`. For more information, including details on // header value syntax, see the documentation on :ref:`custom request headers // `. @@ -282,10 +300,18 @@ message RouteMatch { repeated QueryParameterMatcher query_parameters = 7; } +// [#comment:next free field: 9] message CorsPolicy { // Specifies the origins that will be allowed to do CORS requests. + // + // An origin is allowed if either allow_origin or allow_origin_regex match. repeated string allow_origin = 1; + // Specifies regex patterns that match allowed origins. + // + // An origin is allowed if either allow_origin or allow_origin_regex match. + repeated string allow_origin_regex = 8; + // Specifies the content for the *access-control-allow-methods* header. string allow_methods = 2; @@ -305,7 +331,7 @@ message CorsPolicy { google.protobuf.BoolValue enabled = 7; } -// [#comment:next free field: 24] +// [#comment:next free field: 25] message RouteAction { oneof cluster_specifier { option (validate.required) = true; @@ -393,7 +419,9 @@ message RouteAction { google.protobuf.BoolValue auto_host_rewrite = 7; } - // Specifies the timeout for the route. If not specified, the default is 15s. + // Specifies the upstream timeout for the route. If not specified, the default is 15s. This + // spans between the point at which the entire downstream request (i.e. end-of-stream) has been + // processed and when the upstream response has been completely processed. // // .. note:: // @@ -415,8 +443,8 @@ message RouteAction { // :ref:`config_http_filters_router_x-envoy-max-retries`. google.protobuf.UInt32Value num_retries = 2; - // Specifies a non-zero timeout per retry attempt. This parameter is optional. - // The same conditions documented for + // Specifies a non-zero upstream timeout per retry attempt. This parameter is optional. The + // same conditions documented for // :ref:`config_http_filters_router_x-envoy-upstream-rq-per-try-timeout-ms` apply. // // .. note:: @@ -429,6 +457,27 @@ message RouteAction { google.protobuf.Duration per_try_timeout = 3 [(gogoproto.stdduration) = true]; } + // Specifies the idle timeout for the route. If not specified, there is no per-route idle timeout + // specified, although the connection manager wide :ref:`stream_idle_timeout + // ` + // will still apply. A value of 0 will completely disable the route's idle timeout, even if a + // connection manager stream idle timeout is configured. + // + // The idle timeout is distinct to :ref:`timeout + // `, which provides an upper bound + // on the upstream response time; :ref:`idle_timeout + // ` instead bounds the amount + // of time the request's stream may be idle. + // + // After header decoding, the idle timeout will apply on downstream and + // upstream request events. Each time an encode/decode event for headers or + // data is processed for the stream, the timer will be reset. If the timeout + // fires, the stream is terminated with a 408 Request Timeout error code if no + // upstream response header has been received, otherwise a stream reset + // occurs. + google.protobuf.Duration idle_timeout = 24 + [(validate.rules).duration.gt = {}, (gogoproto.stdduration) = true]; + // Indicates that the route has a retry policy. RetryPolicy retry_policy = 9; @@ -461,25 +510,14 @@ message RouteAction { // https://github.com/lyft/protoc-gen-validate/issues/42 is resolved.] core.RoutingPriority priority = 11; - // Specifies a set of headers that will be added to requests matching this - // route. Headers specified at this level are applied before headers from the - // enclosing :ref:`envoy_api_msg_route.VirtualHost` and - // :ref:`envoy_api_msg_RouteConfiguration`. For more information, including details on - // header value syntax, see the documentation on :ref:`custom request headers - // `. - repeated core.HeaderValueOption request_headers_to_add = 12; + // [#not-implemented-hide:] + repeated core.HeaderValueOption request_headers_to_add = 12 [deprecated = true]; - // Specifies a set of headers that will be added to responses to requests - // matching this route. Headers specified at this level are applied before - // headers from the enclosing :ref:`envoy_api_msg_route.VirtualHost` and - // :ref:`envoy_api_msg_RouteConfiguration`. For more information, including - // details on header value syntax, see the documentation on - // :ref:`custom request headers `. - repeated core.HeaderValueOption response_headers_to_add = 18; + // [#not-implemented-hide:] + repeated core.HeaderValueOption response_headers_to_add = 18 [deprecated = true]; - // Specifies a list of HTTP headers that should be removed from each response - // to requests matching this route. - repeated string response_headers_to_remove = 19; + // [#not-implemented-hide:] + repeated string response_headers_to_remove = 19 [deprecated = true]; // Specifies a set of rate limit configurations that could be applied to the // route. @@ -573,7 +611,7 @@ message RouteAction { // proxy data from the client to the upstream server. // // Redirects are not supported on routes where WebSocket upgrades are allowed. - google.protobuf.BoolValue use_websocket = 16; + google.protobuf.BoolValue use_websocket = 16 [deprecated = true]; message WebSocketProxyConfig { // See :ref:`stat_prefix @@ -603,7 +641,7 @@ message RouteAction { // Proxy configuration used for WebSocket connections. If unset, the default values as specified // in :ref:`TcpProxy ` are used. - WebSocketProxyConfig websocket_config = 22; + WebSocketProxyConfig websocket_config = 22 [deprecated = true]; // Indicates that the route has a CORS policy. CorsPolicy cors = 17; @@ -680,8 +718,9 @@ message DirectResponseAction { // // .. note:: // - // Headers can be specified using *response_headers_to_add* in - // :ref:`envoy_api_msg_RouteConfiguration`. + // Headers can be specified using *response_headers_to_add* in the enclosing + // :ref:`envoy_api_msg_route.Route`, :ref:`envoy_api_msg_RouteConfiguration` or + // :ref:`envoy_api_msg_route.VirtualHost`. core.DataSource body = 2; } @@ -888,37 +927,24 @@ message RateLimit { // "name": ":method", // "value": "POST" // } +// +// .. attention:: +// In the absence of any header match specifier, match will default to :ref:`present_match +// `. i.e, a request that has the :ref:`name +// ` header will match, regardless of the header's +// value. +// message HeaderMatcher { // Specifies the name of the header in the request. string name = 1 [(validate.rules).string.min_bytes = 1]; - // Specifies the value of the header. If the value is absent a request that - // has the name header will match, regardless of the header’s value. - // - // .. attention:: - // Deprecated. Use :ref:`exact_match ` instead. - string value = 2 [deprecated = true]; - - // Specifies whether the header value is a regular - // expression or not. Defaults to false. The entire request header value must match the regex. The - // rule will not match if only a subsequence of the request header value matches the regex. The - // regex grammar used in the value field is defined - // `here `_. - // - // Examples: - // - // * The regex *\d{3}* matches the value *123* - // * The regex *\d{3}* does not match the value *1234* - // * The regex *\d{3}* does not match the value *123.456* - // - // .. attention:: - // Deprecated. Use :ref:`regex_match ` instead. - google.protobuf.BoolValue regex = 3 [deprecated = true]; + reserved 2; // value deprecated by :ref:`exact_match + // ` + + reserved 3; // regex deprecated by :ref:`regex_match + // ` // Specifies how the header match will be performed to route the request. - // If header_match_specifier is absent, a request that has the - // :ref:`envoy_api_msg_route.HeaderMatcher.name` header will match, regardless of the header's - // value. oneof header_match_specifier { // If specified, header match will be performed based on the value of the header. string exact_match = 4; diff --git a/api/envoy/config/accesslog/v2/BUILD b/api/envoy/config/accesslog/v2/BUILD index 63bdc5c5283c6..63ef7b0ae8057 100644 --- a/api/envoy/config/accesslog/v2/BUILD +++ b/api/envoy/config/accesslog/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "als", srcs = ["als.proto"], deps = [ @@ -10,7 +10,7 @@ api_proto_library( ], ) -api_proto_library( +api_proto_library_internal( name = "file", srcs = ["file.proto"], ) diff --git a/api/envoy/config/bootstrap/v2/BUILD b/api/envoy/config/bootstrap/v2/BUILD index 9b97ffab07d1b..4024b11a13c56 100644 --- a/api/envoy/config/bootstrap/v2/BUILD +++ b/api/envoy/config/bootstrap/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "bootstrap", srcs = ["bootstrap.proto"], visibility = ["//visibility:public"], diff --git a/api/envoy/config/filter/accesslog/v2/BUILD b/api/envoy/config/filter/accesslog/v2/BUILD index fbab9f76ba4a2..3eedcf397000e 100644 --- a/api/envoy/config/filter/accesslog/v2/BUILD +++ b/api/envoy/config/filter/accesslog/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "accesslog", srcs = ["accesslog.proto"], visibility = [ diff --git a/api/envoy/config/filter/accesslog/v2/accesslog.proto b/api/envoy/config/filter/accesslog/v2/accesslog.proto index af78123e36091..6642560694eaf 100644 --- a/api/envoy/config/filter/accesslog/v2/accesslog.proto +++ b/api/envoy/config/filter/accesslog/v2/accesslog.proto @@ -61,6 +61,9 @@ message AccessLogFilter { // Header filter. HeaderFilter header_filter = 8; + + // Response flag filter. + ResponseFlagFilter response_flag_filter = 9; } } @@ -150,3 +153,15 @@ message HeaderFilter { // check. envoy.api.v2.route.HeaderMatcher header = 1 [(validate.rules).message.required = true]; } + +// Filters requests that received responses with an Envoy response flag set. +// A list of the response flags can be found +// in the access log formatter :ref:`documentation`. +message ResponseFlagFilter { + // Only responses with the any of the flags listed in this field will be logged. + // This field is optional. If it is not specified, then any response flag will pass + // the filter check. + repeated string flags = 1 [(validate.rules).repeated .items.string = { + in: ["LH", "UH", "UT", "LR", "UR", "UF", "UC", "UO", "NR", "DI", "FI", "RL", "UAEX"] + }]; +} diff --git a/api/envoy/config/filter/fault/v2/BUILD b/api/envoy/config/filter/fault/v2/BUILD index 0b4310f48e36e..9fba2fbed3e17 100644 --- a/api/envoy/config/filter/fault/v2/BUILD +++ b/api/envoy/config/filter/fault/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "fault", srcs = ["fault.proto"], visibility = [ diff --git a/api/envoy/config/filter/http/buffer/v2/BUILD b/api/envoy/config/filter/http/buffer/v2/BUILD index d2be36c572c4d..0460c2d43e3ef 100644 --- a/api/envoy/config/filter/http/buffer/v2/BUILD +++ b/api/envoy/config/filter/http/buffer/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "buffer", srcs = ["buffer.proto"], ) diff --git a/api/envoy/config/filter/http/ext_authz/v2alpha/BUILD b/api/envoy/config/filter/http/ext_authz/v2alpha/BUILD index 62e7fc3d64641..8ab214517f914 100644 --- a/api/envoy/config/filter/http/ext_authz/v2alpha/BUILD +++ b/api/envoy/config/filter/http/ext_authz/v2alpha/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "ext_authz", srcs = ["ext_authz.proto"], deps = [ diff --git a/api/envoy/config/filter/http/ext_authz/v2alpha/ext_authz.proto b/api/envoy/config/filter/http/ext_authz/v2alpha/ext_authz.proto index 9d602298ce170..c85a799cddcae 100644 --- a/api/envoy/config/filter/http/ext_authz/v2alpha/ext_authz.proto +++ b/api/envoy/config/filter/http/ext_authz/v2alpha/ext_authz.proto @@ -6,24 +6,18 @@ option go_package = "v2alpha"; import "envoy/api/v2/core/grpc_service.proto"; import "envoy/api/v2/core/http_uri.proto"; -// [#protodoc-title: HTTP External Authorization ] -// The external authorization HTTP service configuration +// [#protodoc-title: External Authorization ] +// The external authorization service configuration // :ref:`configuration overview `. -// [#not-implemented-hide:] -// [#comment: The HttpService is under development and will be supported soon.] -message HttpService { - // Sets the HTTP server URI which the authorization requests must be sent to. - envoy.api.v2.core.HttpUri server_uri = 1; - - // Sets an optional prefix to the value of authorization request header `path`. - string path_prefix = 2; -} - -// External Authorization filter calls out to an external service over the -// gRPC Authorization API defined by -// :ref:`CheckRequest `. -// A failed check will cause this filter to close the HTTP request with 403(Forbidden). +// External Authorization filter calls out to an external service over either: +// +// 1. gRPC Authorization API defined by :ref:`CheckRequest +// `. +// 2. Raw HTTP Authorization server by passing the request headers to the service. +// +// A failed check will cause this filter to close the HTTP request normally with 403 (Forbidden), +// unless a different status code has been indicated in the authorization response. message ExtAuthz { oneof services { @@ -32,7 +26,7 @@ message ExtAuthz { envoy.api.v2.core.GrpcService grpc_service = 1; // The external authorization HTTP service configuration. - // [#not-implemented-hide:] + // The default timeout is set to 200ms by this filter. HttpService http_service = 3; } @@ -42,3 +36,30 @@ message ExtAuthz { // Defaults to false. bool failure_mode_allow = 2; } + +// External Authorization filter calls out to an upstream authorization server by passing the raw +// HTTP request headers to the server. This allows the authorization service to take a decision +// whether the request is authorized or not. +// +// A successful check allows the authorization service adding or overriding headers from the +// original request before dispatching it to the upstream. This is done by including the headers in +// the response sent back from the authorization service to the filter. Note that `Status`, +// `Method`, `Path` and `Content Length` response headers are automatically removed from this +// response by the filter. If other headers need be deleted, they should be specified in +// `response_headers_to_remove` field. +// +// A failed check will cause this filter to close the HTTP request normally with 403 (Forbidden), +// unless a different status code has been indicated by the authorization service via response +// headers. The HTTP service also allows the authorization filter to also pass data from the +// response body to the downstream client in case of a denied request. +message HttpService { + // Sets the HTTP server URI which the authorization requests must be sent to. + envoy.api.v2.core.HttpUri server_uri = 1; + + // Sets an optional prefix to the value of authorization request header `path`. + string path_prefix = 2; + + // Sets a list of headers that should be not be sent *from the authorization server* to the + // upstream. + repeated string response_headers_to_remove = 3; +} diff --git a/api/envoy/config/filter/http/fault/v2/BUILD b/api/envoy/config/filter/http/fault/v2/BUILD index 0c517c3e666db..7b414c48af121 100644 --- a/api/envoy/config/filter/http/fault/v2/BUILD +++ b/api/envoy/config/filter/http/fault/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "fault", srcs = ["fault.proto"], deps = [ diff --git a/api/envoy/config/filter/http/gzip/v2/BUILD b/api/envoy/config/filter/http/gzip/v2/BUILD index e1b592f4aee79..79c1076d7c77e 100644 --- a/api/envoy/config/filter/http/gzip/v2/BUILD +++ b/api/envoy/config/filter/http/gzip/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "gzip", srcs = ["gzip.proto"], ) diff --git a/api/envoy/config/filter/http/header_to_metadata/v2/BUILD b/api/envoy/config/filter/http/header_to_metadata/v2/BUILD index 102dd076346dc..67b45090a654f 100644 --- a/api/envoy/config/filter/http/header_to_metadata/v2/BUILD +++ b/api/envoy/config/filter/http/header_to_metadata/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "header_to_metadata", srcs = ["header_to_metadata.proto"], deps = [], diff --git a/api/envoy/config/filter/http/health_check/v2/BUILD b/api/envoy/config/filter/http/health_check/v2/BUILD index 1616f046e1fa6..37152bde6f020 100644 --- a/api/envoy/config/filter/http/health_check/v2/BUILD +++ b/api/envoy/config/filter/http/health_check/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "health_check", srcs = ["health_check.proto"], deps = [ diff --git a/api/envoy/config/filter/http/health_check/v2/health_check.proto b/api/envoy/config/filter/http/health_check/v2/health_check.proto index 88106e93136cd..0f584b451f68a 100644 --- a/api/envoy/config/filter/http/health_check/v2/health_check.proto +++ b/api/envoy/config/filter/http/health_check/v2/health_check.proto @@ -19,11 +19,8 @@ message HealthCheck { // Specifies whether the filter operates in pass through mode or not. google.protobuf.BoolValue pass_through_mode = 1 [(validate.rules).message.required = true]; - // Specifies the incoming HTTP endpoint that should be considered the - // health check endpoint. For example */healthcheck*. - // Note that this field is deprecated in favor of - // :ref:`headers `. - string endpoint = 2 [deprecated = true]; + reserved 2; + reserved "endpoint"; // If operating in pass through mode, the amount of time in milliseconds // that the filter should cache the upstream response. @@ -36,8 +33,6 @@ message HealthCheck { // Specifies a set of health check request headers to match on. The health check filter will // check a request’s headers against all the specified headers. To specify the health check - // endpoint, set the ``:path`` header to match on. Note that if the - // :ref:`endpoint ` - // field is set, it will overwrite any ``:path`` header to match. + // endpoint, set the ``:path`` header to match on. repeated envoy.api.v2.route.HeaderMatcher headers = 5; } diff --git a/api/envoy/config/filter/http/ip_tagging/v2/BUILD b/api/envoy/config/filter/http/ip_tagging/v2/BUILD index 147693b86c088..8a6c0ee5be259 100644 --- a/api/envoy/config/filter/http/ip_tagging/v2/BUILD +++ b/api/envoy/config/filter/http/ip_tagging/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "ip_tagging", srcs = ["ip_tagging.proto"], deps = ["//envoy/api/v2/core:address"], diff --git a/api/envoy/config/filter/http/jwt_authn/v2alpha/BUILD b/api/envoy/config/filter/http/jwt_authn/v2alpha/BUILD index cc07bd29bddaa..90863e3f5bed2 100644 --- a/api/envoy/config/filter/http/jwt_authn/v2alpha/BUILD +++ b/api/envoy/config/filter/http/jwt_authn/v2alpha/BUILD @@ -1,8 +1,8 @@ licenses(["notice"]) # Apache 2 -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") -api_proto_library( +api_proto_library_internal( name = "jwt_authn", srcs = ["config.proto"], deps = [ diff --git a/api/envoy/config/filter/http/jwt_authn/v2alpha/config.proto b/api/envoy/config/filter/http/jwt_authn/v2alpha/config.proto index d58e960c37f68..1350070f6806b 100644 --- a/api/envoy/config/filter/http/jwt_authn/v2alpha/config.proto +++ b/api/envoy/config/filter/http/jwt_authn/v2alpha/config.proto @@ -22,13 +22,13 @@ import "validate/validate.proto"; // issuer: https://example.com // audiences: // - bookstore_android.apps.googleusercontent.com -// bookstore_web.apps.googleusercontent.com +// - bookstore_web.apps.googleusercontent.com // remote_jwks: -// - http_uri: -// - uri: https://example.com/.well-known/jwks.json +// http_uri: +// uri: https://example.com/.well-known/jwks.json // cluster: example_jwks_cluster // cache_duration: -// - seconds: 300 +// seconds: 300 // // [#not-implemented-hide:] message JwtProvider { @@ -50,7 +50,7 @@ message JwtProvider { // // audiences: // - bookstore_android.apps.googleusercontent.com - // bookstore_web.apps.googleusercontent.com + // - bookstore_web.apps.googleusercontent.com // repeated string audiences = 2; @@ -67,11 +67,11 @@ message JwtProvider { // .. code-block:: yaml // // remote_jwks: - // - http_uri: - // - uri: https://www.googleapis.com/oauth2/v1/certs + // http_uri: + // uri: https://www.googleapis.com/oauth2/v1/certs // cluster: jwt.www.googleapis.com|443 // cache_duration: - // - seconds: 300 + // seconds: 300 // RemoteJwks remote_jwks = 3; @@ -83,14 +83,14 @@ message JwtProvider { // .. code-block:: yaml // // local_jwks: - // - filename: /etc/envoy/jwks/jwks1.txt + // filename: /etc/envoy/jwks/jwks1.txt // // Example: inline_string // // .. code-block:: yaml // // local_jwks: - // - inline_string: "ACADADADADA" + // inline_string: "ACADADADADA" // envoy.api.v2.core.DataSource local_jwks = 4; } @@ -163,7 +163,7 @@ message RemoteJwks { // .. code-block:: yaml // // http_uri: - // - uri: https://www.googleapis.com/oauth2/v1/certs + // uri: https://www.googleapis.com/oauth2/v1/certs // cluster: jwt.www.googleapis.com|443 // envoy.api.v2.core.HttpUri http_uri = 1; diff --git a/api/envoy/config/filter/http/lua/v2/BUILD b/api/envoy/config/filter/http/lua/v2/BUILD index ce571d9720db6..d399bc5b066be 100644 --- a/api/envoy/config/filter/http/lua/v2/BUILD +++ b/api/envoy/config/filter/http/lua/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "lua", srcs = ["lua.proto"], ) diff --git a/api/envoy/config/filter/http/rate_limit/v2/BUILD b/api/envoy/config/filter/http/rate_limit/v2/BUILD index 484e19c40d322..3b90a57c80ae3 100644 --- a/api/envoy/config/filter/http/rate_limit/v2/BUILD +++ b/api/envoy/config/filter/http/rate_limit/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "rate_limit", srcs = ["rate_limit.proto"], ) diff --git a/api/envoy/config/filter/http/rbac/v2/BUILD b/api/envoy/config/filter/http/rbac/v2/BUILD index e96a01d560f74..d325e3bcde2d7 100644 --- a/api/envoy/config/filter/http/rbac/v2/BUILD +++ b/api/envoy/config/filter/http/rbac/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "rbac", srcs = ["rbac.proto"], deps = ["//envoy/config/rbac/v2alpha:rbac"], diff --git a/api/envoy/config/filter/http/router/v2/BUILD b/api/envoy/config/filter/http/router/v2/BUILD index 00392ac7f98a5..38697ac806806 100644 --- a/api/envoy/config/filter/http/router/v2/BUILD +++ b/api/envoy/config/filter/http/router/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "router", srcs = ["router.proto"], deps = ["//envoy/config/filter/accesslog/v2:accesslog"], diff --git a/api/envoy/config/filter/http/squash/v2/BUILD b/api/envoy/config/filter/http/squash/v2/BUILD index ea5e9c6c4c158..8cf2c80dde1e7 100644 --- a/api/envoy/config/filter/http/squash/v2/BUILD +++ b/api/envoy/config/filter/http/squash/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "squash", srcs = ["squash.proto"], ) diff --git a/api/envoy/config/filter/http/transcoder/v2/BUILD b/api/envoy/config/filter/http/transcoder/v2/BUILD index 087f8ce8cefb3..eddef5a7ebd03 100644 --- a/api/envoy/config/filter/http/transcoder/v2/BUILD +++ b/api/envoy/config/filter/http/transcoder/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "transcoder", srcs = ["transcoder.proto"], ) diff --git a/api/envoy/config/filter/network/client_ssl_auth/v2/BUILD b/api/envoy/config/filter/network/client_ssl_auth/v2/BUILD index d382848c92393..a6d31d6396111 100644 --- a/api/envoy/config/filter/network/client_ssl_auth/v2/BUILD +++ b/api/envoy/config/filter/network/client_ssl_auth/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "client_ssl_auth", srcs = ["client_ssl_auth.proto"], deps = ["//envoy/api/v2/core:address"], diff --git a/api/envoy/config/filter/network/ext_authz/v2/BUILD b/api/envoy/config/filter/network/ext_authz/v2/BUILD index 22dc891526f98..4d716dee9744a 100644 --- a/api/envoy/config/filter/network/ext_authz/v2/BUILD +++ b/api/envoy/config/filter/network/ext_authz/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "ext_authz", srcs = ["ext_authz.proto"], deps = ["//envoy/api/v2/core:grpc_service"], diff --git a/api/envoy/config/filter/network/http_connection_manager/v2/BUILD b/api/envoy/config/filter/network/http_connection_manager/v2/BUILD index da2f4ddabc10e..c89ea09ad2909 100644 --- a/api/envoy/config/filter/network/http_connection_manager/v2/BUILD +++ b/api/envoy/config/filter/network/http_connection_manager/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "http_connection_manager", srcs = ["http_connection_manager.proto"], deps = [ diff --git a/api/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto b/api/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto index 5087298e8c6fa..1e6999cf2077a 100644 --- a/api/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto +++ b/api/envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.proto @@ -19,7 +19,7 @@ import "gogoproto/gogo.proto"; // [#protodoc-title: HTTP connection manager] // HTTP connection manager :ref:`configuration overview `. -// [#comment:next free field: 24] +// [#comment:next free field: 25] message HttpConnectionManager { enum CodecType { option (gogoproto.goproto_enum_prefix) = false; @@ -137,6 +137,33 @@ message HttpConnectionManager { // `. google.protobuf.Duration idle_timeout = 11 [(gogoproto.stdduration) = true]; + // The stream idle timeout for connections managed by the connection manager. + // If not specified, this defaults to 5 minutes. The default value was selected + // so as not to interfere with any smaller configured timeouts that may have + // existed in configurations prior to the introduction of this feature, while + // introducing robustness to TCP connections that terminate without a FIN. + // + // This idle timeout applies to new streams and is overridable by the + // :ref:`route-level idle_timeout + // `. Even on a stream in + // which the override applies, prior to receipt of the initial request + // headers, the :ref:`stream_idle_timeout + // ` + // applies. Each time an encode/decode event for headers or data is processed + // for the stream, the timer will be reset. If the timeout fires, the stream + // is terminated with a 408 Request Timeout error code if no upstream response + // header has been received, otherwise a stream reset occurs. + // + // Note that it is possible to idle timeout even if the wire traffic for a stream is non-idle, due + // to the granularity of events presented to the connection manager. For example, while receiving + // very large request headers, it may be the case that there is traffic regularly arriving on the + // wire while the connection manage is only able to observe the end-of-headers event, hence the + // stream may still idle timeout. + // + // A value of 0 will completely disable the connection manager stream idle + // timeout, although per-route idle timeout overrides will continue to apply. + google.protobuf.Duration stream_idle_timeout = 24 [(gogoproto.stdduration) = true]; + // The time that Envoy will wait between sending an HTTP/2 “shutdown // notification” (GOAWAY frame with max stream ID) and a final GOAWAY frame. // This is used so that Envoy provides a grace period for new streams that @@ -221,9 +248,7 @@ message HttpConnectionManager { // Whether to forward the subject of the client cert. Defaults to false. google.protobuf.BoolValue subject = 1; - // Whether to forward the URI type Subject Alternative Name of the client cert. Defaults to - // false. This field is deprecated, use URI field instead. - google.protobuf.BoolValue san = 2 [deprecated = true]; + reserved 2; // san deprecated by uri // Whether to forward the entire client cert in URL encoded PEM format. This will appear in the // XFCC header comma separated from other values with the value Cert="PEM". @@ -268,7 +293,6 @@ message HttpConnectionManager { // control. bool represent_ipv4_remote_address_as_ipv4_mapped_ipv6 = 20; - // [#not-implemented-hide:] // The configuration for HTTP upgrades. // For each upgrade type desired, an UpgradeConfig must be added. // @@ -277,6 +301,10 @@ message HttpConnectionManager { // The current implementation of upgrade headers does not handle // multi-valued upgrade headers. Support for multi-valued headers may be // added in the future if needed. + // + // .. warning:: + // The current implementation of upgrade headers does not work with HTTP/2 + // upstreams. message UpgradeConfig { // The case-insensitive name of this upgrade, e.g. "websocket". // For each upgrade type present in upgrade_configs, requests with @@ -288,7 +316,6 @@ message HttpConnectionManager { // HTTP connections will be used for this upgrade type. repeated HttpFilter filters = 2; }; - // [#not-implemented-hide:] repeated UpgradeConfig upgrade_configs = 23; } diff --git a/api/envoy/config/filter/network/mongo_proxy/v2/BUILD b/api/envoy/config/filter/network/mongo_proxy/v2/BUILD index 03bc303476752..69b0f85e156d0 100644 --- a/api/envoy/config/filter/network/mongo_proxy/v2/BUILD +++ b/api/envoy/config/filter/network/mongo_proxy/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "mongo_proxy", srcs = ["mongo_proxy.proto"], deps = ["//envoy/config/filter/fault/v2:fault"], diff --git a/api/envoy/config/filter/network/rate_limit/v2/BUILD b/api/envoy/config/filter/network/rate_limit/v2/BUILD index b1936e3bb2c2f..2cda26cfde99e 100644 --- a/api/envoy/config/filter/network/rate_limit/v2/BUILD +++ b/api/envoy/config/filter/network/rate_limit/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "rate_limit", srcs = ["rate_limit.proto"], deps = ["//envoy/api/v2/ratelimit"], diff --git a/api/envoy/config/filter/network/redis_proxy/v2/BUILD b/api/envoy/config/filter/network/redis_proxy/v2/BUILD index 78f269301fe05..c35e219b44659 100644 --- a/api/envoy/config/filter/network/redis_proxy/v2/BUILD +++ b/api/envoy/config/filter/network/redis_proxy/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "redis_proxy", srcs = ["redis_proxy.proto"], ) diff --git a/api/envoy/config/filter/network/tcp_proxy/v2/BUILD b/api/envoy/config/filter/network/tcp_proxy/v2/BUILD index 2e7296fa3f969..7cb467d6fb10d 100644 --- a/api/envoy/config/filter/network/tcp_proxy/v2/BUILD +++ b/api/envoy/config/filter/network/tcp_proxy/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "tcp_proxy", srcs = ["tcp_proxy.proto"], deps = [ diff --git a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/BUILD b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/BUILD new file mode 100644 index 0000000000000..da39334babd17 --- /dev/null +++ b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/BUILD @@ -0,0 +1,11 @@ +load("//bazel:api_build_system.bzl", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "thrift_proxy", + srcs = [ + "route.proto", + "thrift_proxy.proto", + ], +) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/README.md b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/README.md similarity index 100% rename from api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/README.md rename to api/envoy/config/filter/network/thrift_proxy/v2alpha1/README.md diff --git a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/route.proto b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/route.proto new file mode 100644 index 0000000000000..f70523e57f212 --- /dev/null +++ b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/route.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +package envoy.config.filter.network.thrift_proxy.v2alpha1; +option go_package = "v2"; + +import "validate/validate.proto"; +import "gogoproto/gogo.proto"; + +// [#protodoc-title: Thrift route configuration] + +// [#comment:next free field: 3] +message RouteConfiguration { + // The name of the route configuration. Reserved for future use in asynchronous route discovery. + string name = 1; + + // The list of routes that will be matched, in order, against incoming requests. The first route + // that matches will be used. + repeated Route routes = 2 [(gogoproto.nullable) = false]; +} + +// [#comment:next free field: 3] +message Route { + // Route matching prarameters. + RouteMatch match = 1 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; + + // Route request to some upstream cluster. + RouteAction route = 2 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; +} + +// [#comment:next free field: 2] +message RouteMatch { + // If specified, the route must exactly match the request method name. As a special case, an + // empty string matches any request method name. + string method = 1; +} + +// [#comment:next free field: 2] +message RouteAction { + // Indicates the upstream cluster to which the request should be routed. + string cluster = 1 [(validate.rules).string.min_bytes = 1]; +} diff --git a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/router/BUILD b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/router/BUILD new file mode 100644 index 0000000000000..ce0ad0e254f03 --- /dev/null +++ b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/router/BUILD @@ -0,0 +1,8 @@ +load("//bazel:api_build_system.bzl", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "router", + srcs = ["router.proto"], +) diff --git a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/router/router.proto b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/router/router.proto new file mode 100644 index 0000000000000..5ad9863b07dec --- /dev/null +++ b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/router/router.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package envoy.config.filter.network.thrift_proxy.v2alpha1.router; +option go_package = "router"; + +// [#protodoc-title: Thrift Router] +// Thrift Router configuration. +message Router { +} diff --git a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto new file mode 100644 index 0000000000000..1a7176dc33031 --- /dev/null +++ b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +package envoy.config.filter.network.thrift_proxy.v2alpha1; +option go_package = "v2"; + +import "envoy/config/filter/network/thrift_proxy/v2alpha1/route.proto"; + +import "validate/validate.proto"; +import "gogoproto/gogo.proto"; + +// [#protodoc-title: Extensions Thrift Proxy] +// Thrift Proxy filter configuration. +// [#comment:next free field: 5] +message ThriftProxy { + enum TransportType { + option (gogoproto.goproto_enum_prefix) = false; + + // For every new connection, the Thrift proxy will determine which transport to use. + AUTO_TRANSPORT = 0; + + // The Thrift proxy will assume the client is using the Thrift framed transport. + FRAMED = 1; + + // The Thrift proxy will assume the client is using the Thrift unframed transport. + UNFRAMED = 2; + } + + // Supplies the type of transport that the Thrift proxy should use. Defaults to `AUTO_TRANSPORT`. + TransportType transport = 2 [(validate.rules).enum.defined_only = true]; + + enum ProtocolType { + option (gogoproto.goproto_enum_prefix) = false; + + // For every new connection, the Thrift proxy will determine which protocol to use. + // N.B. The older, non-strict binary protocol is not included in automatic protocol + // detection. + AUTO_PROTOCOL = 0; + + // The Thrift proxy will assume the client is using the Thrift binary protocol. + BINARY = 1; + + // The Thrift proxy will assume the client is using the Thrift non-strict binary protocol. + LAX_BINARY = 2; + + // The Thrift proxy will assume the client is using the Thrift compact protocol. + COMPACT = 3; + } + + // Supplies the type of protocol that the Thrift proxy should use. Defaults to `AUTO_PROTOCOL`. + ProtocolType protocol = 3 [(validate.rules).enum.defined_only = true]; + + // The human readable prefix to use when emitting statistics. + string stat_prefix = 1 [(validate.rules).string.min_bytes = 1]; + + // The route table for the connection manager is static and is specified in this property. + RouteConfiguration route_config = 4; +} diff --git a/api/envoy/config/grpc_credential/v2alpha/BUILD b/api/envoy/config/grpc_credential/v2alpha/BUILD index 09f3e691f63a6..ca0a71eaef6cc 100644 --- a/api/envoy/config/grpc_credential/v2alpha/BUILD +++ b/api/envoy/config/grpc_credential/v2alpha/BUILD @@ -1,8 +1,8 @@ licenses(["notice"]) # Apache 2 -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") -api_proto_library( +api_proto_library_internal( name = "file_based_metadata", srcs = ["file_based_metadata.proto"], deps = ["//envoy/api/v2/core:base"], diff --git a/api/envoy/config/health_checker/redis/v2/BUILD b/api/envoy/config/health_checker/redis/v2/BUILD index 7d217c54dda8c..b784e8d150621 100644 --- a/api/envoy/config/health_checker/redis/v2/BUILD +++ b/api/envoy/config/health_checker/redis/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "redis", srcs = ["redis.proto"], ) diff --git a/api/envoy/config/metrics/v2/BUILD b/api/envoy/config/metrics/v2/BUILD index 1c682a133a420..9d061aeb918e6 100644 --- a/api/envoy/config/metrics/v2/BUILD +++ b/api/envoy/config/metrics/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "metrics_service", srcs = ["metrics_service.proto"], visibility = [ @@ -21,7 +21,7 @@ api_go_proto_library( ], ) -api_proto_library( +api_proto_library_internal( name = "stats", srcs = ["stats.proto"], visibility = [ diff --git a/api/envoy/config/metrics/v2/stats.proto b/api/envoy/config/metrics/v2/stats.proto index 8d128611cedde..121f59a9f85a2 100644 --- a/api/envoy/config/metrics/v2/stats.proto +++ b/api/envoy/config/metrics/v2/stats.proto @@ -21,6 +21,7 @@ message StatsSink { // * :ref:`envoy.statsd ` // * :ref:`envoy.dog_statsd ` // * :ref:`envoy.metrics_service ` + // * :ref:`envoy.stat_sinks.hystrix ` // // Sinks optionally support tagged/multiple dimensional metrics. string name = 1; @@ -200,3 +201,27 @@ message DogStatsdSink { reserved 2; } + +// Stats configuration proto schema for built-in *envoy.stat_sinks.hystrix* sink. +// The sink emits stats in `text/event-stream +// `_ +// formatted stream for use by `Hystrix dashboard +// `_. +// +// Note that only a single HystrixSink should be configured. +// +// Streaming is started through an admin endpoint :http:get:`/hystrix_event_stream`. +message HystrixSink { + // The number of buckets the rolling statistical window is divided into. + // + // Each time the sink is flushed, all relevant Envoy statistics are sampled and + // added to the rolling window (removing the oldest samples in the window + // in the process). The sink then outputs the aggregate statistics across the + // current rolling window to the event stream(s). + // + // rolling_window(ms) = stats_flush_interval(ms) * num_of_buckets + // + // More detailed explanation can be found in `Hystix wiki + // `_. + int64 num_buckets = 1; +} diff --git a/api/envoy/config/overload/v2alpha/BUILD b/api/envoy/config/overload/v2alpha/BUILD new file mode 100644 index 0000000000000..ef06407fb9ea0 --- /dev/null +++ b/api/envoy/config/overload/v2alpha/BUILD @@ -0,0 +1,8 @@ +load("//bazel:api_build_system.bzl", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "overload", + srcs = ["overload.proto"], +) diff --git a/api/envoy/config/overload/v2alpha/overload.proto b/api/envoy/config/overload/v2alpha/overload.proto new file mode 100644 index 0000000000000..6b70d11d3243f --- /dev/null +++ b/api/envoy/config/overload/v2alpha/overload.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package envoy.config.overload.v2alpha; +option go_package = "v2alpha"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/struct.proto"; + +import "validate/validate.proto"; + +// The Overload Manager provides an extensible framework to protect Envoy instances +// from overload of various resources (memory, cpu, file descriptors, etc) + +message EmptyConfig { +} + +message ResourceMonitor { + // The name of the resource monitor to instantiate. Must match a registered + // resource monitor type. + string name = 1 [(validate.rules).string.min_bytes = 1]; + + // Configuration for the resource monitor being instantiated. + google.protobuf.Struct config = 2; +} + +message ThresholdTrigger { + // If the resource pressure is greater than or equal to this value, the trigger + // will fire. + double value = 1 [(validate.rules).double = {gte: 0, lte: 1}]; +} + +message Trigger { + // The name of the resource this is a trigger for. + string name = 1 [(validate.rules).string.min_bytes = 1]; + + oneof trigger_oneof { + option (validate.required) = true; + ThresholdTrigger threshold = 2; + } +} + +message OverloadAction { + // The name of the overload action. This is just a well-known string that listeners can + // use for registering callbacks. Custom overload actions should be named using reverse + // DNS to ensure uniqueness. + string name = 1 [(validate.rules).string.min_bytes = 1]; + + // A set of triggers for this action. If any of these triggers fires the overload action + // is activated. Listeners are notified when the overload action transitions from + // inactivated to activated, or vice versa. + repeated Trigger triggers = 2 [(validate.rules).repeated .min_items = 1]; +} + +message OverloadManager { + // The interval for refreshing resource usage. + google.protobuf.Duration refresh_interval = 1; + + // The set of resources to monitor. + repeated ResourceMonitor resource_monitors = 2 [(validate.rules).repeated .min_items = 1]; + + // The set of overload actions. + repeated OverloadAction actions = 3 [(validate.rules).repeated .min_items = 1]; +} diff --git a/api/envoy/config/ratelimit/v2/BUILD b/api/envoy/config/ratelimit/v2/BUILD index 08bad146ed4a7..2e69326aa3b1d 100644 --- a/api/envoy/config/ratelimit/v2/BUILD +++ b/api/envoy/config/ratelimit/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "rls", srcs = ["rls.proto"], visibility = [ diff --git a/api/envoy/config/ratelimit/v2/rls.proto b/api/envoy/config/ratelimit/v2/rls.proto index b9ffe80f79dc4..3a0f5dbedb35a 100644 --- a/api/envoy/config/ratelimit/v2/rls.proto +++ b/api/envoy/config/ratelimit/v2/rls.proto @@ -29,9 +29,9 @@ message RateLimitServiceConfig { } // Specifies if Envoy should use the data-plane-api client - // :repo:`api/envoy/service/ratelimit/v2/rls.proto` or the legacy + // :repo:`api/envoy/service/ratelimit/v2/rls.proto` or the legacy // client :repo:`source/common/ratelimit/ratelimit.proto` when - // making requests to the rate limit service. + // making requests to the rate limit service. // // .. note:: // diff --git a/api/envoy/config/rbac/v2alpha/BUILD b/api/envoy/config/rbac/v2alpha/BUILD index 396982264e3a3..f24c8594ad2eb 100644 --- a/api/envoy/config/rbac/v2alpha/BUILD +++ b/api/envoy/config/rbac/v2alpha/BUILD @@ -1,15 +1,15 @@ licenses(["notice"]) # Apache 2 -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") -api_proto_library( +api_proto_library_internal( name = "rbac", srcs = ["rbac.proto"], visibility = ["//visibility:public"], deps = [ "//envoy/api/v2/core:address", "//envoy/api/v2/route", - "//envoy/type:string_match", + "//envoy/type/matcher:metadata", ], ) @@ -19,6 +19,6 @@ api_go_proto_library( deps = [ "//envoy/api/v2/core:address_go_proto", "//envoy/api/v2/route:route_go_proto", - "//envoy/type:string_match_go_proto", + "//envoy/type/matcher:metadata_go_proto", ], ) diff --git a/api/envoy/config/rbac/v2alpha/rbac.proto b/api/envoy/config/rbac/v2alpha/rbac.proto index cb9e53b5d9b12..ab32aaf475fd8 100644 --- a/api/envoy/config/rbac/v2alpha/rbac.proto +++ b/api/envoy/config/rbac/v2alpha/rbac.proto @@ -3,6 +3,7 @@ syntax = "proto3"; import "validate/validate.proto"; import "envoy/api/v2/core/address.proto"; import "envoy/api/v2/route/route.proto"; +import "envoy/type/matcher/metadata.proto"; package envoy.config.rbac.v2alpha; option go_package = "v2alpha"; @@ -15,12 +16,11 @@ option go_package = "v2alpha"; // // Here is an example of RBAC configuration. It has two policies: // -// * Service account "cluster.local/ns/default/sa/admin" has full access (empty permission entry -// means full access) to the service. +// * Service account "cluster.local/ns/default/sa/admin" has full access to the service, and so +// does "cluster.local/ns/default/sa/superuser". // -// * Any user (empty principal entry means any user) can read ("GET") the service at paths with -// prefix "/products" or suffix "/reviews" when request header "version" set to either "v1" or -// "v2". +// * Any user can read ("GET") the service at paths with prefix "/products", so long as the +// destination port is either 80 or 443. // // .. code-block:: yaml // @@ -111,6 +111,14 @@ message Permission { // A port number that describes the destination port connecting to. uint32 destination_port = 6 [(validate.rules).uint32.lte = 65535]; + + // Metadata that describes additional information about the action. + envoy.type.matcher.MetadataMatcher metadata = 7; + + // Negates matching the provided permission. For instance, if the value of `not_rule` would + // match, this permission would not match. Conversely, if the value of `not_rule` would not + // match, this permission would match. + Permission not_rule = 8; } } @@ -150,5 +158,13 @@ message Principal { // A header (or psuedo-header such as :path or :method) on the incoming HTTP request. envoy.api.v2.route.HeaderMatcher header = 6; + + // Metadata that describes additional information about the principal. + envoy.type.matcher.MetadataMatcher metadata = 7; + + // Negates matching the provided principal. For instance, if the value of `not_id` would match, + // this principal would not match. Conversely, if the value of `not_id` would not match, this + // principal would match. + Principal not_id = 8; } } diff --git a/api/envoy/config/resource_monitor/fixed_heap/v2alpha/BUILD b/api/envoy/config/resource_monitor/fixed_heap/v2alpha/BUILD new file mode 100644 index 0000000000000..adc77e5b5e0d3 --- /dev/null +++ b/api/envoy/config/resource_monitor/fixed_heap/v2alpha/BUILD @@ -0,0 +1,8 @@ +load("//bazel:api_build_system.bzl", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "fixed_heap", + srcs = ["fixed_heap.proto"], +) diff --git a/api/envoy/config/resource_monitor/fixed_heap/v2alpha/fixed_heap.proto b/api/envoy/config/resource_monitor/fixed_heap/v2alpha/fixed_heap.proto new file mode 100644 index 0000000000000..08e3c6536f5d3 --- /dev/null +++ b/api/envoy/config/resource_monitor/fixed_heap/v2alpha/fixed_heap.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package envoy.config.resource_monitor.fixed_heap.v2alpha; +option go_package = "v2alpha"; + +message FixedHeapConfig { + // Limit of the Envoy process heap size. This is used to calculate heap memory pressure which + // is defined as (current heap size)/max_heap_size_bytes. + uint64 max_heap_size_bytes = 1; +} diff --git a/api/envoy/config/trace/v2/BUILD b/api/envoy/config/trace/v2/BUILD index b888bd1b8e400..518395f230707 100644 --- a/api/envoy/config/trace/v2/BUILD +++ b/api/envoy/config/trace/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "trace", srcs = ["trace.proto"], visibility = [ diff --git a/api/envoy/config/transport_socket/capture/v2alpha/BUILD b/api/envoy/config/transport_socket/capture/v2alpha/BUILD index 1786d008b9e7d..bd25da3e6c7ea 100644 --- a/api/envoy/config/transport_socket/capture/v2alpha/BUILD +++ b/api/envoy/config/transport_socket/capture/v2alpha/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "capture", srcs = ["capture.proto"], deps = [ diff --git a/api/envoy/data/accesslog/v2/BUILD b/api/envoy/data/accesslog/v2/BUILD index 21c1ea449e775..8ecfdd5b6d119 100644 --- a/api/envoy/data/accesslog/v2/BUILD +++ b/api/envoy/data/accesslog/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "accesslog", srcs = ["accesslog.proto"], visibility = [ diff --git a/api/envoy/data/core/v2alpha/BUILD b/api/envoy/data/core/v2alpha/BUILD new file mode 100644 index 0000000000000..740e4304cca72 --- /dev/null +++ b/api/envoy/data/core/v2alpha/BUILD @@ -0,0 +1,15 @@ +load("//bazel:api_build_system.bzl", "api_proto_library") + +licenses(["notice"]) # Apache 2 + +api_proto_library( + name = "health_check_event", + srcs = ["health_check_event.proto"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//envoy/api/v2/core:address", + "//envoy/api/v2/core:base", + ], +) diff --git a/api/envoy/data/core/v2alpha/health_check_event.proto b/api/envoy/data/core/v2alpha/health_check_event.proto new file mode 100644 index 0000000000000..5c9e28f6846dd --- /dev/null +++ b/api/envoy/data/core/v2alpha/health_check_event.proto @@ -0,0 +1,58 @@ +syntax = "proto3"; + +package envoy.data.core.v2alpha; + +import "envoy/api/v2/core/address.proto"; +import "envoy/api/v2/core/base.proto"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/wrappers.proto"; + +import "validate/validate.proto"; +import "gogoproto/gogo.proto"; + +option (gogoproto.equal_all) = true; + +// [#protodoc-title: Health check logging events] +// :ref:`Health check logging `. + +message HealthCheckEvent { + HealthCheckerType health_checker_type = 1 [(validate.rules).enum.defined_only = true]; + envoy.api.v2.core.Address host = 2; + string cluster_name = 3 [(validate.rules).string.min_bytes = 1]; + + oneof event { + option (validate.required) = true; + + // Host ejection. + HealthCheckEjectUnhealthy eject_unhealthy_event = 4; + + // Host addition. + HealthCheckAddHealthy add_healthy_event = 5; + } +} + +enum HealthCheckFailureType { + ACTIVE = 0; + PASSIVE = 1; + NETWORK = 2; +} + +enum HealthCheckerType { + HTTP = 0; + TCP = 1; + GRPC = 2; + REDIS = 3; +} + +message HealthCheckEjectUnhealthy { + // The type of failure that caused this ejection. + HealthCheckFailureType failure_type = 1 [(validate.rules).enum.defined_only = true]; +} + +message HealthCheckAddHealthy { + // Whether this addition is the result of the first ever health check on a host, in which case + // the configured :ref:`healthy threshold ` + // is bypassed and the host is immediately added. + bool first_check = 1; +} diff --git a/api/envoy/data/tap/v2alpha/BUILD b/api/envoy/data/tap/v2alpha/BUILD index 2211bb37ca5bc..46de68e3a825b 100644 --- a/api/envoy/data/tap/v2alpha/BUILD +++ b/api/envoy/data/tap/v2alpha/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "capture", srcs = ["capture.proto"], deps = ["//envoy/api/v2/core:address"], diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD deleted file mode 100644 index 19eea2ec3bfa2..0000000000000 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD +++ /dev/null @@ -1,8 +0,0 @@ -load("//bazel:api_build_system.bzl", "api_proto_library") - -licenses(["notice"]) # Apache 2 - -api_proto_library( - name = "thrift_proxy", - srcs = ["thrift_proxy.proto"], -) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto deleted file mode 100644 index e2d6bd02cb261..0000000000000 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto +++ /dev/null @@ -1,13 +0,0 @@ -syntax = "proto3"; - -package envoy.extensions.filters.network.thrift_proxy.v2alpha1; -option go_package = "v2"; - -import "validate/validate.proto"; - -// [#protodoc-title: Extensions Thrift Proxy] -// Thrift Proxy filter configuration. -message ThriftProxy { - // The human readable prefix to use when emitting statistics. - string stat_prefix = 1 [(validate.rules).string.min_bytes = 1]; -} diff --git a/api/envoy/service/accesslog/v2/BUILD b/api/envoy/service/accesslog/v2/BUILD index c5073996f7189..e6e389e22a02f 100644 --- a/api/envoy/service/accesslog/v2/BUILD +++ b/api/envoy/service/accesslog/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "als", srcs = ["als.proto"], has_services = 1, diff --git a/api/envoy/service/auth/v2alpha/BUILD b/api/envoy/service/auth/v2alpha/BUILD index 323a49eee7dff..5faba48ac3dbc 100644 --- a/api/envoy/service/auth/v2alpha/BUILD +++ b/api/envoy/service/auth/v2alpha/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "attribute_context", srcs = [ "attribute_context.proto", @@ -12,7 +12,7 @@ api_proto_library( ], ) -api_proto_library( +api_proto_library_internal( name = "external_auth", srcs = [ "external_auth.proto", @@ -20,5 +20,7 @@ api_proto_library( has_services = 1, deps = [ ":attribute_context", + "//envoy/api/v2/core:base", + "//envoy/type:http_status", ], ) diff --git a/api/envoy/service/auth/v2alpha/external_auth.proto b/api/envoy/service/auth/v2alpha/external_auth.proto index 601c4dea6c218..caa5e3089573b 100644 --- a/api/envoy/service/auth/v2alpha/external_auth.proto +++ b/api/envoy/service/auth/v2alpha/external_auth.proto @@ -4,6 +4,8 @@ package envoy.service.auth.v2alpha; option go_package = "v2alpha"; option java_generic_services = true; +import "envoy/api/v2/core/base.proto"; +import "envoy/type/http_status.proto"; import "envoy/service/auth/v2alpha/attribute_context.proto"; import "google/rpc/status.proto"; @@ -27,21 +29,45 @@ message CheckRequest { AttributeContext attributes = 1; } +// HTTP attributes for a denied response. +message DeniedHttpResponse { + // This field allows the authorization service to send a HTTP response status + // code to the downstream client other than 403 (Forbidden). + envoy.type.HttpStatus status = 1 [(validate.rules).message.required = true]; + + // This field allows the authorization service to send HTTP response headers + // to the the downstream client. + repeated envoy.api.v2.core.HeaderValueOption headers = 2; + + // This field allows the authorization service to send a response body data + // to the the downstream client. + string body = 3; +} + +// HTTP attributes for an ok response. +message OkHttpResponse { + // HTTP entity headers in addition to the original request headers. This allows the authorization + // service to append, to add or to override headers from the original request before + // dispatching it to the upstream. By setting `append` field to `true` in the `HeaderValueOption`, + // the filter will append the correspondent header value to the matched request header. Note that + // by Leaving `append` as false, the filter will either add a new header, or override an existing + // one if there is a match. + repeated envoy.api.v2.core.HeaderValueOption headers = 2; +} + +// Intended for gRPC and Network Authorization servers `only`. message CheckResponse { // Status `OK` allows the request. Any other status indicates the request should be denied. google.rpc.Status status = 1; - // An optional message that contains HTTP response attributes. This message is + // An message that contains HTTP response attributes. This message is // used when the authorization service needs to send custom responses to the // downstream client or, to modify/add request headers being dispatched to the upstream. - message HttpResponse { - // Http status code. - uint32 status_code = 1 [(validate.rules).uint32 = {gte: 100, lt: 600}]; - - // Http entity headers. - map headers = 2; + oneof http_response { + // Supplies http attributes for a denied response. + DeniedHttpResponse denied_response = 2; - // Http entity body. - string body = 3; + // Supplies http attributes for an ok response. + OkHttpResponse ok_response = 3; } } diff --git a/api/envoy/service/discovery/v2/BUILD b/api/envoy/service/discovery/v2/BUILD index f0a67f206c8dd..ac652cf1859a4 100644 --- a/api/envoy/service/discovery/v2/BUILD +++ b/api/envoy/service/discovery/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "ads", srcs = ["ads.proto"], has_services = 1, @@ -19,7 +19,7 @@ api_go_grpc_library( ], ) -api_proto_library( +api_proto_library_internal( name = "hds", srcs = ["hds.proto"], has_services = 1, @@ -40,7 +40,7 @@ api_go_grpc_library( ], ) -api_proto_library( +api_proto_library_internal( name = "sds", srcs = ["sds.proto"], has_services = 1, diff --git a/api/envoy/service/discovery/v2/ads.proto b/api/envoy/service/discovery/v2/ads.proto index 821ccb341db52..16953ee7b9a6c 100644 --- a/api/envoy/service/discovery/v2/ads.proto +++ b/api/envoy/service/discovery/v2/ads.proto @@ -27,4 +27,8 @@ service AggregatedDiscoveryService { rpc StreamAggregatedResources(stream envoy.api.v2.DiscoveryRequest) returns (stream envoy.api.v2.DiscoveryResponse) { } + + rpc IncrementalAggregatedResources(stream envoy.api.v2.IncrementalDiscoveryRequest) + returns (stream envoy.api.v2.IncrementalDiscoveryResponse) { + } } diff --git a/api/envoy/service/load_stats/v2/BUILD b/api/envoy/service/load_stats/v2/BUILD index 4068eafb3e973..66294100bf701 100644 --- a/api/envoy/service/load_stats/v2/BUILD +++ b/api/envoy/service/load_stats/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "lrs", srcs = ["lrs.proto"], has_services = 1, diff --git a/api/envoy/service/load_stats/v2/lrs.proto b/api/envoy/service/load_stats/v2/lrs.proto index 2181fa0ff16f9..43971649504fa 100644 --- a/api/envoy/service/load_stats/v2/lrs.proto +++ b/api/envoy/service/load_stats/v2/lrs.proto @@ -63,6 +63,12 @@ message LoadStatsResponse { // Clusters to report stats for. repeated string clusters = 1 [(validate.rules).repeated .min_items = 1]; - // The interval of time to collect stats. The default is 10 seconds. + // The minimum interval of time to collect stats over. This is only a minimum for two reasons: + // 1. There may be some delay from when the timer fires until stats sampling occurs. + // 2. For clusters that were already feature in the previous *LoadStatsResponse*, any traffic + // that is observed in between the corresponding previous *LoadStatsRequest* and this + // *LoadStatsResponse* will also be accumulated and billed to the cluster. This avoids a period + // of inobservability that might otherwise exists between the messages. New clusters are not + // subject to this consideration. google.protobuf.Duration load_reporting_interval = 2; } diff --git a/api/envoy/service/metrics/v2/BUILD b/api/envoy/service/metrics/v2/BUILD index bbad50c789e5d..6d14bfe414796 100644 --- a/api/envoy/service/metrics/v2/BUILD +++ b/api/envoy/service/metrics/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "metrics_service", srcs = ["metrics_service.proto"], has_services = 1, @@ -10,7 +10,7 @@ api_proto_library( deps = [ "//envoy/api/v2/core:base", "//envoy/api/v2/core:grpc_service", - "@promotheus_metrics_model//:client_model", + "@prometheus_metrics_model//:client_model", ], ) @@ -19,6 +19,6 @@ api_go_grpc_library( proto = ":metrics_service", deps = [ "//envoy/api/v2/core:base_go_proto", - "@promotheus_metrics_model//:client_model_go_proto", + "@prometheus_metrics_model//:client_model_go_proto", ], ) diff --git a/api/envoy/service/ratelimit/v2/BUILD b/api/envoy/service/ratelimit/v2/BUILD index be6fdbc915ee0..4ee72b6518882 100644 --- a/api/envoy/service/ratelimit/v2/BUILD +++ b/api/envoy/service/ratelimit/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "rls", srcs = ["rls.proto"], has_services = 1, diff --git a/api/envoy/service/trace/v2/BUILD b/api/envoy/service/trace/v2/BUILD index a5f13f2c482e7..49c935f12938d 100644 --- a/api/envoy/service/trace/v2/BUILD +++ b/api/envoy/service/trace/v2/BUILD @@ -1,8 +1,8 @@ -load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_grpc_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( name = "trace_service", srcs = ["trace_service.proto"], has_services = 1, diff --git a/api/envoy/type/BUILD b/api/envoy/type/BUILD index 4859476efbd9d..150e226517b50 100644 --- a/api/envoy/type/BUILD +++ b/api/envoy/type/BUILD @@ -1,8 +1,19 @@ -load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") licenses(["notice"]) # Apache 2 -api_proto_library( +api_proto_library_internal( + name = "http_status", + srcs = ["http_status.proto"], + visibility = ["//visibility:public"], +) + +api_go_proto_library( + name = "http_status", + proto = ":http_status", +) + +api_proto_library_internal( name = "percent", srcs = ["percent.proto"], visibility = ["//visibility:public"], @@ -13,7 +24,7 @@ api_go_proto_library( proto = ":percent", ) -api_proto_library( +api_proto_library_internal( name = "range", srcs = ["range.proto"], visibility = ["//visibility:public"], @@ -23,14 +34,3 @@ api_go_proto_library( name = "range", proto = ":range", ) - -api_proto_library( - name = "string_match", - srcs = ["string_match.proto"], - visibility = ["//visibility:public"], -) - -api_go_proto_library( - name = "string_match", - proto = ":string_match", -) diff --git a/api/envoy/type/http_status.proto b/api/envoy/type/http_status.proto new file mode 100644 index 0000000000000..35655613c198c --- /dev/null +++ b/api/envoy/type/http_status.proto @@ -0,0 +1,81 @@ +syntax = "proto3"; + +package envoy.type; + +import "validate/validate.proto"; + +// HTTP response codes supported in Envoy. +// For more details: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml +enum StatusCode { + // Empty - This code not part of the HTTP status code specification, but it is needed for proto + // `enum` type. + Empty = 0; + + Continue = 100; + + OK = 200; + Created = 201; + Accepted = 202; + NonAuthoritativeInformation = 203; + NoContent = 204; + ResetContent = 205; + PartialContent = 206; + MultiStatus = 207; + AlreadyReported = 208; + IMUsed = 226; + + MultipleChoices = 300; + MovedPermanently = 301; + Found = 302; + SeeOther = 303; + NotModified = 304; + UseProxy = 305; + TemporaryRedirect = 307; + PermanentRedirect = 308; + + BadRequest = 400; + Unauthorized = 401; + PaymentRequired = 402; + Forbidden = 403; + NotFound = 404; + MethodNotAllowed = 405; + NotAcceptable = 406; + ProxyAuthenticationRequired = 407; + RequestTimeout = 408; + Conflict = 409; + Gone = 410; + LengthRequired = 411; + PreconditionFailed = 412; + PayloadTooLarge = 413; + URITooLong = 414; + UnsupportedMediaType = 415; + RangeNotSatisfiable = 416; + ExpectationFailed = 417; + MisdirectedRequest = 421; + UnprocessableEntity = 422; + Locked = 423; + FailedDependency = 424; + UpgradeRequired = 426; + PreconditionRequired = 428; + TooManyRequests = 429; + RequestHeaderFieldsTooLarge = 431; + + InternalServerError = 500; + NotImplemented = 501; + BadGateway = 502; + ServiceUnavailable = 503; + GatewayTimeout = 504; + HTTPVersionNotSupported = 505; + VariantAlsoNegotiates = 506; + InsufficientStorage = 507; + LoopDetected = 508; + NotExtended = 510; + NetworkAuthenticationRequired = 511; +} + +// HTTP status. +message HttpStatus { + // Supplies HTTP response code. + StatusCode code = 1 + [(validate.rules).enum = {not_in: [0]}, (validate.rules).enum.defined_only = true]; +} diff --git a/api/envoy/type/matcher/BUILD b/api/envoy/type/matcher/BUILD new file mode 100644 index 0000000000000..eb261e6f7ddbb --- /dev/null +++ b/api/envoy/type/matcher/BUILD @@ -0,0 +1,50 @@ +load("//bazel:api_build_system.bzl", "api_go_proto_library", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "metadata", + srcs = ["metadata.proto"], + visibility = ["//visibility:public"], + deps = [ + ":number", + ":string", + ], +) + +api_go_proto_library( + name = "metadata", + proto = ":metadata", + deps = [ + ":number_go_proto", + ":string_go_proto", + ], +) + +api_proto_library_internal( + name = "number", + srcs = ["number.proto"], + visibility = ["//visibility:public"], + deps = [ + "//envoy/type:range", + ], +) + +api_go_proto_library( + name = "number", + proto = ":number", + deps = [ + "//envoy/type:range_go_proto", + ], +) + +api_proto_library_internal( + name = "string", + srcs = ["string.proto"], + visibility = ["//visibility:public"], +) + +api_go_proto_library( + name = "string", + proto = ":string", +) diff --git a/api/envoy/type/matcher/metadata.proto b/api/envoy/type/matcher/metadata.proto new file mode 100644 index 0000000000000..f899bc1305251 --- /dev/null +++ b/api/envoy/type/matcher/metadata.proto @@ -0,0 +1,123 @@ +syntax = "proto3"; + +package envoy.type.matcher; +option go_package = "matcher"; + +import "envoy/type/matcher/string.proto"; +import "envoy/type/matcher/number.proto"; + +import "validate/validate.proto"; + +// [#protodoc-title: MetadataMatcher] + +// MetadataMatcher provides a general interface to check if a given value is matched in +// :ref:`Metadata `. It uses `filter` and `path` to retrieve the value +// from the Metadata and then check if it's matched to the specified value. +// +// For example, for the following Metadata: +// +// .. code-block:: yaml +// +// filter_metadata: +// envoy.filters.http.rbac: +// fields: +// a: +// struct_value: +// fields: +// b: +// struct_value: +// fields: +// c: +// string_value: pro +// t: +// list_value: +// values: +// - string_value: m +// - string_value: n +// +// The following MetadataMatcher is matched as the path [a, b, c] will retrieve a string value "pro" +// from the Metadata which is matched to the specified prefix match. +// +// .. code-block:: yaml +// +// filter: envoy.filters.http.rbac +// path: +// - key: a +// - key: b +// - key: c +// value: +// string_match: +// prefix: pr +// +// The following MetadataMatcher is not matched as the path [a, t] is pointing to a list value in +// the Metadata which is not supported for now. +// +// .. code-block:: yaml +// +// filter: envoy.filters.http.rbac +// path: +// - key: a +// - key: t +// value: +// string_match: +// exact: m +// +// An example use of MetadataMatcher is specifying additional metadata in envoy.filters.http.rbac to +// enforce access control based on dynamic metadata in a request. See :ref:`Permission +// ` and :ref:`Principal +// `. +message MetadataMatcher { + // Specifies the segment in a path to retrieve value from Metadata. + // Note: Currently it's not supported to retrieve a value from a list in Metadata. This means it + // will always be not matched if the associated value of the key is a list. + message PathSegment { + oneof segment { + option (validate.required) = true; + + // If specified, use the key to retrieve the value in a Struct. + string key = 1 [(validate.rules).string.min_bytes = 1]; + } + } + + // Specifies the value to match. Only primitive value are supported. For non-primitive values, the + // result is always not matched. + message Value { + // NullMatch is an empty message to specify a null value. + message NullMatch { + } + + // Specifies how to match a value. + oneof match_pattern { + option (validate.required) = true; + + // If specified, a match occurs if and only if the target value is a NullValue. + NullMatch null_match = 1; + + // If specified, a match occurs if and only if the target value is a double value and is + // matched to this field. + DoubleMatcher double_match = 2; + + // If specified, a match occurs if and only if the target value is a string value and is + // matched to this field. + StringMatcher string_match = 3; + + // If specified, a match occurs if and only if the target value is a bool value and is equal + // to this field. + bool bool_match = 4; + + // If specified, value match will be performed based on whether the path is referring to a + // valid primitive value in the metadata. If the path is referring to a non-primitive value, + // the result is always not matched. + bool present_match = 5; + } + } + + // The filter name to retrieve the Struct from the Metadata. + string filter = 1 [(validate.rules).string.min_bytes = 1]; + + // The path to retrieve the Value from the Struct. + repeated PathSegment path = 2 [(validate.rules).repeated .min_items = 1]; + + // The MetadataMatcher is matched if the value retrieved by path is matched to this value. + Value value = 3 [(validate.rules).message.required = true]; +} diff --git a/api/envoy/type/matcher/number.proto b/api/envoy/type/matcher/number.proto new file mode 100644 index 0000000000000..9cf4ff1f10458 --- /dev/null +++ b/api/envoy/type/matcher/number.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package envoy.type.matcher; +option go_package = "matcher"; + +import "envoy/type/range.proto"; + +import "validate/validate.proto"; + +// [#protodoc-title: NumberMatcher] + +// Specifies the way to match a double value. +message DoubleMatcher { + oneof match_pattern { + option (validate.required) = true; + + // If specified, the input double value must be in the range specified here. + // Note: The range is using half-open interval semantics [start, end). + envoy.type.DoubleRange range = 1; + + // If specified, the input double value must be equal to the value specified here. + double exact = 2; + } +} diff --git a/api/envoy/type/matcher/string.proto b/api/envoy/type/matcher/string.proto new file mode 100644 index 0000000000000..afb419a613b39 --- /dev/null +++ b/api/envoy/type/matcher/string.proto @@ -0,0 +1,49 @@ +syntax = "proto3"; + +package envoy.type.matcher; +option go_package = "matcher"; + +import "validate/validate.proto"; + +// [#protodoc-title: StringMatcher] + +// Specifies the way to match a string. +message StringMatcher { + oneof match_pattern { + option (validate.required) = true; + + // The input string must match exactly the string specified here. + // + // Examples: + // + // * *abc* only matches the value *abc*. + string exact = 1; + + // The input string must have the prefix specified here. + // Note: empty prefix is not allowed, please use regex instead. + // + // Examples: + // + // * *abc* matches the value *abc.xyz* + string prefix = 2 [(validate.rules).string.min_bytes = 1]; + + // The input string must have the suffix specified here. + // Note: empty prefix is not allowed, please use regex instead. + // + // Examples: + // + // * *abc* matches the value *xyz.abc* + string suffix = 3 [(validate.rules).string.min_bytes = 1]; + + // The input string must match the regular expression specified here. + // The regex grammar is defined `here + // `_. + // + // Examples: + // + // * The regex *\d{3}* matches the value *123* + // * The regex *\d{3}* does not match the value *1234* + // * The regex *\d{3}* does not match the value *123.456* + string regex = 4; + } +} diff --git a/api/envoy/type/range.proto b/api/envoy/type/range.proto index fd6045e7fd289..115091ddf9f69 100644 --- a/api/envoy/type/range.proto +++ b/api/envoy/type/range.proto @@ -18,3 +18,13 @@ message Int64Range { // end of the range (exclusive) int64 end = 2; } + +// Specifies the double start and end of the range using half-open interval semantics [start, +// end). +message DoubleRange { + // start of the range (inclusive) + double start = 1; + + // end of the range (exclusive) + double end = 2; +} diff --git a/api/envoy/type/string_match.proto b/api/envoy/type/string_match.proto deleted file mode 100644 index c1e2468ad5899..0000000000000 --- a/api/envoy/type/string_match.proto +++ /dev/null @@ -1,30 +0,0 @@ -syntax = "proto3"; - -package envoy.type; -option go_package = "envoy_type"; - -import "gogoproto/gogo.proto"; - -option (gogoproto.equal_all) = true; - -// [#protodoc-title: StringMatch] - -// Specifies the way to match a string. -message StringMatch { - oneof match_pattern { - // The input string must match exactly the string specified here. - // Or it is a "*", which means that it matches any string. - string simple = 1; - - // The input string must have the prefix specified here. - string prefix = 2; - - // The input string must have the suffix specified here. - string suffix = 3; - - // The input string must match the regular expression specified here. - // The regex grammar is defined `here - // `_. - string regex = 4; - } -} diff --git a/api/test/validate/BUILD b/api/test/validate/BUILD index 2707e02cda541..2c98249c78859 100644 --- a/api/test/validate/BUILD +++ b/api/test/validate/BUILD @@ -1,4 +1,4 @@ -load("//bazel:api_build_system.bzl", "api_cc_test", "api_proto_library") +load("//bazel:api_build_system.bzl", "api_cc_test", "api_proto_library_internal") licenses(["notice"]) # Apache 2 diff --git a/bazel/BUILD b/bazel/BUILD index 6a5258d5440b0..223d2d0b7ee19 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -35,6 +35,35 @@ genrule( stamp = 1, ) +config_setting( + name = "windows_x86_64", + values = {"cpu": "x64_windows"}, +) + +config_setting( + name = "windows_opt_build", + values = { + "cpu": "x64_windows", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "windows_dbg_build", + values = { + "cpu": "x64_windows", + "compilation_mode": "dbg", + }, +) + +config_setting( + name = "windows_fastbuild_build", + values = { + "cpu": "x64_windows", + "compilation_mode": "fastbuild", + }, +) + config_setting( name = "opt_build", values = {"compilation_mode": "opt"}, diff --git a/bazel/README.md b/bazel/README.md index 6b9631f327113..0b68e7ec16e85 100644 --- a/bazel/README.md +++ b/bazel/README.md @@ -27,7 +27,7 @@ up-to-date with the latest security patches. See for how to update or override dependencies. 1. Install the latest version of [Bazel](https://bazel.build/versions/master/docs/install.html) in your environment. -2. Install external dependencies libtool, cmake, and realpath libraries separately. +2. Install external dependencies libtool, cmake, ninja, and realpath libraries separately. On Ubuntu, run the following commands: ``` apt-get install libtool @@ -35,6 +35,7 @@ On Ubuntu, run the following commands: apt-get install realpath apt-get install clang-format-5.0 apt-get install automake + apt-get install ninja-build ``` On Fedora (maybe also other red hat distros), run the following: @@ -51,6 +52,7 @@ brew install libtool brew install go brew install bazel brew install automake +brew install ninja ``` Envoy compiles and passes tests with the version of clang installed by XCode 9.3.0: @@ -353,6 +355,14 @@ then log back in and it should start working. The latest coverage report for master is available [here](https://s3.amazonaws.com/lyft-envoy/coverage/report-master/coverage.html). +It's also possible to specialize the coverage build to a single test target. This is useful +when doing things like exploring the coverage of a fuzzer over its corpus. This can be done with +the `COVERAGE_TARGET` and `VALIDATE_COVERAGE` environment variables, e.g.: + +``` +COVERAGE_TARGET=//test/common/common:base64_fuzz_test VALIDATE_COVERAGE=false test/run_envoy_bazel_coverage.sh +``` + # Cleaning the build and test artifacts `bazel clean` will nuke all the build/test artifacts from the Bazel cache for diff --git a/bazel/cc_configure.bzl b/bazel/cc_configure.bzl index 35b005557c6ea..eb1dead6b260a 100644 --- a/bazel/cc_configure.bzl +++ b/bazel/cc_configure.bzl @@ -5,86 +5,86 @@ load("@bazel_tools//tools/cpp:unix_cc_configure.bzl", "find_cc") # Stub for `repository_ctx.which()` that always succeeds. See comments in # `_find_cxx` for details. def _quiet_fake_which(program): - return struct(_envoy_fake_which = program) + return struct(_envoy_fake_which = program) # Stub for `repository_ctx.which()` that always fails. See comments in # `_find_cxx` for details. def _noisy_fake_which(program): - return None + return None # Find a good path for the C++ compiler, by hooking into Bazel's C compiler # detection. Uses `$CXX` if found, otherwise defaults to `g++` because Bazel # defaults to `gcc`. def _find_cxx(repository_ctx): - # Bazel's `find_cc` helper uses the repository context to inspect `$CC`. - # Replace this value with `$CXX` if set. - environ_cxx = repository_ctx.os.environ.get("CXX", "g++") - fake_os = struct( - environ = {"CC": environ_cxx}, - ) + # Bazel's `find_cc` helper uses the repository context to inspect `$CC`. + # Replace this value with `$CXX` if set. + environ_cxx = repository_ctx.os.environ.get("CXX", "g++") + fake_os = struct( + environ = {"CC": environ_cxx}, + ) - # We can't directly assign `repository_ctx.which` to a struct attribute - # because Skylark doesn't support bound method references. Instead, stub - # out `which()` using a two-pass approach: - # - # * The first pass uses a stub that always succeeds, passing back a special - # value containing the original parameter. - # * If we detect the special value, we know that `find_cc` found a compiler - # name but don't know if that name could be resolved to an executable path. - # So do the `which()` call ourselves. - # * If our `which()` failed, call `find_cc` again with a dummy which that - # always fails. The error raised by `find_cc` will be identical to what Bazel - # would generate for a missing C compiler. - # - # See https://github.com/bazelbuild/bazel/issues/4644 for more context. - real_cxx = find_cc(struct( - which = _quiet_fake_which, - os = fake_os, - ), {}) - if hasattr(real_cxx, "_envoy_fake_which"): - real_cxx = repository_ctx.which(real_cxx._envoy_fake_which) - if real_cxx == None: - find_cc(struct( - which = _noisy_fake_which, + # We can't directly assign `repository_ctx.which` to a struct attribute + # because Skylark doesn't support bound method references. Instead, stub + # out `which()` using a two-pass approach: + # + # * The first pass uses a stub that always succeeds, passing back a special + # value containing the original parameter. + # * If we detect the special value, we know that `find_cc` found a compiler + # name but don't know if that name could be resolved to an executable path. + # So do the `which()` call ourselves. + # * If our `which()` failed, call `find_cc` again with a dummy which that + # always fails. The error raised by `find_cc` will be identical to what Bazel + # would generate for a missing C compiler. + # + # See https://github.com/bazelbuild/bazel/issues/4644 for more context. + real_cxx = find_cc(struct( + which = _quiet_fake_which, os = fake_os, - ), {}) - return real_cxx + ), {}) + if hasattr(real_cxx, "_envoy_fake_which"): + real_cxx = repository_ctx.which(real_cxx._envoy_fake_which) + if real_cxx == None: + find_cc(struct( + which = _noisy_fake_which, + os = fake_os, + ), {}) + return real_cxx def _build_envoy_cc_wrapper(repository_ctx): - real_cc = find_cc(repository_ctx, {}) - real_cxx = _find_cxx(repository_ctx) + real_cc = find_cc(repository_ctx, {}) + real_cxx = _find_cxx(repository_ctx) - # Copy our CC wrapper script into @local_config_cc, with the true paths - # to the C and C++ compiler injected in. The wrapper will use these paths - # to invoke the compiler after deciding which one is correct for the current - # invocation. - # - # Since the script is Python, we can inject values using `repr(str(value))` - # and escaping will be handled correctly. - repository_ctx.template("extra_tools/envoy_cc_wrapper", repository_ctx.attr._envoy_cc_wrapper, { - "{ENVOY_REAL_CC}": repr(str(real_cc)), - "{ENVOY_REAL_CXX}": repr(str(real_cxx)), - }) - return repository_ctx.path("extra_tools/envoy_cc_wrapper") + # Copy our CC wrapper script into @local_config_cc, with the true paths + # to the C and C++ compiler injected in. The wrapper will use these paths + # to invoke the compiler after deciding which one is correct for the current + # invocation. + # + # Since the script is Python, we can inject values using `repr(str(value))` + # and escaping will be handled correctly. + repository_ctx.template("extra_tools/envoy_cc_wrapper", repository_ctx.attr._envoy_cc_wrapper, { + "{ENVOY_REAL_CC}": repr(str(real_cc)), + "{ENVOY_REAL_CXX}": repr(str(real_cxx)), + }) + return repository_ctx.path("extra_tools/envoy_cc_wrapper") def _needs_envoy_cc_wrapper(repository_ctx): - # When building for Linux we set additional C++ compiler options that aren't - # handled well by Bazel, so we need a wrapper around $CC to fix its - # compiler invocations. - cpu_value = get_cpu_value(repository_ctx) - return cpu_value not in ["freebsd", "x64_windows", "darwin"] + # When building for Linux we set additional C++ compiler options that aren't + # handled well by Bazel, so we need a wrapper around $CC to fix its + # compiler invocations. + cpu_value = get_cpu_value(repository_ctx) + return cpu_value not in ["freebsd", "x64_windows", "darwin"] def cc_autoconf_impl(repository_ctx): - overriden_tools = {} - if _needs_envoy_cc_wrapper(repository_ctx): - # Bazel uses "gcc" as a generic name for all C and C++ compilers. - overriden_tools["gcc"] = _build_envoy_cc_wrapper(repository_ctx) - return _upstream_cc_autoconf_impl(repository_ctx, overriden_tools=overriden_tools) + overriden_tools = {} + if _needs_envoy_cc_wrapper(repository_ctx): + # Bazel uses "gcc" as a generic name for all C and C++ compilers. + overriden_tools["gcc"] = _build_envoy_cc_wrapper(repository_ctx) + return _upstream_cc_autoconf_impl(repository_ctx, overriden_tools = overriden_tools) cc_autoconf = repository_rule( implementation = cc_autoconf_impl, attrs = { - "_envoy_cc_wrapper": attr.label(default="@envoy//bazel:cc_wrapper.py"), + "_envoy_cc_wrapper": attr.label(default = "@envoy//bazel:cc_wrapper.py"), }, environ = [ "ABI_LIBC_VERSION", @@ -116,8 +116,10 @@ cc_autoconf = repository_rule( "VS100COMNTOOLS", "VS110COMNTOOLS", "VS120COMNTOOLS", - "VS140COMNTOOLS"]) + "VS140COMNTOOLS", + ], +) def cc_configure(): - cc_autoconf(name="local_config_cc") - native.bind(name="cc_toolchain", actual="@local_config_cc//:toolchain") + cc_autoconf(name = "local_config_cc") + native.bind(name = "cc_toolchain", actual = "@local_config_cc//:toolchain") diff --git a/bazel/envoy_build_system.bzl b/bazel/envoy_build_system.bzl index bf7885cc7d6c5..9f79cdf88565c 100644 --- a/bazel/envoy_build_system.bzl +++ b/bazel/envoy_build_system.bzl @@ -5,7 +5,7 @@ def envoy_package(): # Compute the final copts based on various options. def envoy_copts(repository, test = False): - return [ + posix_options = [ "-Wall", "-Wextra", "-Werror", @@ -13,81 +13,110 @@ def envoy_copts(repository, test = False): "-Woverloaded-virtual", "-Wold-style-cast", "-std=c++14", - ] + select({ - # Bazel adds an implicit -DNDEBUG for opt. - repository + "//bazel:opt_build": [] if test else ["-ggdb3"], - repository + "//bazel:fastbuild_build": [], - repository + "//bazel:dbg_build": ["-ggdb3"], - }) + select({ - repository + "//bazel:disable_tcmalloc": ["-DABSL_MALLOC_HOOK_MMAP_DISABLE"], - "//conditions:default": ["-DTCMALLOC"], - }) + select({ - repository + "//bazel:disable_signal_trace": [], - "//conditions:default": ["-DENVOY_HANDLE_SIGNALS"], - }) + select({ - # TCLAP command line parser needs this to support int64_t/uint64_t - "@bazel_tools//tools/osx:darwin": ["-DHAVE_LONG_LONG"], - "//conditions:default": [], - }) + envoy_select_hot_restart(["-DENVOY_HOT_RESTART"], repository) + \ - envoy_select_perf_annotation(["-DENVOY_PERF_ANNOTATION"]) + \ - envoy_select_google_grpc(["-DENVOY_GOOGLE_GRPC"], repository) + ] + + msvc_options = [ + "-WX", + "-DWIN32", + "-DWIN32_LEAN_AND_MEAN", + # need win8 for ntohll + # https://msdn.microsoft.com/en-us/library/windows/desktop/aa383745(v=vs.85).aspx + "-D_WIN32_WINNT=0x0602", + "-DNTDDI_VERSION=0x06020000", + "-DCARES_STATICLIB", + "-DNGHTTP2_STATICLIB", + ] + + return select({ + repository + "//bazel:windows_x86_64": msvc_options, + "//conditions:default": posix_options, + }) + select({ + # Bazel adds an implicit -DNDEBUG for opt. + repository + "//bazel:opt_build": [] if test else ["-ggdb3"], + repository + "//bazel:fastbuild_build": [], + repository + "//bazel:dbg_build": ["-ggdb3"], + repository + "//bazel:windows_opt_build": [], + repository + "//bazel:windows_fastbuild_build": [], + repository + "//bazel:windows_dbg_build": [], + }) + select({ + repository + "//bazel:disable_tcmalloc": ["-DABSL_MALLOC_HOOK_MMAP_DISABLE"], + "//conditions:default": ["-DTCMALLOC"], + }) + select({ + repository + "//bazel:disable_signal_trace": [], + "//conditions:default": ["-DENVOY_HANDLE_SIGNALS"], + }) + select({ + # TCLAP command line parser needs this to support int64_t/uint64_t + "@bazel_tools//tools/osx:darwin": ["-DHAVE_LONG_LONG"], + "//conditions:default": [], + }) + envoy_select_hot_restart(["-DENVOY_HOT_RESTART"], repository) + \ + envoy_select_perf_annotation(["-DENVOY_PERF_ANNOTATION"]) + \ + envoy_select_google_grpc(["-DENVOY_GOOGLE_GRPC"], repository) def envoy_static_link_libstdcpp_linkopts(): - return envoy_select_force_libcpp(["--stdlib=libc++"], - ["-static-libstdc++", "-static-libgcc"]) + return envoy_select_force_libcpp( + ["--stdlib=libc++"], + ["-static-libstdc++", "-static-libgcc"], + ) # Compute the final linkopts based on various options. def envoy_linkopts(): return select({ - # The OSX system library transitively links common libraries (e.g., pthread). + # The OSX system library transitively links common libraries (e.g., pthread). + "@bazel_tools//tools/osx:darwin": [ + # See note here: http://luajit.org/install.html + "-pagezero_size 10000", + "-image_base 100000000", + ], + "@envoy//bazel:windows_x86_64": [ + "-DEFAULTLIB:advapi32.lib", + ], + "//conditions:default": [ + "-pthread", + "-lrt", + "-ldl", + "-Wl,--hash-style=gnu", + ], + }) + envoy_static_link_libstdcpp_linkopts() + \ + envoy_select_exported_symbols(["-Wl,-E"]) + +def _envoy_stamped_linkopts(): + return select({ + # Coverage builds in CI are failing to link when setting a build ID. + # + # /usr/bin/ld.gold: internal error in write_build_id, at ../../gold/layout.cc:5419 + "@envoy//bazel:coverage_build": [], + "@envoy//bazel:windows_x86_64": [], + + # MacOS doesn't have an official equivalent to the `.note.gnu.build-id` + # ELF section, so just stuff the raw ID into a new text section. "@bazel_tools//tools/osx:darwin": [ - # See note here: http://luajit.org/install.html - "-pagezero_size 10000", "-image_base 100000000", + "-sectcreate __TEXT __build_id", + "$(location @envoy//bazel:raw_build_id.ldscript)", ], + + # Note: assumes GNU GCC (or compatible) handling of `--build-id` flag. "//conditions:default": [ - "-pthread", - "-lrt", - "-ldl", - '-Wl,--hash-style=gnu', + "-Wl,@$(location @envoy//bazel:gnu_build_id.ldscript)", ], - }) + envoy_static_link_libstdcpp_linkopts() \ - + envoy_select_exported_symbols(["-Wl,-E"]) - -def _envoy_stamped_linkopts(): - return select({ - # Coverage builds in CI are failing to link when setting a build ID. - # - # /usr/bin/ld.gold: internal error in write_build_id, at ../../gold/layout.cc:5419 - "@envoy//bazel:coverage_build": [], - - # MacOS doesn't have an official equivalent to the `.note.gnu.build-id` - # ELF section, so just stuff the raw ID into a new text section. - "@bazel_tools//tools/osx:darwin": [ - "-sectcreate __TEXT __build_id", "$(location @envoy//bazel:raw_build_id.ldscript)" - ], - - # Note: assumes GNU GCC (or compatible) handling of `--build-id` flag. - "//conditions:default": [ - "-Wl,@$(location @envoy//bazel:gnu_build_id.ldscript)", - ], - }) + }) def _envoy_stamped_deps(): - return select({ - "@bazel_tools//tools/osx:darwin": [ - "@envoy//bazel:raw_build_id.ldscript" - ], - "//conditions:default": [ - "@envoy//bazel:gnu_build_id.ldscript", - ], - }) + return select({ + "@bazel_tools//tools/osx:darwin": [ + "@envoy//bazel:raw_build_id.ldscript", + ], + "//conditions:default": [ + "@envoy//bazel:gnu_build_id.ldscript", + ], + }) # Compute the test linkopts based on various options. def envoy_test_linkopts(): return select({ "@bazel_tools//tools/osx:darwin": [ # See note here: http://luajit.org/install.html - "-pagezero_size 10000", "-image_base 100000000", + "-pagezero_size 10000", + "-image_base 100000000", ], # TODO(mattklein123): It's not great that we universally link against the following libs. @@ -119,8 +148,8 @@ def tcmalloc_external_deps(repository): # exporting the package headers at (e.g. envoy/common). Source files can then # include using this path scheme (e.g. #include "envoy/common/time.h"). def envoy_include_prefix(path): - if path.startswith('source/') or path.startswith('include/'): - return '/'.join(path.split('/')[1:]) + if path.startswith("source/") or path.startswith("include/"): + return "/".join(path.split("/")[1:]) return None # Envoy C++ library targets that need no transformations or additional dependencies before being @@ -131,20 +160,22 @@ def envoy_basic_cc_library(name, **kargs): native.cc_library(name = name, **kargs) # Envoy C++ library targets should be specified with this function. -def envoy_cc_library(name, - srcs = [], - hdrs = [], - copts = [], - visibility = None, - external_deps = [], - tcmalloc_dep = None, - repository = "", - linkstamp = None, - tags = [], - deps = [], - strip_include_prefix = None): +def envoy_cc_library( + name, + srcs = [], + hdrs = [], + copts = [], + visibility = None, + external_deps = [], + tcmalloc_dep = None, + repository = "", + linkstamp = None, + tags = [], + deps = [], + strip_include_prefix = None): if tcmalloc_dep: deps += tcmalloc_external_deps(repository) + native.cc_library( name = name, srcs = srcs, @@ -155,29 +186,32 @@ def envoy_cc_library(name, deps = deps + [envoy_external_dep_path(dep) for dep in external_deps] + [ repository + "//include/envoy/common:base_includes", repository + "//source/common/common:fmt_lib", - envoy_external_dep_path('abseil_strings'), - envoy_external_dep_path('spdlog'), - envoy_external_dep_path('fmtlib'), + envoy_external_dep_path("abseil_strings"), + envoy_external_dep_path("spdlog"), + envoy_external_dep_path("fmtlib"), ], include_prefix = envoy_include_prefix(PACKAGE_NAME), alwayslink = 1, linkstatic = 1, - linkstamp = linkstamp, + linkstamp = select({ + repository + "//bazel:windows_x86_64": None, + "//conditions:default": linkstamp, + }), strip_include_prefix = strip_include_prefix, - ) + ) # Envoy C++ binary targets should be specified with this function. -def envoy_cc_binary(name, - srcs = [], - data = [], - testonly = 0, - visibility = None, - external_deps = [], - repository = "", - stamped = False, - deps = [], - linkopts = []): - +def envoy_cc_binary( + name, + srcs = [], + data = [], + testonly = 0, + visibility = None, + external_deps = [], + repository = "", + stamped = False, + deps = [], + linkopts = []): if not linkopts: linkopts = envoy_linkopts() if stamped: @@ -233,20 +267,21 @@ def envoy_cc_fuzz_test(name, corpus, deps = [], **kwargs): ) # Envoy C++ test targets should be specified with this function. -def envoy_cc_test(name, - srcs = [], - data = [], - # List of pairs (Bazel shell script target, shell script args) - repository = "", - external_deps = [], - deps = [], - tags = [], - args = [], - coverage = True, - local = False): +def envoy_cc_test( + name, + srcs = [], + data = [], + # List of pairs (Bazel shell script target, shell script args) + repository = "", + external_deps = [], + deps = [], + tags = [], + args = [], + coverage = True, + local = False): test_lib_tags = [] if coverage: - test_lib_tags.append("coverage_test_lib") + test_lib_tags.append("coverage_test_lib") envoy_cc_test_library( name = name + "_lib", srcs = srcs, @@ -264,7 +299,7 @@ def envoy_cc_test(name, malloc = tcmalloc_external_dep(repository), deps = [ ":" + name + "_lib", - repository + "//test:main" + repository + "//test:main", ], # from https://github.com/google/googletest/blob/6e1970e2376c14bf658eb88f655a054030353f9f/googlemock/src/gmock.cc#L51 # 2 - by default, mocks act as StrictMocks. @@ -275,14 +310,15 @@ def envoy_cc_test(name, # Envoy C++ test related libraries (that want gtest, gmock) should be specified # with this function. -def envoy_cc_test_library(name, - srcs = [], - hdrs = [], - data = [], - external_deps = [], - deps = [], - repository = "", - tags = []): +def envoy_cc_test_library( + name, + srcs = [], + hdrs = [], + data = [], + external_deps = [], + deps = [], + repository = "", + tags = []): native.cc_library( name = name, srcs = srcs, @@ -291,7 +327,7 @@ def envoy_cc_test_library(name, copts = envoy_copts(repository, test = True), testonly = 1, deps = deps + [envoy_external_dep_path(dep) for dep in external_deps] + [ - envoy_external_dep_path('googletest'), + envoy_external_dep_path("googletest"), repository + "//test/test_common:printers_includes", ], tags = tags, @@ -300,18 +336,22 @@ def envoy_cc_test_library(name, ) # Envoy test binaries should be specified with this function. -def envoy_cc_test_binary(name, - **kargs): - envoy_cc_binary(name, - testonly = 1, - linkopts = envoy_test_linkopts() + envoy_static_link_libstdcpp_linkopts(), - **kargs) +def envoy_cc_test_binary( + name, + **kargs): + envoy_cc_binary( + name, + testonly = 1, + linkopts = envoy_test_linkopts() + envoy_static_link_libstdcpp_linkopts(), + **kargs + ) # Envoy Python test binaries should be specified with this function. -def envoy_py_test_binary(name, - external_deps = [], - deps = [], - **kargs): +def envoy_py_test_binary( + name, + external_deps = [], + deps = [], + **kargs): native.py_binary( name = name, deps = deps + [envoy_external_dep_path(dep) for dep in external_deps], @@ -323,41 +363,46 @@ def envoy_cc_mock(name, **kargs): envoy_cc_test_library(name = name, **kargs) # Envoy shell tests that need to be included in coverage run should be specified with this function. -def envoy_sh_test(name, - srcs = [], - data = [], - **kargs): - test_runner_cc = name + "_test_runner.cc" - native.genrule( - name = name + "_gen_test_runner", - srcs = srcs, - outs = [test_runner_cc], - cmd = "$(location //bazel:gen_sh_test_runner.sh) $(SRCS) >> $@", - tools = ["//bazel:gen_sh_test_runner.sh"], - ) - envoy_cc_test_library( - name = name + "_lib", - srcs = [test_runner_cc], - data = srcs + data, - tags = ["coverage_test_lib"], - deps = ["//test/test_common:environment_lib"], - ) - native.sh_test( - name = name, - srcs = ["//bazel:sh_test_wrapper.sh"], - data = srcs + data, - args = srcs, - **kargs - ) +def envoy_sh_test( + name, + srcs = [], + data = [], + **kargs): + test_runner_cc = name + "_test_runner.cc" + native.genrule( + name = name + "_gen_test_runner", + srcs = srcs, + outs = [test_runner_cc], + cmd = "$(location //bazel:gen_sh_test_runner.sh) $(SRCS) >> $@", + tools = ["//bazel:gen_sh_test_runner.sh"], + ) + envoy_cc_test_library( + name = name + "_lib", + srcs = [test_runner_cc], + data = srcs + data, + tags = ["coverage_test_lib"], + deps = ["//test/test_common:environment_lib"], + ) + native.sh_test( + name = name, + srcs = ["//bazel:sh_test_wrapper.sh"], + data = srcs + data, + args = srcs, + **kargs + ) def _proto_header(proto_path): - if proto_path.endswith(".proto"): - return proto_path[:-5] + "pb.h" - return None + if proto_path.endswith(".proto"): + return proto_path[:-5] + "pb.h" + return None # Envoy proto targets should be specified with this function. -def envoy_proto_library(name, srcs = [], deps = [], external_deps = [], - generate_python = True): +def envoy_proto_library( + name, + srcs = [], + deps = [], + external_deps = [], + generate_python = True): # Ideally this would be native.{proto_library, cc_proto_library}. # Unfortunately, this doesn't work with http_api_protos due to the PGV # requirement to also use them in the non-native protobuf.bzl @@ -366,6 +411,10 @@ def envoy_proto_library(name, srcs = [], deps = [], external_deps = [], cc_proto_deps = [] py_proto_deps = ["@com_google_protobuf//:protobuf_python"] + if "api_httpbody_protos" in external_deps: + cc_proto_deps.append("@googleapis//:api_httpbody_protos") + py_proto_deps.append("@googleapis//:api_httpbody_protos_py") + if "http_api_protos" in external_deps: cc_proto_deps.append("@googleapis//:http_api_protos") py_proto_deps.append("@googleapis//:http_api_protos_py") @@ -403,6 +452,10 @@ def envoy_proto_descriptor(name, out, srcs = [], external_deps = []): input_files = ["$(location " + src + ")" for src in srcs] include_paths = [".", PACKAGE_NAME] + if "api_httpbody_protos" in external_deps: + srcs.append("@googleapis//:api_httpbody_protos_src") + include_paths.append("external/googleapis") + if "http_api_protos" in external_deps: srcs.append("@googleapis//:http_api_protos_src") include_paths.append("external/googleapis") @@ -432,13 +485,11 @@ def envoy_select_hot_restart(xs, repository = ""): "//conditions:default": xs, }) - def envoy_select_perf_annotation(xs): return select({ "@envoy//bazel:enable_perf_annotation": xs, "//conditions:default": [], - }) - + }) # Selects the given values if Google gRPC is enabled in the current build. def envoy_select_google_grpc(xs, repository = ""): @@ -458,5 +509,6 @@ def envoy_select_force_libcpp(if_libcpp, default = None): return select({ "@envoy//bazel:force_libcpp": if_libcpp, "@bazel_tools//tools/osx:darwin": [], + "@envoy//bazel:windows_x86_64": [], "//conditions:default": default or [], }) diff --git a/bazel/external/apache_thrift.BUILD b/bazel/external/apache_thrift.BUILD new file mode 100644 index 0000000000000..8b296fc00672b --- /dev/null +++ b/bazel/external/apache_thrift.BUILD @@ -0,0 +1,21 @@ +# The apache-thrift distribution does not keep the thrift files in a directory with the +# expected package name (it uses src/Thrift.py vs src/thrift/Thrift.py), so we provide a +# genrule to copy src/**/*.py to thrift/**/*.py. +src_files = glob(["src/**/*.py"]) + +genrule( + name = "thrift_files", + srcs = src_files, + outs = [f.replace("src/", "thrift/") for f in src_files], + cmd = '\n'.join( + ['mkdir -p $$(dirname $(location %s)) && cp $(location %s) $(location :%s)' % (f, f, f.replace('src/', 'thrift/')) for f in src_files] + ), + visibility = ["//visibility:private"], +) + +py_library( + name = "apache_thrift", + srcs = [":thrift_files"], + visibility = ["//visibility:public"], + deps = ["@six_archive//:six"], +) diff --git a/bazel/external/libcircllhist.BUILD b/bazel/external/libcircllhist.BUILD index 4e109f0b38d47..a937b65a382c6 100644 --- a/bazel/external/libcircllhist.BUILD +++ b/bazel/external/libcircllhist.BUILD @@ -6,4 +6,8 @@ cc_library( ], includes = ["src"], visibility = ["//visibility:public"], + copts = select({ + "@envoy//bazel:windows_x86_64": ["-DWIN32"], + "//conditions:default": [], + }), ) diff --git a/bazel/external/twitter_common_finagle_thrift.BUILD b/bazel/external/twitter_common_finagle_thrift.BUILD new file mode 100644 index 0000000000000..1ca6af126c596 --- /dev/null +++ b/bazel/external/twitter_common_finagle_thrift.BUILD @@ -0,0 +1,7 @@ +py_library( + name = "twitter_common_finagle_thrift", + srcs = glob([ + "gen/**/*.py", + ]), + visibility = ["//visibility:public"], +) diff --git a/bazel/external/twitter_common_lang.BUILD b/bazel/external/twitter_common_lang.BUILD new file mode 100644 index 0000000000000..f4300b37b05d2 --- /dev/null +++ b/bazel/external/twitter_common_lang.BUILD @@ -0,0 +1,7 @@ +py_library( + name = "twitter_common_lang", + srcs = glob([ + "twitter/**/*.py", + ]), + visibility = ["//visibility:public"], +) diff --git a/bazel/external/twitter_common_rpc.BUILD b/bazel/external/twitter_common_rpc.BUILD new file mode 100644 index 0000000000000..7a13ec511a667 --- /dev/null +++ b/bazel/external/twitter_common_rpc.BUILD @@ -0,0 +1,11 @@ +py_library( + name = "twitter_common_rpc", + srcs = glob([ + "twitter/**/*.py", + ]), + visibility = ["//visibility:public"], + deps = [ + "@com_github_twitter_common_lang//:twitter_common_lang", + "@com_github_twitter_common_finagle_thrift//:twitter_common_finagle_thrift" + ], +) diff --git a/bazel/genrule_repository.bzl b/bazel/genrule_repository.bzl index a72be286987dd..030cb9a3a38c1 100644 --- a/bazel/genrule_repository.bzl +++ b/bazel/genrule_repository.bzl @@ -1,9 +1,9 @@ def _genrule_repository(ctx): ctx.download_and_extract( ctx.attr.urls, - "", # output + "", # output ctx.attr.sha256, - "", # type + "", # type ctx.attr.strip_prefix, ) for ii, patch in enumerate(ctx.attr.patches): @@ -11,7 +11,7 @@ def _genrule_repository(ctx): ctx.symlink(patch, patch_input) patch_result = ctx.execute(["patch", "-p0", "--input", patch_input]) if patch_result.return_code != 0: - fail("Failed to apply patch %r: %s" % (patch, patch_result.stderr)) + fail("Failed to apply patch %r: %s" % (patch, patch_result.stderr)) # https://github.com/bazelbuild/bazel/issues/3766 genrule_cmd_file = Label("@envoy//bazel").relative(str(ctx.attr.genrule_cmd_file)) @@ -19,7 +19,9 @@ def _genrule_repository(ctx): cat_genrule_cmd = ctx.execute(["cat", "_envoy_genrule_cmd.genrule_cmd"]) if cat_genrule_cmd.return_code != 0: fail("Failed to read genrule command %r: %s" % ( - genrule_cmd_file, cat_genrule_cmd.stderr)) + genrule_cmd_file, + cat_genrule_cmd.stderr, + )) ctx.file("WORKSPACE", "workspace(name=%r)" % (ctx.name,)) ctx.symlink(ctx.attr.build_file, "BUILD.bazel") @@ -58,10 +60,10 @@ genrule_repository = repository_rule( ) def _genrule_cc_deps(ctx): - outs = depset() - for dep in ctx.attr.deps: - outs = dep.cc.transitive_headers + dep.cc.libs + outs - return DefaultInfo(files=outs) + outs = depset() + for dep in ctx.attr.deps: + outs = dep.cc.transitive_headers + dep.cc.libs + outs + return DefaultInfo(files = outs) genrule_cc_deps = rule( attrs = { @@ -75,67 +77,67 @@ genrule_cc_deps = rule( ) def _absolute_bin(path): - # If the binary path looks like it's relative to the current directory, - # transform it to be absolute by appending "${PWD}". - if "/" in path and not path.startswith("/"): - return '"${PWD}"/%r' % (path,) - return '%r' % (path,) + # If the binary path looks like it's relative to the current directory, + # transform it to be absolute by appending "${PWD}". + if "/" in path and not path.startswith("/"): + return '"${PWD}"/%r' % (path,) + return "%r" % (path,) def _genrule_environment(ctx): - lines = [] - - # Bazel uses the same command for C and C++ compilation. - c_compiler = ctx.var['CC'] - - # Bare minimum cflags to get included test binaries to link. - # - # See //tools:bazel.rc for the full set. - asan_flags = ["-fsanitize=address,undefined"] - tsan_flags = ["-fsanitize=thread"] - - # Older versions of GCC in Ubuntu, including GCC 5 used in CI images, - # incorrectly invoke the older `/usr/bin/ld` with gold-specific options when - # building with sanitizers enabled. Work around this by forcing use of gold - # in sanitize mode. - # - # This is not a great solution because it doesn't detect GCC when Bazel has - # wrapped it in an intermediate script, but it works well enough to keep CI - # running. - # - # https://stackoverflow.com/questions/37603238/fsanitize-not-using-gold-linker-in-gcc-6-1 - force_ld_gold = [] - if "gcc" in c_compiler or "g++" in c_compiler: - force_ld_gold = ["-fuse-ld=gold"] - - cc_flags = [] - ld_flags = [] - ld_libs = [] - if ctx.var.get('ENVOY_CONFIG_COVERAGE'): - ld_libs += ["-lgcov"] - if ctx.var.get('ENVOY_CONFIG_ASAN'): - cc_flags += asan_flags - ld_flags += asan_flags - ld_flags += force_ld_gold - if ctx.var.get('ENVOY_CONFIG_TSAN'): - cc_flags += tsan_flags - ld_flags += tsan_flags - ld_flags += force_ld_gold - - lines.append("export CFLAGS=%r" % (" ".join(cc_flags),)) - lines.append("export LDFLAGS=%r" % (" ".join(ld_flags),)) - lines.append("export LIBS=%r" % (" ".join(ld_libs),)) - lines.append("export CC=%s" % (_absolute_bin(c_compiler),)) - lines.append("export CXX=%s" % (_absolute_bin(c_compiler),)) - - # Some Autoconf helper binaries leak, which makes ./configure think the - # system is unable to do anything. Turn off leak checking during part of - # the build. - lines.append("export ASAN_OPTIONS=detect_leaks=0") - - lines.append("") - out = ctx.new_file(ctx.attr.name + ".sh") - ctx.file_action(out, "\n".join(lines)) - return DefaultInfo(files=depset([out])) + lines = [] + + # Bazel uses the same command for C and C++ compilation. + c_compiler = ctx.var["CC"] + + # Bare minimum cflags to get included test binaries to link. + # + # See //tools:bazel.rc for the full set. + asan_flags = ["-fsanitize=address,undefined"] + tsan_flags = ["-fsanitize=thread"] + + # Older versions of GCC in Ubuntu, including GCC 5 used in CI images, + # incorrectly invoke the older `/usr/bin/ld` with gold-specific options when + # building with sanitizers enabled. Work around this by forcing use of gold + # in sanitize mode. + # + # This is not a great solution because it doesn't detect GCC when Bazel has + # wrapped it in an intermediate script, but it works well enough to keep CI + # running. + # + # https://stackoverflow.com/questions/37603238/fsanitize-not-using-gold-linker-in-gcc-6-1 + force_ld_gold = [] + if "gcc" in c_compiler or "g++" in c_compiler: + force_ld_gold = ["-fuse-ld=gold"] + + cc_flags = [] + ld_flags = [] + ld_libs = [] + if ctx.var.get("ENVOY_CONFIG_COVERAGE"): + ld_libs += ["-lgcov"] + if ctx.var.get("ENVOY_CONFIG_ASAN"): + cc_flags += asan_flags + ld_flags += asan_flags + ld_flags += force_ld_gold + if ctx.var.get("ENVOY_CONFIG_TSAN"): + cc_flags += tsan_flags + ld_flags += tsan_flags + ld_flags += force_ld_gold + + lines.append("export CFLAGS=%r" % (" ".join(cc_flags),)) + lines.append("export LDFLAGS=%r" % (" ".join(ld_flags),)) + lines.append("export LIBS=%r" % (" ".join(ld_libs),)) + lines.append("export CC=%s" % (_absolute_bin(c_compiler),)) + lines.append("export CXX=%s" % (_absolute_bin(c_compiler),)) + + # Some Autoconf helper binaries leak, which makes ./configure think the + # system is unable to do anything. Turn off leak checking during part of + # the build. + lines.append("export ASAN_OPTIONS=detect_leaks=0") + + lines.append("") + out = ctx.new_file(ctx.attr.name + ".sh") + ctx.file_action(out, "\n".join(lines)) + return DefaultInfo(files = depset([out])) genrule_environment = rule( implementation = _genrule_environment, diff --git a/bazel/patched_http_archive.bzl b/bazel/patched_http_archive.bzl index 87b4be7737345..8a6d54881cdfe 100644 --- a/bazel/patched_http_archive.bzl +++ b/bazel/patched_http_archive.bzl @@ -1,9 +1,9 @@ def _patched_http_archive(ctx): ctx.download_and_extract( ctx.attr.urls, - "", # output + "", # output ctx.attr.sha256, - "", # type + "", # type ctx.attr.strip_prefix, ) for ii, patch in enumerate(ctx.attr.patches): @@ -11,7 +11,7 @@ def _patched_http_archive(ctx): ctx.symlink(patch, patch_input) patch_result = ctx.execute(["patch", "-p0", "--input", patch_input]) if patch_result.return_code != 0: - fail("Failed to apply patch %r: %s" % (patch, patch_result.stderr)) + fail("Failed to apply patch %r: %s" % (patch, patch_result.stderr)) patched_http_archive = repository_rule( attrs = { diff --git a/bazel/repositories.bat b/bazel/repositories.bat new file mode 100644 index 0000000000000..7b66957105932 --- /dev/null +++ b/bazel/repositories.bat @@ -0,0 +1,4 @@ +echo "Start" +@ECHO OFF +%BAZEL_SH% -c "./repositories.sh %*" +exit %ERRORLEVEL% diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index b55e5c70fc554..3d231c259ff00 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -7,6 +7,12 @@ load(":genrule_repository.bzl", "genrule_repository") load(":patched_http_archive.bzl", "patched_http_archive") load(":repository_locations.bzl", "REPOSITORY_LOCATIONS") load(":target_recipes.bzl", "TARGET_RECIPES") +load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", + "find_vc_path", + "setup_vc_env_vars", +) +load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "get_env_var") def _repository_impl(name, **kwargs): # `existing_rule_keys` contains the names of repositories that have already @@ -25,8 +31,9 @@ def _repository_impl(name, **kwargs): # user a useful error if they accidentally specify a tag. if "tag" in location: fail( - "Refusing to depend on Git tag %r for external dependency %r: use 'commit' instead." - % (location["tag"], name)) + "Refusing to depend on Git tag %r for external dependency %r: use 'commit' instead." % + (location["tag"], name), + ) if "commit" in location: # Git repository at given commit ID. Add a BUILD file if requested. @@ -35,13 +42,15 @@ def _repository_impl(name, **kwargs): name = name, remote = location["remote"], commit = location["commit"], - **kwargs) + **kwargs + ) else: git_repository( name = name, remote = location["remote"], commit = location["commit"], - **kwargs) + **kwargs + ) else: # HTTP # HTTP tarball at a given URL. Add a BUILD file if requested. if "build_file" in kwargs: @@ -50,33 +59,54 @@ def _repository_impl(name, **kwargs): urls = location["urls"], sha256 = location["sha256"], strip_prefix = location["strip_prefix"], - **kwargs) + **kwargs + ) else: native.http_archive( name = name, urls = location["urls"], sha256 = location["sha256"], strip_prefix = location["strip_prefix"], - **kwargs) + **kwargs + ) def _build_recipe_repository_impl(ctxt): # Setup the build directory with links to the relevant files. ctxt.symlink(Label("//bazel:repositories.sh"), "repositories.sh") - ctxt.symlink(Label("//ci/build_container:build_and_install_deps.sh"), - "build_and_install_deps.sh") + ctxt.symlink(Label("//bazel:repositories.bat"), "repositories.bat") + ctxt.symlink( + Label("//ci/build_container:build_and_install_deps.sh"), + "build_and_install_deps.sh", + ) ctxt.symlink(Label("//ci/build_container:recipe_wrapper.sh"), "recipe_wrapper.sh") ctxt.symlink(Label("//ci/build_container:Makefile"), "Makefile") for r in ctxt.attr.recipes: - ctxt.symlink(Label("//ci/build_container/build_recipes:" + r + ".sh"), - "build_recipes/" + r + ".sh") + ctxt.symlink( + Label("//ci/build_container/build_recipes:" + r + ".sh"), + "build_recipes/" + r + ".sh", + ) ctxt.symlink(Label("//ci/prebuilt:BUILD"), "BUILD") # Run the build script. - environment = {} + command = [] + env = {} + if ctxt.os.name.upper().startswith("WINDOWS"): + vc_path = find_vc_path(ctxt) + current_path = get_env_var(ctxt, "PATH", None, False) + env = setup_vc_env_vars(ctxt, vc_path) + env["PATH"] += (";%s" % current_path) + env["CC"] = "cl" + env["CXX"] = "cl" + env["CXXFLAGS"] = "-DNDEBUG" + env["CFLAGS"] = "-DNDEBUG" + command = ["./repositories.bat"] + ctxt.attr.recipes + else: + command = ["./repositories.sh"] + ctxt.attr.recipes + print("Fetching external dependencies...") result = ctxt.execute( - ["./repositories.sh"] + ctxt.attr.recipes, - environment = environment, + command, + environment = env, quiet = False, ) print(result.stdout) @@ -86,6 +116,7 @@ def _build_recipe_repository_impl(ctxt): print("\033[31;1m\033[48;5;226m External dependency build failed, check above log " + "for errors and ensure all prerequisites at " + "https://github.com/envoyproxy/envoy/blob/master/bazel/README.md#quick-start-bazel-build-for-developers are met.") + # This error message doesn't appear to the user :( https://github.com/bazelbuild/bazel/issues/3683 fail("External dep build failed") @@ -97,7 +128,7 @@ def _default_envoy_build_config_impl(ctx): _default_envoy_build_config = repository_rule( implementation = _default_envoy_build_config_impl, attrs = { - "config": attr.label(default="@envoy//source/extensions:extensions_build_config.bzl"), + "config": attr.label(default = "@envoy//source/extensions:extensions_build_config.bzl"), }, ) @@ -113,12 +144,12 @@ def _default_envoy_api_impl(ctx): "tools", ] for d in api_dirs: - ctx.symlink(ctx.path(ctx.attr.api).dirname.get_child(d), d) + ctx.symlink(ctx.path(ctx.attr.api).dirname.get_child(d), d) _default_envoy_api = repository_rule( implementation = _default_envoy_api_impl, attrs = { - "api": attr.label(default="@envoy//api:BUILD"), + "api": attr.label(default = "@envoy//api:BUILD"), }, ) @@ -141,6 +172,22 @@ def _python_deps(): name = "jinja2", actual = "@com_github_pallets_jinja//:jinja2", ) + _repository_impl( + name = "com_github_apache_thrift", + build_file = "@envoy//bazel/external:apache_thrift.BUILD", + ) + _repository_impl( + name = "com_github_twitter_common_lang", + build_file = "@envoy//bazel/external:twitter_common_lang.BUILD", + ) + _repository_impl( + name = "com_github_twitter_common_rpc", + build_file = "@envoy//bazel/external:twitter_common_rpc.BUILD", + ) + _repository_impl( + name = "com_github_twitter_common_finagle_thrift", + build_file = "@envoy//bazel/external:twitter_common_finagle_thrift.BUILD", + ) # Bazel native C++ dependencies. For the depedencies that doesn't provide autoconf/automake builds. def _cc_deps(): @@ -164,8 +211,12 @@ def _envoy_api_deps(): # Treat the data plane API as an external repo, this simplifies exporting the API to # https://github.com/envoyproxy/data-plane-api. if "envoy_api" not in native.existing_rules().keys(): - _default_envoy_api(name="envoy_api") + _default_envoy_api(name = "envoy_api") + native.bind( + name = "api_httpbody_protos", + actual = "@googleapis//:api_httpbody_protos", + ) native.bind( name = "http_api_protos", actual = "@googleapis//:http_api_protos", @@ -187,7 +238,7 @@ def envoy_dependencies(path = "@envoy_deps//", skip_targets = []): "CXX", "CFLAGS", "CXXFLAGS", - "LD_LIBRARY_PATH" + "LD_LIBRARY_PATH", ], # Don't pretend we're in the sandbox, we do some evil stuff with envoy_dep_cache. local = True, @@ -435,32 +486,32 @@ def _com_github_grpc_grpc(): # Rebind some stuff to match what the gRPC Bazel is expecting. native.bind( - name = "protobuf_headers", - actual = "@com_google_protobuf//:protobuf_headers", + name = "protobuf_headers", + actual = "@com_google_protobuf//:protobuf_headers", ) native.bind( - name = "libssl", - actual = "//external:ssl", + name = "libssl", + actual = "//external:ssl", ) native.bind( - name = "cares", - actual = "//external:ares", + name = "cares", + actual = "//external:ares", ) native.bind( - name = "grpc", - actual = "@com_github_grpc_grpc//:grpc++" + name = "grpc", + actual = "@com_github_grpc_grpc//:grpc++", ) native.bind( - name = "grpc_health_proto", - actual = "@envoy//bazel:grpc_health_proto", + name = "grpc_health_proto", + actual = "@envoy//bazel:grpc_health_proto", ) def _com_github_google_jwt_verify(): _repository_impl("com_github_google_jwt_verify") native.bind( - name = "jwt_verify_lib", - actual = "@com_github_google_jwt_verify//:jwt_verify_lib", + name = "jwt_verify_lib", + actual = "@com_github_google_jwt_verify//:jwt_verify_lib", ) diff --git a/bazel/repository_locations.bzl b/bazel/repository_locations.bzl index 097e76ad603b6..7a65f4c956f3a 100644 --- a/bazel/repository_locations.bzl +++ b/bazel/repository_locations.bzl @@ -1,19 +1,24 @@ REPOSITORY_LOCATIONS = dict( boringssl = dict( # Use commits from branch "chromium-stable-with-bazel" - commit = "2a52ce799382c87cd3119f3b44fbbebf97061ab6", # chromium-67.0.3396.62 + commit = "372daf7042ffe3da1335743e7c93d78f1399aba7", # chromium-68.0.3440.75 remote = "https://github.com/google/boringssl", ), com_google_absl = dict( commit = "92020a042c0cd46979db9f6f0cb32783dc07765e", # 2018-06-08 remote = "https://github.com/abseil/abseil-cpp", ), + com_github_apache_thrift = dict( + sha256 = "7d59ac4fdcb2c58037ebd4a9da5f9a49e3e034bf75b3f26d9fe48ba3d8806e6b", + urls = ["https://files.pythonhosted.org/packages/c6/b4/510617906f8e0c5660e7d96fbc5585113f83ad547a3989b80297ac72a74c/thrift-0.11.0.tar.gz"], # 0.11.0 + strip_prefix = "thrift-0.11.0", + ), com_github_bombela_backward = dict( commit = "44ae9609e860e3428cd057f7052e505b4819eb84", # 2018-02-06 remote = "https://github.com/bombela/backward-cpp", ), com_github_circonus_labs_libcircllhist = dict( - commit = "476687ac9cc636fc92ac3070246d757ae6854547", # 2018-05-08 + commit = "050da53a44dede7bda136b93a9aeef47bd91fa12", # 2018-07-02 remote = "https://github.com/circonus-labs/libcircllhist", ), com_github_cyan4973_xxhash = dict( @@ -43,16 +48,16 @@ REPOSITORY_LOCATIONS = dict( remote = "https://github.com/google/libprotobuf-mutator", ), com_github_grpc_grpc = dict( - commit = "bec3b5ada2c5e5d782dff0b7b5018df646b65cb0", # v1.12.0 + commit = "bec3b5ada2c5e5d782dff0b7b5018df646b65cb0", # v1.12.0 remote = "https://github.com/grpc/grpc.git", ), io_opentracing_cpp = dict( - commit = "3b36b084a4d7fffc196eac83203cf24dfb8696b3", # v1.4.2 + commit = "3b36b084a4d7fffc196eac83203cf24dfb8696b3", # v1.4.2 remote = "https://github.com/opentracing/opentracing-cpp", ), com_lightstep_tracer_cpp = dict( commit = "ae6a6bba65f8c4d438a6a3ac855751ca8f52e1dc", - remote = "https://github.com/lightstep/lightstep-tracer-cpp", # v0.7.1 + remote = "https://github.com/lightstep/lightstep-tracer-cpp", # v0.7.1 ), lightstep_vendored_googleapis = dict( commit = "d6f78d948c53f3b400bb46996eb3084359914f9b", @@ -63,9 +68,11 @@ REPOSITORY_LOCATIONS = dict( remote = "https://github.com/google/jwt_verify_lib", ), com_github_nodejs_http_parser = dict( - # 2018-05-30 snapshot to pick up a performance fix, nodejs/http-parser PR 422 + # 2018-07-20 snapshot to pick up: + # A performance fix, nodejs/http-parser PR 422. + # A bug fix, nodejs/http-parser PR 432. # TODO(brian-pane): Upgrade to the next http-parser release once it's available - commit = "cf69c8eda9fe79e4682598a7b3d39338dea319a3", + commit = "77310eeb839c4251c07184a5db8885a572a08352", remote = "https://github.com/nodejs/http-parser", ), com_github_pallets_jinja = dict( @@ -80,6 +87,21 @@ REPOSITORY_LOCATIONS = dict( commit = "f54b0e47a08782a6131cc3d60f94d038fa6e0a51", # v1.1.0 remote = "https://github.com/tencent/rapidjson", ), + com_github_twitter_common_lang = dict( + sha256 = "56d1d266fd4767941d11c27061a57bc1266a3342e551bde3780f9e9eb5ad0ed1", + urls = ["https://files.pythonhosted.org/packages/08/bc/d6409a813a9dccd4920a6262eb6e5889e90381453a5f58938ba4cf1d9420/twitter.common.lang-0.3.9.tar.gz"], # 0.3.9 + strip_prefix = "twitter.common.lang-0.3.9/src", + ), + com_github_twitter_common_rpc = dict( + sha256 = "0792b63fb2fb32d970c2e9a409d3d00633190a22eb185145fe3d9067fdaa4514", + urls = ["https://files.pythonhosted.org/packages/be/97/f5f701b703d0f25fbf148992cd58d55b4d08d3db785aad209255ee67e2d0/twitter.common.rpc-0.3.9.tar.gz"], # 0.3.9 + strip_prefix = "twitter.common.rpc-0.3.9/src", + ), + com_github_twitter_common_finagle_thrift = dict( + sha256 = "1e3a57d11f94f58745e6b83348ecd4fa74194618704f45444a15bc391fde497a", + urls = ["https://files.pythonhosted.org/packages/f9/e7/4f80d582578f8489226370762d2cf6bc9381175d1929eba1754e03f70708/twitter.common.finagle-thrift-0.3.9.tar.gz"], # 0.3.9 + strip_prefix = "twitter.common.finagle-thrift-0.3.9/src", + ), com_google_googletest = dict( commit = "43863938377a9ea1399c0596269e0890b5c5515a", remote = "https://github.com/google/googletest", diff --git a/bazel/target_recipes.bzl b/bazel/target_recipes.bzl index 002780148a4e2..6260336887927 100644 --- a/bazel/target_recipes.bzl +++ b/bazel/target_recipes.bzl @@ -5,7 +5,6 @@ TARGET_RECIPES = { "ares": "cares", "benchmark": "benchmark", "event": "libevent", - "event_pthreads": "libevent", "tcmalloc_and_profiler": "gperftools", "luajit": "luajit", "nghttp2": "nghttp2", diff --git a/ci/build_container/build_container_centos.sh b/ci/build_container/build_container_centos.sh index f26971230c3df..d416fddea6f7c 100755 --- a/ci/build_container/build_container_centos.sh +++ b/ci/build_container/build_container_centos.sh @@ -9,7 +9,7 @@ curl -L -o /etc/yum.repos.d/alonid-llvm-5.0.0-epel-7.repo \ # dependencies for bazel and build_recipes yum install -y java-1.8.0-openjdk-devel unzip which openssl rpm-build \ - cmake3 devtoolset-4-gcc-c++ git golang libtool make patch rsync wget \ + cmake3 devtoolset-4-gcc-c++ git golang libtool make ninja-build patch rsync wget \ clang-5.0.0 devtoolset-4-libatomic-devel llvm-5.0.0 python-virtualenv bc yum clean all diff --git a/ci/build_container/build_container_ubuntu.sh b/ci/build_container/build_container_ubuntu.sh index ff37f0fe1e912..e107bd1d2deb2 100755 --- a/ci/build_container/build_container_ubuntu.sh +++ b/ci/build_container/build_container_ubuntu.sh @@ -6,7 +6,7 @@ set -e apt-get update export DEBIAN_FRONTEND=noninteractive apt-get install -y wget software-properties-common make cmake git python python-pip \ - bc libtool automake zip time golang g++ gdb strace wireshark tshark + bc libtool ninja-build automake zip time golang g++ gdb strace wireshark tshark # clang head (currently 5.0) wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - apt-add-repository "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main" @@ -23,5 +23,5 @@ rm -rf /var/lib/apt/lists/* # virtualenv pip install virtualenv -EXPECTED_CXX_VERSION="g++ (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609" ./build_container_common.sh +EXPECTED_CXX_VERSION="g++ (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609" ./build_container_common.sh diff --git a/ci/build_container/build_recipes/benchmark.sh b/ci/build_container/build_recipes/benchmark.sh index 5e8f2f41ab2c5..6817ea42a291e 100644 --- a/ci/build_container/build_recipes/benchmark.sh +++ b/ci/build_container/build_recipes/benchmark.sh @@ -8,11 +8,17 @@ git clone https://github.com/google/benchmark.git mkdir build cd build -cmake -G "Unix Makefiles" ../benchmark \ +cmake -G "Ninja" ../benchmark \ -DCMAKE_BUILD_TYPE=RELEASE \ -DBENCHMARK_ENABLE_GTEST_TESTS=OFF -make -cp src/libbenchmark.a "$THIRDPARTY_BUILD"/lib +ninja + +benchmark_lib="libbenchmark.a" +if [[ "${OS}" == "Windows_NT" ]]; then + benchmark_lib="benchmark.lib" +fi + +cp "src/$benchmark_lib" "$THIRDPARTY_BUILD"/lib cd ../benchmark INCLUDE_DIR="$THIRDPARTY_BUILD/include/testing/base/public" diff --git a/ci/build_container/build_recipes/cares.sh b/ci/build_container/build_recipes/cares.sh index b3797f432e99d..d4191ae7fadd8 100755 --- a/ci/build_container/build_recipes/cares.sh +++ b/ci/build_container/build_recipes/cares.sh @@ -10,10 +10,31 @@ VERSION=cares-1_14_0 CPPFLAGS="$(for f in $CXXFLAGS; do if [[ $f =~ -D.* ]]; then echo $f; fi; done | tr '\n' ' ')" CFLAGS="$(for f in $CXXFLAGS; do if [[ ! $f =~ -D.* ]]; then echo $f; fi; done | tr '\n' ' ')" -wget -O c-ares-"$VERSION".tar.gz https://github.com/c-ares/c-ares/archive/"$VERSION".tar.gz +curl https://github.com/c-ares/c-ares/archive/"$VERSION".tar.gz -sLo c-ares-"$VERSION".tar.gz tar xf c-ares-"$VERSION".tar.gz cd c-ares-"$VERSION" -./buildconf -./configure --prefix="$THIRDPARTY_BUILD" --enable-shared=no --enable-lib-only \ - --enable-debug --enable-optimize -make V=1 install + +mkdir build +cd build + +build_type=RelWithDebInfo +if [[ "${OS}" == "Windows_NT" ]]; then + # On Windows, every object file in the final executable needs to be compiled to use the + # same version of the C Runtime Library. If Envoy is built with '-c dbg', then it will + # use the Debug C Runtime Library. Setting CMAKE_BUILD_TYPE to Debug will cause c-ares + # to use the debug version as well + # TODO: when '-c fastbuild' and '-c opt' work for Windows builds, set this appropriately + build_type=Debug +fi + +cmake -G "Ninja" -DCMAKE_INSTALL_PREFIX="$THIRDPARTY_BUILD" \ + -DCARES_SHARED=no \ + -DCARES_STATIC=on \ + -DCMAKE_BUILD_TYPE="$build_type" \ + .. +ninja +ninja install + +if [[ "${OS}" == "Windows_NT" ]]; then + cp "CMakeFiles/c-ares.dir/c-ares.pdb" "$THIRDPARTY_BUILD/lib/c-ares.pdb" +fi diff --git a/ci/build_container/build_recipes/gperftools.sh b/ci/build_container/build_recipes/gperftools.sh index 7c0c72d9c6e6e..de18e91a526d3 100755 --- a/ci/build_container/build_recipes/gperftools.sh +++ b/ci/build_container/build_recipes/gperftools.sh @@ -2,9 +2,13 @@ set -e +if [[ "${OS}" == "Windows_NT" ]]; then + exit 0 +fi + VERSION=2.7 -wget -O gperftools-"$VERSION".tar.gz https://github.com/gperftools/gperftools/releases/download/gperftools-"$VERSION"/gperftools-"$VERSION".tar.gz +curl https://github.com/gperftools/gperftools/releases/download/gperftools-"$VERSION"/gperftools-"$VERSION".tar.gz -sLo gperftools-"$VERSION".tar.gz tar xf gperftools-"$VERSION".tar.gz cd gperftools-"$VERSION" diff --git a/ci/build_container/build_recipes/libevent.sh b/ci/build_container/build_recipes/libevent.sh index c88d5bb3a2ed3..0bd783cf4bdd0 100755 --- a/ci/build_container/build_recipes/libevent.sh +++ b/ci/build_container/build_recipes/libevent.sh @@ -4,8 +4,33 @@ set -e VERSION=2.1.8-stable -wget -O libevent-"$VERSION".tar.gz https://github.com/libevent/libevent/releases/download/release-"$VERSION"/libevent-"$VERSION".tar.gz -tar xf libevent-"$VERSION".tar.gz -cd libevent-"$VERSION" -./configure --prefix="$THIRDPARTY_BUILD" --enable-shared=no --disable-libevent-regress --disable-openssl -make V=1 install +curl https://github.com/libevent/libevent/archive/release-"$VERSION".tar.gz -sLo libevent-release-"$VERSION".tar.gz +tar xf libevent-release-"$VERSION".tar.gz +cd libevent-release-"$VERSION" + +mkdir build +cd build + +# libevent defaults CMAKE_BUILD_TYPE to Release +build_type=Release +if [[ "${OS}" == "Windows_NT" ]]; then + # On Windows, every object file in the final executable needs to be compiled to use the + # same version of the C Runtime Library. If Envoy is built with '-c dbg', then it will + # use the Debug C Runtime Library. Setting CMAKE_BUILD_TYPE to Debug will cause libevent + # to use the debug version as well + # TODO: when '-c fastbuild' and '-c opt' work for Windows builds, set this appropriately + build_type=Debug +fi + +cmake -G "Ninja" \ + -DCMAKE_INSTALL_PREFIX="$THIRDPARTY_BUILD" \ + -DEVENT__DISABLE_OPENSSL:BOOL=on \ + -DEVENT__DISABLE_REGRESS:BOOL=on \ + -DCMAKE_BUILD_TYPE="$build_type" \ + .. +ninja +ninja install + +if [[ "${OS}" == "Windows_NT" ]]; then + cp "CMakeFiles/event.dir/event.pdb" "$THIRDPARTY_BUILD/lib/event.pdb" +fi diff --git a/ci/build_container/build_recipes/luajit.sh b/ci/build_container/build_recipes/luajit.sh index 4deba13a51d46..3b02133d34dc7 100644 --- a/ci/build_container/build_recipes/luajit.sh +++ b/ci/build_container/build_recipes/luajit.sh @@ -4,7 +4,7 @@ set -e VERSION=2.0.5 -wget -O LuaJIT-"$VERSION".tar.gz https://github.com/LuaJIT/LuaJIT/archive/v"$VERSION".tar.gz +curl https://github.com/LuaJIT/LuaJIT/archive/v"$VERSION".tar.gz -sLo LuaJIT-"$VERSION".tar.gz tar xf LuaJIT-"$VERSION".tar.gz cd LuaJIT-"$VERSION" @@ -46,15 +46,26 @@ index f7f81a4..e698517 100644 # Disable the JIT compiler, i.e. turn LuaJIT into a pure interpreter. #XCFLAGS+= -DLUAJIT_DISABLE_JIT @@ -564,7 +564,7 @@ endif - + Q= @ E= @echo -#Q= +Q= #E= @: - + ############################################################################## EOF -patch -p1 < ../luajit_make.diff -DEFAULT_CC=${CC} TARGET_CFLAGS=${CFLAGS} TARGET_LDFLAGS=${CFLAGS} CFLAGS="" make V=1 PREFIX="$THIRDPARTY_BUILD" install +if [[ "${OS}" == "Windows_NT" ]]; then + cd src + ./msvcbuild.bat debug + + mkdir -p "$THIRDPARTY_BUILD/include/luajit-2.0" + cp *.h* "$THIRDPARTY_BUILD/include/luajit-2.0" + cp luajit.lib "$THIRDPARTY_BUILD/lib" + cp *.pdb "$THIRDPARTY_BUILD/lib" +else + patch -p1 < ../luajit_make.diff + + DEFAULT_CC=${CC} TARGET_CFLAGS=${CFLAGS} TARGET_LDFLAGS=${CFLAGS} CFLAGS="" make V=1 PREFIX="$THIRDPARTY_BUILD" install +fi diff --git a/ci/build_container/build_recipes/nghttp2.sh b/ci/build_container/build_recipes/nghttp2.sh index 1b380f3856d2e..cea6ab963292a 100755 --- a/ci/build_container/build_recipes/nghttp2.sh +++ b/ci/build_container/build_recipes/nghttp2.sh @@ -2,10 +2,47 @@ set -e -VERSION=1.32.0 +# Use master branch, which contains a fix for the spurious limit of 100 concurrent streams: +# https://github.com/nghttp2/nghttp2/commit/2ba1389993729fcb6ee5794ac512f2b67b29952e +# TODO(PiotrSikora): switch back to releases once v1.33.0 is out. +VERSION=e5b3f9addd49bca27e2f99c5c65a564eb5c0cf6d # 2018-06-09 -wget -O nghttp2-"$VERSION".tar.gz https://github.com/nghttp2/nghttp2/releases/download/v"$VERSION"/nghttp2-"$VERSION".tar.gz +curl https://github.com/nghttp2/nghttp2/archive/"$VERSION".tar.gz -sLo nghttp2-"$VERSION".tar.gz tar xf nghttp2-"$VERSION".tar.gz cd nghttp2-"$VERSION" -./configure --prefix="$THIRDPARTY_BUILD" --enable-shared=no --enable-lib-only -make V=1 install + +# Allow nghttp2 to build as static lib on Windows +# TODO: remove once https://github.com/nghttp2/nghttp2/pull/1198 is merged +cat > nghttp2_cmakelists.diff << 'EOF' +diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt +index 17e422b2..e58070f5 100644 +--- a/lib/CMakeLists.txt ++++ b/lib/CMakeLists.txt +@@ -56,6 +56,7 @@ if(HAVE_CUNIT OR ENABLE_STATIC_LIB) + COMPILE_FLAGS "${WARNCFLAGS}" + VERSION ${LT_VERSION} SOVERSION ${LT_SOVERSION} + ARCHIVE_OUTPUT_NAME nghttp2 ++ ARCHIVE_OUTPUT_DIRECTORY static + ) + target_compile_definitions(nghttp2_static PUBLIC "-DNGHTTP2_STATICLIB") + if(ENABLE_STATIC_LIB) +EOF + +if [[ "${OS}" == "Windows_NT" ]]; then + git apply nghttp2_cmakelists.diff +fi + +mkdir build +cd build + +cmake -G "Ninja" -DCMAKE_INSTALL_PREFIX="$THIRDPARTY_BUILD" \ + -DCMAKE_INSTALL_LIBDIR="$THIRDPARTY_BUILD/lib" \ + -DENABLE_STATIC_LIB=on \ + -DENABLE_LIB_ONLY=on \ + .. +ninja +ninja install + +if [[ "${OS}" == "Windows_NT" ]]; then + cp "lib/CMakeFiles/nghttp2_static.dir/nghttp2_static.pdb" "$THIRDPARTY_BUILD/lib/nghttp2_static.pdb" +fi diff --git a/ci/build_container/build_recipes/yaml-cpp.sh b/ci/build_container/build_recipes/yaml-cpp.sh index db63dcb3a11ef..2c565cfd1bf6c 100755 --- a/ci/build_container/build_recipes/yaml-cpp.sh +++ b/ci/build_container/build_recipes/yaml-cpp.sh @@ -4,11 +4,31 @@ set -e VERSION=0.6.2 -wget -O yaml-cpp-"$VERSION".tar.gz https://github.com/jbeder/yaml-cpp/archive/yaml-cpp-"$VERSION".tar.gz +curl https://github.com/jbeder/yaml-cpp/archive/yaml-cpp-"$VERSION".tar.gz -sLo yaml-cpp-"$VERSION".tar.gz tar xf yaml-cpp-"$VERSION".tar.gz cd yaml-cpp-yaml-cpp-"$VERSION" -cmake -DCMAKE_INSTALL_PREFIX:PATH="$THIRDPARTY_BUILD" \ + +mkdir build +cd build + +build_type=RelWithDebInfo +if [[ "${OS}" == "Windows_NT" ]]; then + # On Windows, every object file in the final executable needs to be compiled to use the + # same version of the C Runtime Library. If Envoy is built with '-c dbg', then it will + # use the Debug C Runtime Library. Setting CMAKE_BUILD_TYPE to Debug will cause yaml-cpp + # to use the debug version as well + # TODO: when '-c fastbuild' and '-c opt' work for Windows builds, set this appropriately + build_type=Debug +fi + +cmake -G "Ninja" -DCMAKE_INSTALL_PREFIX:PATH="$THIRDPARTY_BUILD" \ -DCMAKE_CXX_FLAGS:STRING="${CXXFLAGS} ${CPPFLAGS}" \ -DCMAKE_C_FLAGS:STRING="${CFLAGS} ${CPPFLAGS}" \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo . -make VERBOSE=1 install + -DYAML_CPP_BUILD_TESTS=off \ + -DCMAKE_BUILD_TYPE="$build_type" \ + .. +ninja install + +if [[ "${OS}" == "Windows_NT" ]]; then + cp "CMakeFiles/yaml-cpp.dir/yaml-cpp.pdb" "$THIRDPARTY_BUILD/lib/yaml-cpp.pdb" +fi diff --git a/ci/build_container/build_recipes/zlib.sh b/ci/build_container/build_recipes/zlib.sh index fd22ea67f0af5..62997062f1491 100644 --- a/ci/build_container/build_recipes/zlib.sh +++ b/ci/build_container/build_recipes/zlib.sh @@ -4,8 +4,15 @@ set -e VERSION=1.2.11 -wget -O zlib-"$VERSION".tar.gz https://github.com/madler/zlib/archive/v"$VERSION".tar.gz +curl https://github.com/madler/zlib/archive/v"$VERSION".tar.gz -sLo zlib-"$VERSION".tar.gz tar xf zlib-"$VERSION".tar.gz cd zlib-"$VERSION" -./configure --prefix="$THIRDPARTY_BUILD" -make V=1 install +mkdir build +cd build +cmake -G "Ninja" -DCMAKE_INSTALL_PREFIX:PATH="$THIRDPARTY_BUILD" .. +ninja +ninja install + +if [[ "${OS}" == "Windows_NT" ]]; then + cp "CMakeFiles/zlibstatic.dir/zlibstatic.pdb" "$THIRDPARTY_BUILD/lib/zlibstatic.pdb" +fi diff --git a/ci/build_setup.ps1 b/ci/build_setup.ps1 new file mode 100755 index 0000000000000..12a3aeff987f5 --- /dev/null +++ b/ci/build_setup.ps1 @@ -0,0 +1,22 @@ +$ErrorActionPreference = "Stop"; +trap { $host.SetShouldExit(1) } + +if ("$env:NUM_CPUS" -eq "") { + $env:NUM_CPUS = (Get-WmiObject -class Win32_computersystem).NumberOfLogicalProcessors +} + +if ("$env:ENVOY_BAZEL_ROOT" -eq "") { + Write-Host "ENVOY_BAZEL_ROOT must be set!" + throw +} + +mkdir -force "$env:ENVOY_BAZEL_ROOT" > $nul + +$env:ENVOY_SRCDIR = [System.IO.Path]::GetFullPath("$PSScriptRoot\..") + +echo "ENVOY_BAZEL_ROOT: $env:ENVOY_BAZEL_ROOT" +echo "ENVOY_SRCDIR: $env:ENVOY_SRCDIR" + +$env:BAZEL_BASE_OPTIONS="--nomaster_bazelrc --output_base=$env:ENVOY_BAZEL_ROOT --bazelrc=$env:ENVOY_SRCDIR\windows\tools\bazel.rc" +$env:BAZEL_BUILD_OPTIONS="--strategy=Genrule=standalone --spawn_strategy=standalone --verbose_failures --jobs=$env:NUM_CPUS --show_task_finish $env:BAZEL_BUILD_EXTRA_OPTIONS" +$env:BAZEL_TEST_OPTIONS="$env:BAZEL_BUILD_OPTIONS --cache_test_results=no --test_output=all $env:BAZEL_EXTRA_TEST_OPTIONS" diff --git a/ci/build_setup.sh b/ci/build_setup.sh index 6264b66929cce..dba93c0dd78d3 100755 --- a/ci/build_setup.sh +++ b/ci/build_setup.sh @@ -87,7 +87,7 @@ if [ "$1" != "-nofetch" ]; then fi # This is the hash on https://github.com/envoyproxy/envoy-filter-example.git we pin to. - (cd "${ENVOY_FILTER_EXAMPLE_SRCDIR}" && git fetch origin && git checkout -f 4b6c55b726eda8a1f99e6f4ca1a87f6ce670604f) + (cd "${ENVOY_FILTER_EXAMPLE_SRCDIR}" && git fetch origin && git checkout -f 3e5b73305b961526ffcee7584251692a9a3ce4b3) cp -f "${ENVOY_SRCDIR}"/ci/WORKSPACE.filter.example "${ENVOY_FILTER_EXAMPLE_SRCDIR}"/WORKSPACE fi diff --git a/ci/do_ci.ps1 b/ci/do_ci.ps1 new file mode 100755 index 0000000000000..fa0aa691c1a7d --- /dev/null +++ b/ci/do_ci.ps1 @@ -0,0 +1,20 @@ +$ErrorActionPreference = "Stop"; +trap { $host.SetShouldExit(1) } + +. "$PSScriptRoot\build_setup.ps1" +Write-Host "building using $env:NUM_CPUS CPUs" + +function bazel_debug_binary_build() { + echo "Building..." + pushd "$env:ENVOY_SRCDIR" + bazel $env:BAZEL_BASE_OPTIONS.Split(" ") build $env:BAZEL_BUILD_OPTIONS.Split(" ") -c dbg "//source/exe:envoy-static" + $exit = $LASTEXITCODE + if ($exit -ne 0) { + popd + exit $exit + } + popd +} + +echo "bazel debug build..." +bazel_debug_binary_build diff --git a/ci/mac_ci_setup.sh b/ci/mac_ci_setup.sh index e44bd7d5430f1..decdfd75f4d97 100755 --- a/ci/mac_ci_setup.sh +++ b/ci/mac_ci_setup.sh @@ -21,7 +21,7 @@ if ! brew update; then exit 1 fi -DEPS="automake bazel cmake coreutils go libtool wget" +DEPS="automake bazel cmake coreutils go libtool wget ninja" for DEP in ${DEPS} do is_installed "${DEP}" || install "${DEP}" diff --git a/ci/prebuilt/BUILD b/ci/prebuilt/BUILD index 691644568905c..8997736ea30ae 100644 --- a/ci/prebuilt/BUILD +++ b/ci/prebuilt/BUILD @@ -2,36 +2,47 @@ licenses(["notice"]) # Apache 2 package(default_visibility = ["//visibility:public"]) +config_setting( + name = "windows_x86_64", + values = {"cpu": "x64_windows"}, +) + cc_library( name = "ares", - srcs = ["thirdparty_build/lib/libcares.a"], + srcs = select({ + ":windows_x86_64": ["thirdparty_build/lib/cares.lib"], + "//conditions:default": ["thirdparty_build/lib/libcares.a"], + }), hdrs = glob(["thirdparty_build/include/ares*.h"]), includes = ["thirdparty_build/include"], ) cc_library( name = "benchmark", - srcs = ["thirdparty_build/lib/libbenchmark.a"], + srcs = select({ + ":windows_x86_64": ["thirdparty_build/lib/benchmark.lib"], + "//conditions:default": ["thirdparty_build/lib/libbenchmark.a"], + }), hdrs = ["thirdparty_build/include/testing/base/public/benchmark.h"], includes = ["thirdparty_build/include"], ) cc_library( name = "event", - srcs = ["thirdparty_build/lib/libevent.a"], + srcs = select({ + ":windows_x86_64": ["thirdparty_build/lib/event.lib"], + "//conditions:default": ["thirdparty_build/lib/libevent.a"], + }), hdrs = glob(["thirdparty_build/include/event2/**/*.h"]), includes = ["thirdparty_build/include"], ) -cc_library( - name = "event_pthreads", - srcs = ["thirdparty_build/lib/libevent_pthreads.a"], - deps = [":event"], -) - cc_library( name = "luajit", - srcs = ["thirdparty_build/lib/libluajit-5.1.a"], + srcs = select({ + ":windows_x86_64": ["thirdparty_build/lib/luajit.lib"], + "//conditions:default": ["thirdparty_build/lib/libluajit-5.1.a"], + }), hdrs = glob(["thirdparty_build/include/luajit-2.0/*"]), includes = ["thirdparty_build/include"], # TODO(mattklein123): We should strip luajit-2.0 here for consumers. However, if we do that @@ -40,7 +51,10 @@ cc_library( cc_library( name = "nghttp2", - srcs = ["thirdparty_build/lib/libnghttp2.a"], + srcs = select({ + ":windows_x86_64": ["thirdparty_build/lib/nghttp2.lib"], + "//conditions:default": ["thirdparty_build/lib/libnghttp2.a"], + }), hdrs = glob(["thirdparty_build/include/nghttp2/**/*.h"]), includes = ["thirdparty_build/include"], ) @@ -54,14 +68,20 @@ cc_library( cc_library( name = "yaml_cpp", - srcs = ["thirdparty_build/lib/libyaml-cpp.a"], + srcs = select({ + ":windows_x86_64": glob(["thirdparty_build/lib/libyaml-cpp*.lib"]), + "//conditions:default": ["thirdparty_build/lib/libyaml-cpp.a"], + }), hdrs = glob(["thirdparty_build/include/yaml-cpp/**/*.h"]), includes = ["thirdparty_build/include"], ) cc_library( name = "zlib", - srcs = ["thirdparty_build/lib/libz.a"], + srcs = select({ + ":windows_x86_64": glob(["thirdparty_build/lib/zlibstaticd.lib"]), + "//conditions:default": ["thirdparty_build/lib/libz.a"], + }), hdrs = [ "thirdparty_build/include/zconf.h", "thirdparty_build/include/zlib.h", diff --git a/ci/run_envoy_docker.sh b/ci/run_envoy_docker.sh index d52172671f73b..3fd04a01e283e 100755 --- a/ci/run_envoy_docker.sh +++ b/ci/run_envoy_docker.sh @@ -18,7 +18,7 @@ USER_GROUP=root mkdir -p "${ENVOY_DOCKER_BUILD_DIR}" # Since we specify an explicit hash, docker-run will pull from the remote repo if missing. -docker run --rm -t -i -e http_proxy=${http_proxy} -e https_proxy=${https_proxy} \ +docker run --rm -t -i -e HTTP_PROXY=${http_proxy} -e HTTPS_PROXY=${https_proxy} \ -u "${USER}":"${USER_GROUP}" -v "${ENVOY_DOCKER_BUILD_DIR}":/build \ -v "$PWD":/source -e NUM_CPUS --cap-add SYS_PTRACE "${IMAGE_NAME}":"${IMAGE_ID}" \ /bin/bash -lc "groupadd --gid $(id -g) -f envoygroup && useradd -o --uid $(id -u) --gid $(id -g) \ diff --git a/configs/configgen.sh b/configs/configgen.sh index 703d09b2c5494..ff8b006da31e6 100755 --- a/configs/configgen.sh +++ b/configs/configgen.sh @@ -22,4 +22,4 @@ for FILE in $*; do done # tar is having issues with -C for some reason so just cd into OUT_DIR. -(cd "$OUT_DIR"; tar -cvf example_configs.tar *.json *.yaml certs/*.pem) +(cd "$OUT_DIR"; tar -hcvf example_configs.tar *.json *.yaml certs/*.pem) diff --git a/docs/build.sh b/docs/build.sh index a93ecea6caaf8..69ee02ce3f6e6 100755 --- a/docs/build.sh +++ b/docs/build.sh @@ -49,7 +49,9 @@ bazel --batch build ${BAZEL_BUILD_OPTIONS} @envoy_api//docs:protos --aspects \ # These are the protos we want to put in docs, this list will grow. # TODO(htuch): Factor this out of this script. PROTO_RST=" + /envoy/admin/v2alpha/clusters/envoy/admin/v2alpha/clusters.proto.rst /envoy/admin/v2alpha/config_dump/envoy/admin/v2alpha/config_dump.proto.rst + /envoy/admin/v2alpha/clusters/envoy/admin/v2alpha/metrics.proto.rst /envoy/api/v2/core/address/envoy/api/v2/core/address.proto.rst /envoy/api/v2/core/base/envoy/api/v2/core/base.proto.rst /envoy/api/v2/core/http_uri/envoy/api/v2/core/http_uri.proto.rst @@ -102,12 +104,17 @@ PROTO_RST=" /envoy/config/rbac/v2alpha/rbac/envoy/config/rbac/v2alpha/rbac.proto.rst /envoy/config/transport_socket/capture/v2alpha/capture/envoy/config/transport_socket/capture/v2alpha/capture.proto.rst /envoy/data/accesslog/v2/accesslog/envoy/data/accesslog/v2/accesslog.proto.rst + /envoy/data/core/v2alpha/health_check_event/envoy/data/core/v2alpha/health_check_event.proto.rst /envoy/data/tap/v2alpha/capture/envoy/data/tap/v2alpha/capture.proto.rst /envoy/service/accesslog/v2/als/envoy/service/accesslog/v2/als.proto.rst /envoy/service/auth/v2alpha/external_auth/envoy/service/auth/v2alpha/attribute_context.proto.rst /envoy/service/auth/v2alpha/external_auth/envoy/service/auth/v2alpha/external_auth.proto.rst + /envoy/type/http_status/envoy/type/http_status.proto.rst /envoy/type/percent/envoy/type/percent.proto.rst /envoy/type/range/envoy/type/range.proto.rst + /envoy/type/matcher/metadata/envoy/type/matcher/metadata.proto.rst + /envoy/type/matcher/number/envoy/type/matcher/number.proto.rst + /envoy/type/matcher/string/envoy/type/matcher/string.proto.rst " # Dump all the generated RST so they can be added to PROTO_RST easily. diff --git a/docs/root/api-v1/route_config/route.rst b/docs/root/api-v1/route_config/route.rst index 7b320b90ac77b..5f11b1708079a 100644 --- a/docs/root/api-v1/route_config/route.rst +++ b/docs/root/api-v1/route_config/route.rst @@ -487,7 +487,7 @@ string map. Nested objects are not supported. .. _config_http_conn_man_route_table_cors: -Cors +CORS -------- Settings on a route take precedence over settings on the virtual host. diff --git a/docs/root/api-v1/route_config/vhost.rst b/docs/root/api-v1/route_config/vhost.rst index 2d4662124101e..a677b79c53499 100644 --- a/docs/root/api-v1/route_config/vhost.rst +++ b/docs/root/api-v1/route_config/vhost.rst @@ -15,6 +15,7 @@ upstream cluster to route to or whether to perform a redirect. "name": "...", "domains": [], "routes": [], + "cors": {}, "require_ssl": "...", "virtual_clusters": [], "rate_limits": [], diff --git a/docs/root/api-v2/admin/admin.rst b/docs/root/api-v2/admin/admin.rst index 455ba12199f37..db0081b902492 100644 --- a/docs/root/api-v2/admin/admin.rst +++ b/docs/root/api-v2/admin/admin.rst @@ -6,3 +6,5 @@ Admin :maxdepth: 2 ../admin/v2alpha/config_dump.proto + ../admin/v2alpha/clusters.proto + ../admin/v2alpha/metrics.proto diff --git a/docs/root/api-v2/data/core/core.rst b/docs/root/api-v2/data/core/core.rst new file mode 100644 index 0000000000000..f9d7e77bf4d71 --- /dev/null +++ b/docs/root/api-v2/data/core/core.rst @@ -0,0 +1,8 @@ +Core data +========= + +.. toctree:: + :glob: + :maxdepth: 2 + + v2alpha/health_check_event.proto diff --git a/docs/root/api-v2/data/data.rst b/docs/root/api-v2/data/data.rst index e97e93bf879fc..fd7c5877e939f 100644 --- a/docs/root/api-v2/data/data.rst +++ b/docs/root/api-v2/data/data.rst @@ -6,4 +6,5 @@ Envoy data :maxdepth: 2 accesslog/accesslog + core/core tap/tap diff --git a/docs/root/api-v2/types/types.rst b/docs/root/api-v2/types/types.rst index 116d6c3cb519c..8b5750f9993f6 100644 --- a/docs/root/api-v2/types/types.rst +++ b/docs/root/api-v2/types/types.rst @@ -5,5 +5,9 @@ Types :glob: :maxdepth: 2 + ../type/http_status.proto ../type/percent.proto ../type/range.proto + ../type/matcher/metadata.proto + ../type/matcher/number.proto + ../type/matcher/string.proto diff --git a/docs/root/configuration/access_log.rst b/docs/root/configuration/access_log.rst index 88946fa75f669..94392f404c69f 100644 --- a/docs/root/configuration/access_log.rst +++ b/docs/root/configuration/access_log.rst @@ -102,6 +102,16 @@ The following command operators are supported: TCP Total duration in milliseconds of the downstream connection. +%RESPONSE_DURATION% + HTTP + Total duration in milliseconds of the request from the start time to the first byte read from the + upstream host. + + TCP + Not implemented ("-"). + +.. _config_access_log_format_response_flags: + %RESPONSE_FLAGS% Additional details about the response or connection, if any. For TCP connections, the response codes mentioned in the descriptions do not apply. Possible values are: @@ -121,6 +131,14 @@ The following command operators are supported: * **FI**: The request was aborted with a response code specified via :ref:`fault injection `. * **RL**: The request was ratelimited locally by the :ref:`HTTP rate limit filter ` in addition to 429 response code. +%RESPONSE_TX_DURATION% + HTTP + Total duration in milliseconds of the request from the first byte read from the upstream host to the last + byte sent downstream. + + TCP + Not implemented ("-"). + %UPSTREAM_HOST% Upstream host URL (e.g., tcp://ip:port for TCP connections). diff --git a/docs/root/configuration/cluster_manager/cluster_stats.rst b/docs/root/configuration/cluster_manager/cluster_stats.rst index 83fe8a10b0340..bfeb164c795a5 100644 --- a/docs/root/configuration/cluster_manager/cluster_stats.rst +++ b/docs/root/configuration/cluster_manager/cluster_stats.rst @@ -19,6 +19,10 @@ statistics. Any ``:`` character in the stats name is replaced with ``_``. cluster_added, Counter, Total clusters added (either via static config or CDS) cluster_modified, Counter, Total clusters modified (via CDS) cluster_removed, Counter, Total clusters removed (via CDS) + cluster_updated, Counter, Total cluster updates + cluster_updated_via_merge, Counter, Total cluster updates applied as merged updates + update_merge_cancelled, Counter, Total merged updates that got cancelled and delivered early + update_out_of_merge_window, Counter, Total updates which arrived out of a merge window active_clusters, Gauge, Number of currently active (warmed) clusters warming_clusters, Gauge, Number of currently warming (not active) clusters diff --git a/docs/root/configuration/health_checkers/redis.rst b/docs/root/configuration/health_checkers/redis.rst index 2439982ced8d2..1859c005adb1e 100644 --- a/docs/root/configuration/health_checkers/redis.rst +++ b/docs/root/configuration/health_checkers/redis.rst @@ -3,12 +3,22 @@ Redis ===== -The Redis health checker is a custom health checker which checks Redis upstream hosts. It sends -a Redis PING command and expect a PONG response. The upstream Redis server can respond with -anything other than PONG to cause an immediate active health check failure. Optionally, Envoy can -perform EXISTS on a user-specified key. If the key does not exist it is considered a passing healthcheck. -This allows the user to mark a Redis instance for maintenance by setting the specified -:ref:`key ` to any value and waiting for -traffic to drain. +The Redis health checker is a custom health checker (with :code:`envoy.health_checkers.redis` as name) +which checks Redis upstream hosts. It sends a Redis PING command and expect a PONG response. The upstream +Redis server can respond with anything other than PONG to cause an immediate active health check failure. +Optionally, Envoy can perform EXISTS on a user-specified key. If the key does not exist it is considered a +passing healthcheck. This allows the user to mark a Redis instance for maintenance by setting the +specified :ref:`key ` to any value and waiting +for traffic to drain. + +An example setting for :ref:`custom_health_check ` as a +Redis health checker is shown below: + +.. code-block:: yaml + + custom_health_check: + name: envoy.health_checkers.redis + config: + key: foo * :ref:`v2 API reference ` \ No newline at end of file diff --git a/docs/root/configuration/http_conn_man/stats.rst b/docs/root/configuration/http_conn_man/stats.rst index 1b9b13e1d63ec..9b32590f3362e 100644 --- a/docs/root/configuration/http_conn_man/stats.rst +++ b/docs/root/configuration/http_conn_man/stats.rst @@ -52,6 +52,7 @@ statistics: downstream_rq_5xx, Counter, Total 5xx responses downstream_rq_ws_on_non_ws_route, Counter, Total WebSocket upgrade requests rejected by non WebSocket routes downstream_rq_time, Histogram, Request time milliseconds + downstream_rq_idle_timeout, Counter, Total requests closed due to idle timeout rs_too_large, Counter, Total response errors due to buffering an overly large body Per user agent statistics diff --git a/docs/root/configuration/http_filters/cors_filter.rst b/docs/root/configuration/http_filters/cors_filter.rst index 436999a1d18dd..f7e0018cc5f4d 100644 --- a/docs/root/configuration/http_filters/cors_filter.rst +++ b/docs/root/configuration/http_filters/cors_filter.rst @@ -8,5 +8,5 @@ For the meaning of the headers please refer to the pages below. - https://developer.mozilla.org/en-US/docs/Web/HTTP/Access_control_CORS - https://www.w3.org/TR/cors/ -- :ref:`v1 API reference ` -- :ref:`v2 API reference ` +- :ref:`v1 API reference ` +- :ref:`v2 API reference ` diff --git a/docs/root/configuration/http_filters/ext_authz_filter.rst b/docs/root/configuration/http_filters/ext_authz_filter.rst index c9eabf8f94819..bd92156303970 100644 --- a/docs/root/configuration/http_filters/ext_authz_filter.rst +++ b/docs/root/configuration/http_filters/ext_authz_filter.rst @@ -5,9 +5,11 @@ External Authorization * External authorization :ref:`architecture overview ` * :ref:`HTTP filter v2 API reference ` -The external authorization HTTP filter calls an external gRPC service to check if the incoming +The external authorization HTTP filter calls an external gRPC or HTTP service to check if the incoming HTTP request is authorized or not. -If the request is deemed unauthorized then the request will be denied with 403 (Forbidden) response. +If the request is deemed unauthorized then the request will be denied normally with 403 (Forbidden) response. +Note that sending additional custom metadata from the authorization service to the upstream, or to the downstream is +also possible. This is explained in more details at :ref:`HTTP filter `. .. tip:: It is recommended that this filter is configured first in the filter chain so that requests are @@ -18,14 +20,14 @@ The content of the requests that are passed to an authorization service is speci .. _config_http_filters_ext_authz_http_configuration: -The HTTP filter, using a gRPC service, can be configured as follows. You can see all the +The HTTP filter, using a gRPC/HTTP service, can be configured as follows. You can see all the configuration options at :ref:`HTTP filter `. -Example -------- +Configuration Examples +----------------------------- -A sample filter configuration could be: +A sample filter configuration for a gRPC authorization server: .. code-block:: yaml @@ -36,6 +38,8 @@ A sample filter configuration could be: envoy_grpc: cluster_name: ext-authz +.. code-block:: yaml + clusters: - name: ext-authz type: static @@ -43,6 +47,30 @@ A sample filter configuration could be: hosts: - socket_address: { address: 127.0.0.1, port_value: 10003 } +A sample filter configuration for a raw HTTP authorization server: + +.. code-block:: yaml + + http_filters: + - name: envoy.ext_authz + config: + http_service: + server_uri: + uri: 127.0.0.1:10003 + cluster: ext-authz + timeout: 0.25s + failure_mode_allow: false + +.. code-block:: yaml + + clusters: + - name: ext-authz + connect_timeout: 0.25s + type: logical_dns + lb_policy: round_robin + hosts: + - socket_address: { address: 127.0.0.1, port_value: 10003 } + Statistics ---------- The HTTP filter outputs statistics in the *cluster..ext_authz.* namespace. diff --git a/docs/root/configuration/http_filters/grpc_json_transcoder_filter.rst b/docs/root/configuration/http_filters/grpc_json_transcoder_filter.rst index 0c77708dbac58..471297286a125 100644 --- a/docs/root/configuration/http_filters/grpc_json_transcoder_filter.rst +++ b/docs/root/configuration/http_filters/grpc_json_transcoder_filter.rst @@ -62,3 +62,16 @@ match the incoming request path, set `match_incoming_request_route` to true. }; } } + +Sending arbitrary content +------------------------- + +By default, when transcoding occurs, gRPC-JSON encodes the message output of a gRPC service method into +JSON and sets the HTTP response `Content-Type` header to `application/json`. To send abritrary content, +a gRPC service method can use +`google.api.HttpBody `_ +as its output message type. The implementation needs to set +`content_type `_ +(which sets the value of the HTTP response `Content-Type` header) and +`data `_ +(which sets the HTTP response body) accordingly. \ No newline at end of file diff --git a/docs/root/configuration/http_filters/lua_filter.rst b/docs/root/configuration/http_filters/lua_filter.rst index 0f5fb6022d358..afb2970a57869 100644 --- a/docs/root/configuration/http_filters/lua_filter.rst +++ b/docs/root/configuration/http_filters/lua_filter.rst @@ -290,6 +290,28 @@ under the filter name i.e. *envoy.lua*. Below is an example of a *metadata* in a Returns a :ref:`metadata object `. +requestInfo() +^^^^^^^^^^^^^ + +.. code-block:: lua + + requestInfo = handle:requestInfo() + +Returns :repo:`information ` related to the current request. + +Returns a :ref:`request info object `. + +connection() +^^^^^^^^^^^^ + +.. code-block:: lua + + connection = handle:connection() + +Returns the current request's underlying :repo:`connection `. + +Returns a :ref:`connection object `. + .. _config_http_filters_lua_header_wrapper: Header object API @@ -390,7 +412,7 @@ get() metadata:get(key) Gets a metadata. *key* is a string that supplies the metadata key. Returns the corresponding -value of the given metadata key. The type of the value can be: *null*, *boolean*, *number*, +value of the given metadata key. The type of the value can be: *nil*, *boolean*, *number*, *string* and *table*. __pairs() @@ -403,3 +425,89 @@ __pairs() Iterates through every *metadata* entry. *key* is a string that supplies a *metadata* key. *value* is *metadata* entry value. + +.. _config_http_filters_lua_request_info_wrapper: + +Request info object API +----------------------- + +protocol() +^^^^^^^^^^ + +.. code-block:: lua + + requestInfo:protocol() + +Returns the string representation of :repo:`HTTP protocol ` +used by the current request. The possible values are: *HTTP/1.0*, *HTTP/1.1*, and *HTTP/2*. + +dynamicMetadata() +^^^^^^^^^^^^^^^^^ + +.. code-block:: lua + + requestInfo:dynamicMetadata() + +Returns a :ref:`dynamic metadata object `. + +.. _config_http_filters_lua_request_info_dynamic_metadata_wrapper: + +Dynamic metadata object API +--------------------------- + +get() +^^^^^ + +.. code-block:: lua + + dynamicMetadata:get(filterName) + + -- to get a value from a returned table. + dynamicMetadata:get(filterName)[key] + +Gets an entry in dynamic metadata struct. *filterName* is a string that supplies the filter name, e.g. *envoy.lb*. +Returns the corresponding *table* of a given *filterName*. + +set() +^^^^^ + +.. code-block:: lua + + dynamicMetadata:set(filterName, key, value) + +Sets key-value pair of a *filterName*'s metadata. *filterName* is a key specifying the target filter name, +e.g. *envoy.lb*. The type of *key* and *value* is *string*. + +__pairs() +^^^^^^^^^ + +.. code-block:: lua + + for key, value in pairs(dynamicMetadata) do + end + +Iterates through every *dynamicMetadata* entry. *key* is a string that supplies a *dynamicMetadata* +key. *value* is *dynamicMetadata* entry value. + +.. _config_http_filters_lua_connection_wrapper: + +Connection object API +--------------------- + +ssl() +^^^^^^^^ + +.. code-block:: lua + + if connection:ssl() == nil then + print("plain") + else + print("secure") + end + +Returns :repo:`SSL connection ` object when the connection is +secured and *nil* when it is not. + +.. note:: + + Currently the SSL connection object has no exposed APIs. diff --git a/docs/root/configuration/listener_filters/listener_filters.rst b/docs/root/configuration/listener_filters/listener_filters.rst index 6c5a7857d5a4c..4f3e2f353398c 100644 --- a/docs/root/configuration/listener_filters/listener_filters.rst +++ b/docs/root/configuration/listener_filters/listener_filters.rst @@ -9,4 +9,5 @@ Envoy has the follow builtin listener filters. :maxdepth: 2 original_dst_filter + proxy_protocol tls_inspector diff --git a/docs/root/configuration/listener_filters/proxy_protocol.rst b/docs/root/configuration/listener_filters/proxy_protocol.rst new file mode 100644 index 0000000000000..2d196609fb004 --- /dev/null +++ b/docs/root/configuration/listener_filters/proxy_protocol.rst @@ -0,0 +1,26 @@ +.. _config_listener_filters_proxy_protocol: + +Proxy Protocol +============== + +This listener filter adds support for +`HAProxy Proxy Protocol `_. + +In this mode, the upstream connection is assumed to come from a proxy +which places the original coordinates (IP, PORT) into a connection-string. +Envoy then extracts these and uses them as the remote address. + +In Proxy Protocol v2 there exists the concept of extensions (TLV) +tags that are optional. This implementation skips over these without +using them. + +This implementation supports both version 1 and version 2, it +automatically determines on a per-connection basis which of the two +versions is present. Note: if the filter is enabled, the Proxy Protocol +must be present on the connection (either version 1 or version 2), +the standard does not allow parsing to determine if it is present or not. + +If there is a protocol error or an unsupported address family +(e.g. AF_UNIX) the connection will be closed and an error thrown. + +* :ref:`v2 API reference ` diff --git a/docs/root/configuration/overview/v2_overview.rst b/docs/root/configuration/overview/v2_overview.rst index 1e5629e8c80ea..3c8e0aff06c99 100644 --- a/docs/root/configuration/overview/v2_overview.rst +++ b/docs/root/configuration/overview/v2_overview.rst @@ -37,12 +37,9 @@ flag, i.e.: .. code-block:: console - ./envoy -c .{json,yaml,pb,pb_text} --v2-config-only + ./envoy -c .{json,yaml,pb,pb_text} -where the extension reflects the underlying v2 config representation. The -:option:`--v2-config-only` flag is not strictly required as Envoy will attempt -to autodetect the config file version, but this option provides an enhanced -debug experience when configuration parsing fails. +where the extension reflects the underlying v2 config representation. The :ref:`Bootstrap ` message is the root of the configuration. A key concept in the :ref:`Bootstrap ` @@ -98,7 +95,14 @@ A minimal fully static bootstrap config is provided below: connect_timeout: 0.25s type: STATIC lb_policy: ROUND_ROBIN - hosts: [{ socket_address: { address: 127.0.0.2, port_value: 1234 }}] + load_assignment: + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 1234 Mostly static with dynamic EDS ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -145,13 +149,22 @@ on 127.0.0.3:5678 is provided below: eds_config: api_config_source: api_type: GRPC - cluster_names: [xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: xds_cluster - name: xds_cluster connect_timeout: 0.25s type: STATIC lb_policy: ROUND_ROBIN http2_protocol_options: {} - hosts: [{ socket_address: { address: 127.0.0.3, port_value: 5678 }}] + load_assignment: + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 5678 Notice above that *xds_cluster* is defined to point Envoy at the management server. Even in an otherwise completely dynamic configurations, some static resources need to @@ -198,11 +211,15 @@ below: lds_config: api_config_source: api_type: GRPC - cluster_names: [xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: xds_cluster cds_config: api_config_source: api_type: GRPC - cluster_names: [xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: xds_cluster static_resources: clusters: @@ -211,7 +228,14 @@ below: type: STATIC lb_policy: ROUND_ROBIN http2_protocol_options: {} - hosts: [{ socket_address: { address: 127.0.0.3, port_value: 5678 }}] + load_assignment: + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 127.0.0.1 + port_value: 5678 The management server could respond to LDS requests with: @@ -236,7 +260,9 @@ The management server could respond to LDS requests with: config_source: api_config_source: api_type: GRPC - cluster_names: [xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: xds_cluster http_filters: - name: envoy.router @@ -270,7 +296,9 @@ The management server could respond to CDS requests with: eds_config: api_config_source: api_type: GRPC - cluster_names: [xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: xds_cluster The management server could respond to EDS requests with: @@ -324,7 +352,9 @@ for the service definition. This is used by Envoy as a client when cds_config: api_config_source: api_type: GRPC - cluster_names: [some_xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: some_xds_cluster is set in the :ref:`dynamic_resources ` of the :ref:`Bootstrap @@ -341,7 +371,9 @@ for the service definition. This is used by Envoy as a client when eds_config: api_config_source: api_type: GRPC - cluster_names: [some_xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: some_xds_cluster is set in the :ref:`eds_cluster_config ` field of the :ref:`Cluster @@ -358,7 +390,9 @@ for the service definition. This is used by Envoy as a client when lds_config: api_config_source: api_type: GRPC - cluster_names: [some_xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: some_xds_cluster is set in the :ref:`dynamic_resources ` of the :ref:`Bootstrap @@ -376,7 +410,9 @@ for the service definition. This is used by Envoy as a client when config_source: api_config_source: api_type: GRPC - cluster_names: [some_xds_cluster] + grpc_services: + envoy_grpc: + cluster_name: some_xds_cluster is set in the :ref:`rds ` field of the :ref:`HttpConnectionManager @@ -496,7 +532,9 @@ for the service definition. This is used by Envoy as a client when ads_config: api_type: GRPC - cluster_names: [some_ads_cluster] + grpc_services: + envoy_grpc: + cluster_name: some_ads_cluster is set in the :ref:`dynamic_resources ` of the :ref:`Bootstrap @@ -526,11 +564,11 @@ the shared ADS channel. Management Server Unreachability -------------------------------- -When Envoy instance looses connectivity with the management server, Envoy will latch on to -the previous configuration while actively retrying in the background to reestablish the -connection with the management server. +When an Envoy instance loses connectivity with the management server, Envoy will latch on to +the previous configuration while actively retrying in the background to reestablish the +connection with the management server. -Envoy debug logs the fact that it is not able to establish a connection with the management server +Envoy debug logs the fact that it is not able to establish a connection with the management server every time it attempts a connection. :ref:`upstream_cx_connect_fail ` a cluster level statistic diff --git a/docs/root/intro/arch_overview/health_checking.rst b/docs/root/intro/arch_overview/health_checking.rst index 2f44702a06bc4..08812c0b62e56 100644 --- a/docs/root/intro/arch_overview/health_checking.rst +++ b/docs/root/intro/arch_overview/health_checking.rst @@ -24,6 +24,44 @@ unhealthy, successes required before marking a host healthy, etc.): maintenance by setting the specified key to any value and waiting for traffic to drain. See :ref:`redis_key `. +.. _arch_overview_per_cluster_health_check_config: + +Per cluster member health check config +-------------------------------------- + +If active health checking is configured for an upstream cluster, a specific additional configuration +for each registered member can be specified by setting the +:ref:`HealthCheckConfig` +in the :ref:`Endpoint` of an :ref:`LbEndpoint` +of each defined :ref:`LocalityLbEndpoints` in a +:ref:`ClusterLoadAssignment`. + +An example of setting up :ref:`health check config` +to set a :ref:`cluster member`'s alternative health check +:ref:`port` is: + +.. code-block:: yaml + + load_assignment: + endpoints: + - lb_endpoints: + - endpoint: + health_check_config: + port_value: 8080 + address: + socket_address: + address: localhost + port_value: 80 + +.. _arch_overview_health_check_logging: + +Health check event logging +-------------------------- + +A per-healthchecker log of ejection and addition events can optionally be produced by Envoy by +specifying a log file path in `the HealthCheck config `. +The log is structured as JSON dumps of `HealthCheckEvent messages `. + Passive health checking ----------------------- diff --git a/docs/root/intro/arch_overview/http_connection_management.rst b/docs/root/intro/arch_overview/http_connection_management.rst index 4f1d415b48e3d..40415481322fc 100644 --- a/docs/root/intro/arch_overview/http_connection_management.rst +++ b/docs/root/intro/arch_overview/http_connection_management.rst @@ -42,3 +42,27 @@ table `. The route table can be specified in one of * Statically. * Dynamically via the :ref:`RDS API `. + +Timeouts +-------- + +Various configurable timeouts apply to an HTTP connection and its constituent streams: + +* Connection-level :ref:`idle timeout + `: + this applies to the idle period where no streams are active. +* Connection-level :ref:`drain timeout + `: + this spans between an Envoy originated GOAWAY and connection termination. +* Stream-level idle timeout: this applies to each individual stream. It may be configured at both + the :ref:`connection manager + ` + and :ref:`per-route ` granularity. + Header/data/trailer events on the stream reset the idle timeout. +* Stream-level :ref:`per-route upstream timeout `: this + applies to the upstream response, i.e. a maximum bound on the time from the end of the downstream + request until the end of the upstream response. This may also be specified at the :ref:`per-retry + ` granularity. +* Stream-level :ref:`per-route gRPC max timeout + `: this bounds the upstream timeout and allows + the timeout to be overriden via the *grpc-timeout* request header. diff --git a/docs/root/intro/arch_overview/redis.rst b/docs/root/intro/arch_overview/redis.rst index b93830edba161..ff7d4696a4abb 100644 --- a/docs/root/intro/arch_overview/redis.rst +++ b/docs/root/intro/arch_overview/redis.rst @@ -43,8 +43,10 @@ For filter configuration details, see the Redis proxy filter The corresponding cluster definition should be configured with :ref:`ring hash load balancing `. -If active healthchecking is desired, the cluster should be configured with a -:ref:`Redis healthcheck `. +If :ref:`active health checking ` is desired, the +cluster should be configured with a :ref:`custom health check +` which configured as a +:ref:`Redis health checker `. If passive healthchecking is desired, also configure :ref:`outlier detection `. diff --git a/docs/root/intro/arch_overview/websocket.rst b/docs/root/intro/arch_overview/websocket.rst index 9d65b3b680e74..35aa8477fc535 100644 --- a/docs/root/intro/arch_overview/websocket.rst +++ b/docs/root/intro/arch_overview/websocket.rst @@ -1,7 +1,27 @@ .. _arch_overview_websocket: -WebSocket support -================= +Envoy currently supports two modes of Upgrade behavior, the new generic upgrade mode, and +the old WebSocket-only TCP proxy mode. + + +New style Upgrade support +========================= + +The new style Upgrade support is intended mainly for WebSocket but may be used for non-WebSocket +upgrades as well. The new style of upgrades pass both the HTTP headers and the upgrade payload +through an HTTP filter chain. One may configure the +:ref:`upgrade_configs ` +in one of two ways. If only the +`upgrade_type ` +is specified, both the upgrade headers, any request and response body, and WebSocket payload will +pass through the default HTTP filter chain. To avoid the use of HTTP-only filters for upgrade payload, +one can set up custom +`filters ` +for the given upgrade type, up to and including only using the router filter to send the WebSocket +data upstream. + +Old style WebSocket support +=========================== Envoy supports upgrading a HTTP/1.1 connection to a WebSocket connection. Connection upgrade will be allowed only if the downstream client @@ -18,8 +38,8 @@ retries, rate limits and shadowing are not supported for WebSocket routes. However, prefix rewriting, explicit and automatic host rewriting, traffic shifting and splitting are supported. -Connection semantics --------------------- +Old style Connection semantics +------------------------------ Even though WebSocket upgrades occur over HTTP/1.1 connections, WebSockets proxying works similarly to plain TCP proxy, i.e., Envoy does not interpret diff --git a/docs/root/intro/version_history.rst b/docs/root/intro/version_history.rst index 3b81f97a9a046..62f92f2070040 100644 --- a/docs/root/intro/version_history.rst +++ b/docs/root/intro/version_history.rst @@ -3,15 +3,57 @@ Version history 1.8.0 (Pending) =============== +* access log: added :ref:`response flag filter ` + to filter based on the presence of Envoy response flags. +* access log: added RESPONSE_DURATION and RESPONSE_TX_DURATION. +* admin: added :http:get:`/hystrix_event_stream` as an endpoint for monitoring envoy's statistics + through `Hystrix dashboard `_. +* grpc-json: added support for building HTTP response from + `google.api.HttpBody `_. +* cluster: added :ref:`option ` to merge + health check/weight/metadata updates within the given duration. +* config: v1 disabled by default. v1 support remains available until October via flipping --v2-config-only=false. +* config: v1 disabled by default. v1 support remains available until October via setting :option:`--allow-deprecated-v1-api`. +* health check: added support for :ref:`custom health check `. +* health check: added support for :ref:`specifying jitter as a percentage `. +* health_check: added support for :ref:`health check event logging `. +* health_check: added support for specifying :ref:`custom request headers ` + to HTTP health checker requests. +* http: added support for a per-stream idle timeout. This applies at both :ref:`connection manager + ` + and :ref:`per-route granularity `. The timeout + defaults to 5 minutes; if you have other timeouts (e.g. connection idle timeout, upstream + response per-retry) that are longer than this in duration, you may want to consider setting a + non-default per-stream idle timeout. +* http: added support for a :ref:`per-stream idle timeout + `. This defaults to 5 minutes; if you have + other timeouts (e.g. connection idle timeout, upstream response per-retry) that are longer than + this in duration, you may want to consider setting a non-default per-stream idle timeout. +* http: added generic :ref:`Upgrade support + `. +* http: better handling of HEAD requests. Now sending transfer-encoding: chunked rather than content-length: 0. * http: response filters not applied to early error paths such as http_parser generated 400s. +* http: :ref:`hpack_table_size ` now controls + dynamic table size of both: encoder and decoder. +* listeners: added the ability to match :ref:`FilterChain ` using + :ref:`destination_port ` and + :ref:`prefix_ranges `. +* lua: added :ref:`connection() ` wrapper and *ssl()* API. +* lua: added :ref:`requestInfo() ` wrapper and *protocol()* API. +* lua: added :ref:`requestInfo():dynamicMetadata() ` API. +* proxy_protocol: added support for HAProxy Proxy Protocol v2 (AF_INET/AF_INET6 only). * ratelimit: added support for :repo:`api/envoy/service/ratelimit/v2/rls.proto`. Lyft's reference implementation of the `ratelimit `_ service also supports the data-plane-api proto as of v1.1.0. Envoy can use either proto to send client requests to a ratelimit server with the use of the :ref:`use_data_plane_proto` boolean flag in the ratelimit configuration. Support for the legacy proto :repo:`source/common/ratelimit/ratelimit.proto` is deprecated and will be removed at the start of the 1.9.0 release cycle. +* router: added ability to set request/response headers at the :ref:`envoy_api_msg_route.Route` level. * tracing: added support for configuration of :ref:`tracing sampling `. +* thrift_proxy: introduced thrift routing, moved configuration to correct location +* upstream: added configuration option to the subset load balancer to take locality weights into account when + selecting a host from a subset. 1.7.0 =============== @@ -26,6 +68,8 @@ Version history * access log: improved WebSocket logging. * admin: added :http:get:`/config_dump` for dumping the current configuration and associated xDS version information (if applicable). +* admin: added :http:get:`/clusters?format=json` for outputing a JSON-serialized proto detailing + the current status of all clusters. * admin: added :http:get:`/stats/prometheus` as an alternative endpoint for getting stats in prometheus format. * admin: added :ref:`/runtime_modify endpoint ` to add or change runtime values. * admin: mutations must be sent as POSTs, rather than GETs. Mutations include: @@ -43,8 +87,10 @@ Version history to close tcp_proxy upstream connections when health checks fail. * cluster: added :ref:`option ` to drain connections from hosts after they are removed from service discovery, regardless of health status. -* cluster: fixed bug preventing the deletion of all endpoints in a priority. -* debug: added symbolized stack traces (where supported). +* cluster: fixed bug preventing the deletion of all endpoints in a priority +* debug: added symbolized stack traces (where supported) +* ext-authz filter: added support to raw HTTP authorization. +* ext-authz filter: added support to gRPC responses to carry HTTP attributes. * grpc: support added for the full set of :ref:`Google gRPC call credentials `. * gzip filter: added :ref:`stats ` to the filter. @@ -61,8 +107,7 @@ Version history * health check: health check connections can now be configured to use http/2. * health check http filter: added :ref:`generic header matching ` - to trigger health check response. Deprecated the - :ref:`endpoint option `. + to trigger health check response. Deprecated the endpoint option. * http: filters can now optionally support :ref:`virtual host `, :ref:`route `, and @@ -85,8 +130,7 @@ Version history * listeners: added the ability to match :ref:`FilterChain ` using :ref:`application_protocols ` (e.g. ALPN for TLS protocol). -* listeners: :ref:`sni_domains ` has been deprecated/renamed to - :ref:`server_names `. +* listeners: `sni_domains` has been deprecated/renamed to :ref:`server_names `. * listeners: removed restriction on all filter chains having identical filters. * load balancer: added :ref:`weighted round robin ` support. The round robin diff --git a/docs/root/operations/admin.rst b/docs/root/operations/admin.rst index 8a9947b5c6334..38cc64f7fe182 100644 --- a/docs/root/operations/admin.rst +++ b/docs/root/operations/admin.rst @@ -32,8 +32,8 @@ modify different aspects of the server: In the future additional security options will be added to the administration interface. This work is tracked in `this `_ issue. - All mutations should be sent as HTTP POST operations. For a limited time, they will continue - to work with HTTP GET, with a warning logged. + All mutations must be sent as HTTP POST operations. When a mutation is requested via GET, + the request has no effect, and an HTTP 400 (Invalid Request) response is returned. .. http:get:: / @@ -110,6 +110,11 @@ modify different aspects of the server: */failed_outlier_check*: The host has failed an outlier detection check. +.. http:get:: /clusters?format=json + + Dump the */clusters* output in a JSON-serialized proto. See the + :ref:`definition ` for more information. + .. _operations_admin_interface_config_dump: .. http:get:: /config_dump @@ -314,3 +319,39 @@ The fields are: Use the /runtime_modify endpoint with care. Changes are effectively immediately. It is **critical** that the admin interface is :ref:`properly secured `. + + .. _operations_admin_interface_hystrix_event_stream: + +.. http:get:: /hystrix_event_stream + + This endpoint is intended to be used as the stream source for + `Hystrix dashboard `_. + a GET to this endpoint will trriger a stream of statistics from envoy in + `text/event-stream `_ + format, as expected by the Hystrix dashboard. + + If invoked from a browser or a terminal, the response will be shown as a continous stream, + sent in intervals defined by the :ref:`Bootstrap ` + :ref:`stats_flush_interval ` + + This handler is enabled only when a Hystrix sink is enabled in the config file as documented + :ref:`here `. + + As Envoy's and Hystrix resiliency mechanisms differ, some of the statistics shown in the dashboard + had to be adapted: + + * **Thread pool rejections** - Generally similar to what's called short circuited in Envoy, + and counted by *upstream_rq_pending_overflow*, although the term thread pool is not accurate for + Envoy. Both in Hystrix and Envoy, the result is rejected requests which are not passed upstream. + * **circuit breaker status (closed or open)** - Since in Envoy, a circuit is opened based on the + current number of connections/requests in queue, there is no sleeping window for circuit breaker, + circuit open/closed is momentary. Hence, we set the circuit breaker status to "forced closed". + * **Short-circuited (rejected)** - The term exists in Envoy but refers to requests not sent because + of passing a limit (queue or connections), while in Hystrix it refers to requests not sent because + of high percentage of service unavailable responses during some time frame. + In Envoy, service unavailable response will cause **outlier detection** - removing a node off the + load balancer pool, but requests are not rejected as a result. Therefore, this counter is always + set to '0'. + * Latency information is currently unavailable. + + diff --git a/docs/root/operations/cli.rst b/docs/root/operations/cli.rst index e50e99462e32a..c6081452c6770 100644 --- a/docs/root/operations/cli.rst +++ b/docs/root/operations/cli.rst @@ -11,8 +11,8 @@ following are the command line options that Envoy supports. *(optional)* The path to the v1 or v2 :ref:`JSON/YAML/proto3 configuration file `. If this flag is missing, :option:`--config-yaml` is required. This will be parsed as a :ref:`v2 bootstrap configuration file - ` and on failure, subject to - :option:`--v2-config-only`, will be considered as a :ref:`v1 JSON + `. On failure, if :option:`--allow-deprecated-v1-api`, + is set, it will be considered as a :ref:`v1 JSON configuration file `. For v2 configuration files, valid extensions are ``.json``, ``.yaml``, ``.pb`` and ``.pb_text``, which indicate JSON, YAML, `binary proto3 @@ -34,9 +34,14 @@ following are the command line options that Envoy supports. .. option:: --v2-config-only + *(deprecated)* This flag used to allow opting into only using a + :ref:`v2 bootstrap configuration file `. This is now set by default. + +.. option:: --allow-deprecated-v1-api + *(optional)* This flag determines whether the configuration file should only be parsed as a :ref:`v2 bootstrap configuration file - `. If false (default), when a v2 bootstrap + `. If specified when a v2 bootstrap config parse fails, a second attempt to parse the config as a :ref:`v1 JSON configuration file ` will be made. diff --git a/examples/grpc-bridge/Dockerfile-python b/examples/grpc-bridge/Dockerfile-python index a807a4d1c95ac..02aa308c2acb1 100644 --- a/examples/grpc-bridge/Dockerfile-python +++ b/examples/grpc-bridge/Dockerfile-python @@ -3,7 +3,7 @@ FROM envoyproxy/envoy:latest RUN apt-get update RUN apt-get -q install -y python-dev \ python-pip -RUN pip install -q grpcio requests +RUN pip install -q grpcio protobuf requests ADD ./client /client RUN chmod a+x /client/client.py RUN mkdir /var/log/envoy/ diff --git a/examples/grpc-bridge/config/s2s-grpc-envoy.yaml b/examples/grpc-bridge/config/s2s-grpc-envoy.yaml index 31ccd3ff0fe41..baaac35e57ea2 100644 --- a/examples/grpc-bridge/config/s2s-grpc-envoy.yaml +++ b/examples/grpc-bridge/config/s2s-grpc-envoy.yaml @@ -21,7 +21,7 @@ static_resources: prefix: "/" headers: - name: content-type - value: application/grpc + exact_match: application/grpc route: cluster: local_service_grpc http_filters: diff --git a/include/envoy/api/os_sys_calls.h b/include/envoy/api/os_sys_calls.h index ae6a96c9e5588..9378b35e3cc05 100644 --- a/include/envoy/api/os_sys_calls.h +++ b/include/envoy/api/os_sys_calls.h @@ -1,5 +1,6 @@ #pragma once +#include #include // for mode_t #include // for sockaddr #include @@ -13,6 +14,22 @@ namespace Envoy { namespace Api { +/** + * SysCallResult holds the rc and errno values resulting from a system call. + */ +struct SysCallResult { + + /** + * The return code from the system call. + */ + int rc_; + + /** + * The errno value as captured after the system call. + */ + int errno_; +}; + class OsSysCalls { public: virtual ~OsSysCalls() {} @@ -22,6 +39,11 @@ class OsSysCalls { */ virtual int bind(int sockfd, const sockaddr* addr, socklen_t addrlen) PURE; + /** + * @see ioctl (man 2 ioctl) + */ + virtual int ioctl(int sockfd, unsigned long int request, void* argp) PURE; + /** * Open file by full_path with given flags and mode. * @return file descriptor. @@ -90,6 +112,11 @@ class OsSysCalls { * @see man 2 getsockopt */ virtual int getsockopt(int sockfd, int level, int optname, void* optval, socklen_t* optlen) PURE; + + /** + * @see man 2 socket + */ + virtual int socket(int domain, int type, int protocol) PURE; }; typedef std::unique_ptr OsSysCallsPtr; diff --git a/include/envoy/buffer/BUILD b/include/envoy/buffer/BUILD index 084ed91ebf24f..01dcb26234196 100644 --- a/include/envoy/buffer/BUILD +++ b/include/envoy/buffer/BUILD @@ -11,4 +11,5 @@ envoy_package() envoy_cc_library( name = "buffer_interface", hdrs = ["buffer.h"], + deps = ["//include/envoy/api:os_sys_calls_interface"], ) diff --git a/include/envoy/buffer/buffer.h b/include/envoy/buffer/buffer.h index ae9cff3785ae7..cfe8611e0848a 100644 --- a/include/envoy/buffer/buffer.h +++ b/include/envoy/buffer/buffer.h @@ -5,6 +5,7 @@ #include #include +#include "envoy/api/os_sys_calls.h" #include "envoy/common/pure.h" namespace Envoy { @@ -142,9 +143,10 @@ class Instance { * Read from a file descriptor directly into the buffer. * @param fd supplies the descriptor to read from. * @param max_length supplies the maximum length to read. - * @return the number of bytes read or -1 if there was an error. + * @return a Api::SysCallResult with rc_ = the number of bytes read if successful, or rc_ = -1 + * for failure. If the call is successful, errno_ shouldn't be used. */ - virtual int read(int fd, uint64_t max_length) PURE; + virtual Api::SysCallResult read(int fd, uint64_t max_length) PURE; /** * Reserve space in the buffer. @@ -164,12 +166,19 @@ class Instance { */ virtual ssize_t search(const void* data, uint64_t size, size_t start) const PURE; + /** + * Constructs a flattened string from a buffer. + * @return the flattened string. + */ + virtual std::string toString() const PURE; + /** * Write the buffer out to a file descriptor. * @param fd supplies the descriptor to write to. - * @return the number of bytes written or -1 if there was an error. + * @return a Api::SysCallResult with rc_ = the number of bytes written if successful, or rc_ = -1 + * for failure. If the call is successful, errno_ shouldn't be used. */ - virtual int write(int fd) PURE; + virtual Api::SysCallResult write(int fd) PURE; }; typedef std::unique_ptr InstancePtr; diff --git a/include/envoy/common/BUILD b/include/envoy/common/BUILD index dee9f10b4479b..507c15d284fe9 100644 --- a/include/envoy/common/BUILD +++ b/include/envoy/common/BUILD @@ -37,3 +37,8 @@ envoy_cc_library( name = "callback", hdrs = ["callback.h"], ) + +envoy_cc_library( + name = "backoff_strategy_interface", + hdrs = ["backoff_strategy.h"], +) diff --git a/include/envoy/common/backoff_strategy.h b/include/envoy/common/backoff_strategy.h new file mode 100644 index 0000000000000..63114e047fc27 --- /dev/null +++ b/include/envoy/common/backoff_strategy.h @@ -0,0 +1,25 @@ +#pragma once + +#include "envoy/common/pure.h" + +namespace Envoy { +/** + * Generic interface for all backoff strategy implementations. + */ +class BackOffStrategy { +public: + virtual ~BackOffStrategy() {} + + /** + * @return the next backoff interval in milli seconds. + */ + virtual uint64_t nextBackOffMs() PURE; + + /** + * Resets the intervals so that the back off intervals can start again. + */ + virtual void reset() PURE; +}; + +typedef std::unique_ptr BackOffStrategyPtr; +} // namespace Envoy \ No newline at end of file diff --git a/include/envoy/http/header_map.h b/include/envoy/http/header_map.h index 7bdf89436a8d8..c85d77b53e061 100644 --- a/include/envoy/http/header_map.h +++ b/include/envoy/http/header_map.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "envoy/common/pure.h" @@ -106,31 +107,6 @@ class HeaderString { */ bool find(const char* str) const { return strstr(c_str(), str); } - /** - * HeaderString is in token list form, each token separated by commas or whitespace, - * see https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.1 for more information, - * header field value's case sensitivity depends on each header. - * @return whether contains token in case insensitive manner. - */ - bool caseInsensitiveContains(const char* token) const { - // Avoid dead loop if token argument is empty. - const int n = strlen(token); - if (n == 0) { - return false; - } - - // Find token substring, skip if it's partial of other token. - const char* tokens = c_str(); - for (const char* p = tokens; (p = strcasestr(p, token)); p += n) { - if ((p == tokens || *(p - 1) == ' ' || *(p - 1) == ',') && - (*(p + n) == '\0' || *(p + n) == ' ' || *(p + n) == ',')) { - return true; - } - } - - return false; - } - /** * Set the value of the string by copying data into it. This overwrites any existing string. */ @@ -281,6 +257,7 @@ class HeaderEntry { HEADER_FUNC(KeepAlive) \ HEADER_FUNC(LastModified) \ HEADER_FUNC(Method) \ + HEADER_FUNC(NoChunks) \ HEADER_FUNC(Origin) \ HEADER_FUNC(OtSpanContext) \ HEADER_FUNC(Path) \ @@ -498,5 +475,10 @@ class HeaderMap { typedef std::unique_ptr HeaderMapPtr; +/** + * Convenient container type for storing Http::LowerCaseString and std::string key/value pairs. + */ +typedef std::vector> HeaderVector; + } // namespace Http } // namespace Envoy diff --git a/include/envoy/network/BUILD b/include/envoy/network/BUILD index 31781824fa098..5f7e0ebd06acc 100644 --- a/include/envoy/network/BUILD +++ b/include/envoy/network/BUILD @@ -11,6 +11,7 @@ envoy_package() envoy_cc_library( name = "address_interface", hdrs = ["address.h"], + deps = ["//include/envoy/api:os_sys_calls_interface"], ) envoy_cc_library( diff --git a/include/envoy/network/address.h b/include/envoy/network/address.h index 949747484018f..13004a1f19b03 100644 --- a/include/envoy/network/address.h +++ b/include/envoy/network/address.h @@ -8,6 +8,7 @@ #include #include +#include "envoy/api/os_sys_calls.h" #include "envoy/common/pure.h" #include "absl/numeric/int128.h" @@ -128,19 +129,19 @@ class Instance { * Bind a socket to this address. The socket should have been created with a call to socket() on * an Instance of the same address family. * @param fd supplies the platform socket handle. - * @return 0 for success and -1 for failure. The error code associated with a failure will - * be accessible in a plaform dependent fashion (e.g. errno for Unix platforms). + * @return a Api::SysCallResult with rc_ = 0 for success and rc_ = -1 for failure. If the call + * is successful, errno_ shouldn't be used. */ - virtual int bind(int fd) const PURE; + virtual Api::SysCallResult bind(int fd) const PURE; /** * Connect a socket to this address. The socket should have been created with a call to socket() * on this object. * @param fd supplies the platform socket handle. - * @return 0 for success and -1 for failure. The error code associated with a failure will - * be accessible in a plaform dependent fashion (e.g. errno for Unix platforms). + * @return a Api::SysCallResult with rc_ = 0 for success and rc_ = -1 for failure. If the call + * is successful, errno_ shouldn't be used. */ - virtual int connect(int fd) const PURE; + virtual Api::SysCallResult connect(int fd) const PURE; /** * @return the IP address information IFF type() == Type::Ip, otherwise nullptr. @@ -150,9 +151,8 @@ class Instance { /** * Create a socket for this address. * @param type supplies the socket type to create. - * @return the file descriptor naming the socket for success and -1 for failure. The error - * code associated with a failure will be accessible in a plaform dependent fashion (e.g. - * errno for Unix platforms). + * @return the file descriptor naming the socket. In case of a failure, the program would be + * aborted. */ virtual int socket(SocketType type) const PURE; diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index 84e0395b3cf47..b7c57f06e4c93 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -186,6 +186,11 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual const Ssl::Connection* ssl() const PURE; + /** + * @return requested server name (e.g. SNI in TLS), if any. + */ + virtual absl::string_view requestedServerName() const PURE; + /** * @return State the current state of the connection. */ diff --git a/include/envoy/registry/registry.h b/include/envoy/registry/registry.h index d9890a79456df..727934540c9a6 100644 --- a/include/envoy/registry/registry.h +++ b/include/envoy/registry/registry.h @@ -82,7 +82,7 @@ template class FactoryRegistry { } factories().emplace(factory.name(), &factory); - RELEASE_ASSERT(getFactory(factory.name()) == &factory); + RELEASE_ASSERT(getFactory(factory.name()) == &factory, ""); return displaced; } @@ -92,7 +92,7 @@ template class FactoryRegistry { */ static void removeFactoryForTest(const std::string& name) { auto result = factories().erase(name); - RELEASE_ASSERT(result == 1); + RELEASE_ASSERT(result == 1, ""); } /** diff --git a/include/envoy/request_info/request_info.h b/include/envoy/request_info/request_info.h index b3bb71248e721..ee21a3ae7e8b1 100644 --- a/include/envoy/request_info/request_info.h +++ b/include/envoy/request_info/request_info.h @@ -64,6 +64,13 @@ class RequestInfo { */ virtual void setResponseFlag(ResponseFlag response_flag) PURE; + /** + * @param response_flags the response_flags to intersect with. + * @return true if the intersection of the response_flags argument and the currently set response + * flags is non-empty. + */ + virtual bool intersectResponseFlags(uint64_t response_flags) const PURE; + /** * @param host the selected upstream host for the request. */ @@ -216,7 +223,12 @@ class RequestInfo { /** * @return whether response flag is set or not. */ - virtual bool getResponseFlag(ResponseFlag response_flag) const PURE; + virtual bool hasResponseFlag(ResponseFlag response_flag) const PURE; + + /** + * @return whether any response flag is set or not. + */ + virtual bool hasAnyResponseFlag() const PURE; /** * @return upstream host description. diff --git a/include/envoy/router/rds.h b/include/envoy/router/rds.h index 827740aee3fca..8ff43f213f4f5 100644 --- a/include/envoy/router/rds.h +++ b/include/envoy/router/rds.h @@ -45,7 +45,7 @@ class RouteConfigProvider { virtual SystemTime lastUpdated() const PURE; }; -typedef std::shared_ptr RouteConfigProviderSharedPtr; +typedef std::unique_ptr RouteConfigProviderPtr; } // namespace Router } // namespace Envoy diff --git a/include/envoy/router/route_config_provider_manager.h b/include/envoy/router/route_config_provider_manager.h index 98b8db5a937ea..daacf2e8e6deb 100644 --- a/include/envoy/router/route_config_provider_manager.h +++ b/include/envoy/router/route_config_provider_manager.h @@ -27,43 +27,29 @@ class RouteConfigProviderManager { virtual ~RouteConfigProviderManager() {} /** - * Get a RouteConfigProviderSharedPtr for a route from RDS. Ownership of the RouteConfigProvider - * is shared by all the HttpConnectionManagers who own a RouteConfigProviderSharedPtr. The - * RouteConfigProviderManager holds weak_ptrs to the RouteConfigProviders. Clean up of the weak - * ptrs happen from the destructor of the RouteConfigProvider. This function creates a - * RouteConfigProvider if there isn't one with the same (route_config_name, cluster) already. - * Otherwise, it returns a RouteConfigProviderSharedPtr created from the manager held weak_ptr. + * Get a RouteConfigProviderPtr for a route from RDS. Ownership of the RouteConfigProvider is the + * HttpConnectionManagers who calls this function. The RouteConfigProviderManager holds raw + * pointers to the RouteConfigProviders. Clean up of the pointers happen from the destructor of + * the RouteConfigProvider. This method creates a RouteConfigProvider which may share the + * underlying RDS subscription with the same (route_config_name, cluster). * @param rds supplies the proto configuration of an RDS-configured RouteConfigProvider. * @param factory_context is the context to use for the route config provider. * @param stat_prefix supplies the stat_prefix to use for the provider stats. */ - virtual RouteConfigProviderSharedPtr getRdsRouteConfigProvider( + virtual RouteConfigProviderPtr createRdsRouteConfigProvider( const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix) PURE; /** * Get a RouteConfigSharedPtr for a statically defined route. Ownership is as described for - * getRdsRouteConfigProvider above. Unlike getRdsRouteConfigProvider(), this method always creates - * a new RouteConfigProvider. + * getRdsRouteConfigProvider above. This method always create a new RouteConfigProvider. * @param route_config supplies the RouteConfiguration for this route * @param runtime supplies the runtime loader. * @param cm supplies the ClusterManager. */ - virtual RouteConfigProviderSharedPtr - getStaticRouteConfigProvider(const envoy::api::v2::RouteConfiguration& route_config, - Server::Configuration::FactoryContext& factory_context) PURE; - - /** - * @return std::vector a list of all the - * dynamic (RDS) RouteConfigProviders currently loaded. - */ - virtual std::vector getRdsRouteConfigProviders() PURE; - - /** - * @return std::vector a list of all the - * static RouteConfigProviders currently loaded. - */ - virtual std::vector getStaticRouteConfigProviders() PURE; + virtual RouteConfigProviderPtr + createStaticRouteConfigProvider(const envoy::api::v2::RouteConfiguration& route_config, + Server::Configuration::FactoryContext& factory_context) PURE; }; } // namespace Router diff --git a/include/envoy/router/router.h b/include/envoy/router/router.h index 8699f95d3f137..a69c402251910 100644 --- a/include/envoy/router/router.h +++ b/include/envoy/router/router.h @@ -99,6 +99,11 @@ class CorsPolicy { */ virtual const std::list& allowOrigins() const PURE; + /* + * @return std::list& regexes that match allowed origins. + */ + virtual const std::list& allowOriginRegexes() const PURE; + /** * @return std::string access-control-allow-methods value. */ @@ -467,6 +472,12 @@ class RouteEntry : public ResponseEntry { */ virtual std::chrono::milliseconds timeout() const PURE; + /** + * @return optional the route's idle timeout. Zero indicates a + * disabled idle timeout, while nullopt indicates deference to the global timeout. + */ + virtual absl::optional idleTimeout() const PURE; + /** * @return absl::optional the maximum allowed timeout value derived * from 'grpc-timeout' header of a gRPC request. Non-present value disables use of 'grpc-timeout' diff --git a/include/envoy/server/BUILD b/include/envoy/server/BUILD index fae78b50ab2ad..42acd80d8f161 100644 --- a/include/envoy/server/BUILD +++ b/include/envoy/server/BUILD @@ -111,6 +111,7 @@ envoy_cc_library( hdrs = ["options.h"], deps = [ "//include/envoy/network:address_interface", + "//include/envoy/stats:stats_interface", ], ) @@ -180,3 +181,27 @@ envoy_cc_library( "//source/common/protobuf", ], ) + +envoy_cc_library( + name = "resource_monitor_interface", + hdrs = ["resource_monitor.h"], + deps = [ + "//source/common/protobuf", + ], +) + +envoy_cc_library( + name = "resource_monitor_config_interface", + hdrs = ["resource_monitor_config.h"], + deps = [ + ":resource_monitor_interface", + "//include/envoy/event:dispatcher_interface", + ], +) + +envoy_cc_library( + name = "overload_manager_interface", + hdrs = ["overload_manager.h"], + deps = [ + ], +) diff --git a/include/envoy/server/admin.h b/include/envoy/server/admin.h index 2ed9f3e45e61f..00dca9d31cf5f 100644 --- a/include/envoy/server/admin.h +++ b/include/envoy/server/admin.h @@ -37,7 +37,7 @@ class AdminStream { * @return Http::StreamDecoderFilterCallbacks& to be used by the handler to get HTTP request data * for streaming. */ - virtual const Http::StreamDecoderFilterCallbacks& getDecoderFilterCallbacks() const PURE; + virtual Http::StreamDecoderFilterCallbacks& getDecoderFilterCallbacks() const PURE; /** * @return Http::HeaderMap& to be used by handler to parse header information sent with the @@ -54,7 +54,7 @@ class AdminStream { */ #define MAKE_ADMIN_HANDLER(X) \ [this](absl::string_view path_and_query, Http::HeaderMap& response_headers, \ - Buffer::Instance& data, AdminStream& admin_stream) -> Http::Code { \ + Buffer::Instance& data, Server::AdminStream& admin_stream) -> Http::Code { \ return X(path_and_query, response_headers, data, admin_stream); \ } diff --git a/include/envoy/server/filter_config.h b/include/envoy/server/filter_config.h index 6fc02e5980bcd..46d1a265cb231 100644 --- a/include/envoy/server/filter_config.h +++ b/include/envoy/server/filter_config.h @@ -208,7 +208,7 @@ class NamedNetworkFilterConfigFactory { FactoryContext& context) { UNREFERENCED_PARAMETER(config); UNREFERENCED_PARAMETER(context); - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } /** @@ -260,7 +260,7 @@ class NamedHttpFilterConfigFactory { UNREFERENCED_PARAMETER(config); UNREFERENCED_PARAMETER(stat_prefix); UNREFERENCED_PARAMETER(context); - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } /** diff --git a/include/envoy/server/health_checker_config.h b/include/envoy/server/health_checker_config.h index 68af121915baa..089c65f4d3977 100644 --- a/include/envoy/server/health_checker_config.h +++ b/include/envoy/server/health_checker_config.h @@ -31,6 +31,12 @@ class HealthCheckerFactoryContext { * for all singleton processing. */ virtual Event::Dispatcher& dispatcher() PURE; + + /* + * @return Upstream::HealthCheckEventLoggerPtr the health check event logger for the + * created health checkers. This function may not be idempotent. + */ + virtual Upstream::HealthCheckEventLoggerPtr eventLogger() PURE; }; /** @@ -63,4 +69,4 @@ class CustomHealthCheckerFactory { } // namespace Configuration } // namespace Server -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/include/envoy/server/options.h b/include/envoy/server/options.h index 0d5a62e084913..88d43d8c6bf19 100644 --- a/include/envoy/server/options.h +++ b/include/envoy/server/options.h @@ -6,6 +6,7 @@ #include "envoy/common/pure.h" #include "envoy/network/address.h" +#include "envoy/stats/stats.h" #include "spdlog/spdlog.h" @@ -149,10 +150,10 @@ class Options { virtual uint64_t maxStats() const PURE; /** - * @return uint64_t the maximum name length of the name field in + * @return StatsOptions& the max stat name / suffix lengths for stats. * router/cluster/listener. */ - virtual uint64_t maxObjNameLength() const PURE; + virtual const Stats::StatsOptions& statsOptions() const PURE; /** * @return bool indicating whether the hot restart functionality has been disabled via cli flags. diff --git a/include/envoy/server/overload_manager.h b/include/envoy/server/overload_manager.h new file mode 100644 index 0000000000000..40cf4cf3e969f --- /dev/null +++ b/include/envoy/server/overload_manager.h @@ -0,0 +1,46 @@ +#pragma once + +#include "envoy/common/pure.h" + +namespace Envoy { +namespace Server { + +enum class OverloadActionState { + /** + * Indicates that an overload action is active because at least one of its triggers has fired. + */ + Active, + /** + * Indicates that an overload action is inactive because none of its triggers have fired. + */ + Inactive +}; + +/** + * Callback invoked when an overload action changes state. + */ +typedef std::function OverloadActionCb; + +/** + * The OverloadManager protects the Envoy instance from being overwhelmed by client + * requests. It monitors a set of resources and notifies registered listeners if + * configured thresholds for those resources have been exceeded. + */ +class OverloadManager { +public: + virtual ~OverloadManager() {} + + /** + * Register a callback to be invoked when the specified overload action changes state + * (ie. becomes activated or inactivated). Must be called before the start method is called. + * @param action const std::string& the name of the overload action to register for + * @param dispatcher Event::Dispatcher& the dispatcher on which callbacks will be posted + * @param callback OverloadActionCb the callback to post when the overload action + * changes state + */ + virtual void registerForAction(const std::string& action, Event::Dispatcher& dispatcher, + OverloadActionCb callback) PURE; +}; + +} // namespace Server +} // namespace Envoy diff --git a/include/envoy/server/resource_monitor.h b/include/envoy/server/resource_monitor.h new file mode 100644 index 0000000000000..3fd01b52ac3b7 --- /dev/null +++ b/include/envoy/server/resource_monitor.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +namespace Envoy { +namespace Server { + +// Struct for reporting usage for a particular resource. +struct ResourceUsage { + // Fraction of (resource usage)/(resource limit). + double resource_pressure_; +}; + +class ResourceMonitor { +public: + virtual ~ResourceMonitor() {} + + /** + * Notifies caller of updated resource usage. + */ + class Callbacks { + public: + virtual ~Callbacks() {} + + /** + * Called when the request for updated resource usage succeeds. + * @param usage the updated resource usage + */ + virtual void onSuccess(const ResourceUsage& usage) PURE; + + /** + * Called when the request for updated resource usage fails. + * @param error the exception caught when trying to get updated resource usage + */ + virtual void onFailure(const EnvoyException& error) PURE; + }; + + /** + * Recalculate resource usage. + * This must be non-blocking so if RPCs need to be made they should be + * done asynchronously and invoke the callback when finished. + */ + virtual void updateResourceUsage(Callbacks& callbacks) PURE; +}; + +typedef std::unique_ptr ResourceMonitorPtr; + +} // namespace Server +} // namespace Envoy diff --git a/include/envoy/server/resource_monitor_config.h b/include/envoy/server/resource_monitor_config.h new file mode 100644 index 0000000000000..ceea1a685c9b1 --- /dev/null +++ b/include/envoy/server/resource_monitor_config.h @@ -0,0 +1,60 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/event/dispatcher.h" +#include "envoy/server/resource_monitor.h" + +#include "common/protobuf/protobuf.h" + +namespace Envoy { +namespace Server { +namespace Configuration { + +class ResourceMonitorFactoryContext { +public: + virtual ~ResourceMonitorFactoryContext() {} + + /** + * @return Event::Dispatcher& the main thread's dispatcher. This dispatcher should be used + * for all singleton processing. + */ + virtual Event::Dispatcher& dispatcher() PURE; +}; + +/** + * Implemented by each resource monitor and registered via Registry::registerFactory() + * or the convenience class RegistryFactory. + */ +class ResourceMonitorFactory { +public: + virtual ~ResourceMonitorFactory() {} + + /** + * Create a particular resource monitor implementation. + * @param config const ProtoBuf::Message& supplies the config for the resource monitor + * implementation. + * @param context ResourceMonitorFactoryContext& supplies the resource monitor's context. + * @return ResourceMonitorPtr the resource monitor instance. Should not be nullptr. + * @throw EnvoyException if the implementation is unable to produce an instance with + * the provided parameters. + */ + virtual ResourceMonitorPtr createResourceMonitor(const Protobuf::Message& config, + ResourceMonitorFactoryContext& context) PURE; + + /** + * @return ProtobufTypes::MessagePtr create empty config proto message. The resource monitor + * config, which arrives in an opaque google.protobuf.Struct message, will be converted + * to JSON and then parsed into this empty proto. + */ + virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; + + /** + * @return std::string the identifying name for a particular implementation of a resource + * monitor produced by the factory. + */ + virtual std::string name() PURE; +}; + +} // namespace Configuration +} // namespace Server +} // namespace Envoy diff --git a/include/envoy/ssl/connection.h b/include/envoy/ssl/connection.h index cd8ebe55cf06a..98f8506497b8d 100644 --- a/include/envoy/ssl/connection.h +++ b/include/envoy/ssl/connection.h @@ -38,6 +38,12 @@ class Connection { */ virtual const std::string& sha256PeerCertificateDigest() const PURE; + /** + * @return std::string the serial number field of the peer certificate. Returns "" if + * there is no peer certificate, or no serial number. + **/ + virtual std::string serialNumberPeerCertificate() const PURE; + /** * @return std::string the subject field of the peer certificate in RFC 2253 format. Returns "" if * there is no peer certificate, or no subject. diff --git a/include/envoy/ssl/context.h b/include/envoy/ssl/context.h index 5af2fa804fb8a..b3d63bfe45b78 100644 --- a/include/envoy/ssl/context.h +++ b/include/envoy/ssl/context.h @@ -32,12 +32,13 @@ class Context { */ virtual std::string getCertChainInformation() const PURE; }; +typedef std::shared_ptr ContextSharedPtr; class ClientContext : public virtual Context {}; -typedef std::unique_ptr ClientContextPtr; +typedef std::shared_ptr ClientContextSharedPtr; class ServerContext : public virtual Context {}; -typedef std::unique_ptr ServerContextPtr; +typedef std::shared_ptr ServerContextSharedPtr; } // namespace Ssl } // namespace Envoy diff --git a/include/envoy/ssl/context_manager.h b/include/envoy/ssl/context_manager.h index 7489800c99caf..ea63ab9981f05 100644 --- a/include/envoy/ssl/context_manager.h +++ b/include/envoy/ssl/context_manager.h @@ -19,13 +19,13 @@ class ContextManager { /** * Builds a ClientContext from a ClientContextConfig. */ - virtual ClientContextPtr createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) PURE; + virtual ClientContextSharedPtr createSslClientContext(Stats::Scope& scope, + const ClientContextConfig& config) PURE; /** * Builds a ServerContext from a ServerContextConfig. */ - virtual ServerContextPtr + virtual ServerContextSharedPtr createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names) PURE; diff --git a/include/envoy/stats/stats.h b/include/envoy/stats/stats.h index 76e88253548d5..c39e41749699a 100644 --- a/include/envoy/stats/stats.h +++ b/include/envoy/stats/stats.h @@ -24,6 +24,43 @@ class Instance; namespace Stats { +/** + * Struct stored under Server::Options to hold information about the maximum object name length and + * maximum stat suffix length of a stat. These have defaults in StatsOptionsImpl, and the maximum + * object name length can be overridden. The default initialization is used in IsolatedStatImpl, and + * the user-overridden struct is stored in Options. + * + * As noted in the comment above StatsOptionsImpl in source/common/stats/stats_impl.h, a stat name + * often contains both a string whose length is user-defined (cluster_name in the below example), + * and a specific statistic name generated by Envoy. To make room for growth on both fronts, we + * limit the max allowed length of each separately. + * + * name / stat name + * |----------------------------------------------------------------| + * cluster..outlier_detection.ejections_consecutive_5xx + * |--------------------------------------| |-----------------------| + * object name suffix + */ +class StatsOptions { +public: + virtual ~StatsOptions() {} + + /** + * The max allowed length of a complete stat name, including suffix. + */ + virtual size_t maxNameLength() const PURE; + + /** + * The max allowed length of the object part of a stat name. + */ + virtual size_t maxObjNameLength() const PURE; + + /** + * The max allowed length of a stat suffix. + */ + virtual size_t maxStatSuffixLength() const PURE; +}; + /** * General representation of a tag. */ @@ -329,6 +366,12 @@ class Scope { * @return a histogram within the scope's namespace with a particular value type. */ virtual Histogram& histogram(const std::string& name) PURE; + + /** + * @return a reference to the top-level StatsOptions struct, containing information about the + * maximum allowable object name length and stat suffix length. + */ + virtual const Stats::StatsOptions& statsOptions() const PURE; }; /** @@ -424,7 +467,7 @@ class StatDataAllocator { * @return CounterSharedPtr a counter, or nullptr if allocation failed, in which case * tag_extracted_name and tags are not moved. */ - virtual CounterSharedPtr makeCounter(const std::string& name, std::string&& tag_extracted_name, + virtual CounterSharedPtr makeCounter(absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags) PURE; /** @@ -434,9 +477,14 @@ class StatDataAllocator { * @return GaugeSharedPtr a gauge, or nullptr if allocation failed, in which case * tag_extracted_name and tags are not moved. */ - virtual GaugeSharedPtr makeGauge(const std::string& name, std::string&& tag_extracted_name, + virtual GaugeSharedPtr makeGauge(absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags) PURE; + /** + * Determines whether this stats allocator requires bounded stat-name size. + */ + virtual bool requiresBoundedStatNameSize() const PURE; + // TODO(jmarantz): create a parallel mechanism to instantiate histograms. At // the moment, histograms don't fit the same pattern of counters and gaugaes // as they are not actually created in the context of a stats allocator. diff --git a/include/envoy/tcp/BUILD b/include/envoy/tcp/BUILD new file mode 100644 index 0000000000000..3716bc5cb4f64 --- /dev/null +++ b/include/envoy/tcp/BUILD @@ -0,0 +1,19 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "conn_pool_interface", + hdrs = ["conn_pool.h"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/event:deferred_deletable", + "//include/envoy/upstream:upstream_interface", + ], +) diff --git a/include/envoy/tcp/conn_pool.h b/include/envoy/tcp/conn_pool.h new file mode 100644 index 0000000000000..8237af37fea31 --- /dev/null +++ b/include/envoy/tcp/conn_pool.h @@ -0,0 +1,161 @@ +#pragma once + +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/pure.h" +#include "envoy/event/deferred_deletable.h" +#include "envoy/upstream/upstream.h" + +namespace Envoy { +namespace Tcp { +namespace ConnectionPool { + +/** + * Handle that allows a pending connection request to be canceled before it is completed. + */ +class Cancellable { +public: + virtual ~Cancellable() {} + + /** + * Cancel the pending request. + */ + virtual void cancel() PURE; +}; + +/** + * Reason that a pool connection could not be obtained. + */ +enum class PoolFailureReason { + // A resource overflowed and policy prevented a new connection from being created. + Overflow, + // A local connection failure took place while creating a new connection. + LocalConnectionFailure, + // A remote connection failure took place while creating a new connection. + RemoteConnectionFailure, + // A timeout occurred while creating a new connection. + Timeout, +}; + +/* + * UpstreamCallbacks for connection pool upstream connection callbacks and data. Note that + * onEvent(Connected) is never triggered since the event always occurs before a ConnectionPool + * caller is assigned a connection. + */ +class UpstreamCallbacks : public Network::ConnectionCallbacks { +public: + virtual ~UpstreamCallbacks() {} + + /* + * Invoked when data is delivered from the upstream connection while the connection is owned by a + * ConnectionPool::Instance caller. + * @param data supplies data from the upstream + * @param end_stream whether the data is the last data frame + */ + virtual void onUpstreamData(Buffer::Instance& data, bool end_stream) PURE; +}; + +/* + * ConnectionData wraps a ClientConnection allocated to a caller. Open ClientConnections are + * released back to the pool for re-use when their containing ConnectionData is destroyed. + */ +class ConnectionData { +public: + virtual ~ConnectionData() {} + + /** + * @return the ClientConnection for the connection. + */ + virtual Network::ClientConnection& connection() PURE; + + /** + * Sets the ConnectionPool::UpstreamCallbacks for the connection. If no callback is attached, + * data from the upstream will cause the connection to be closed. Callbacks cease when the + * connection is released. + * @param callback the UpstreamCallbacks to invoke for upstream data + */ + virtual void addUpstreamCallbacks(ConnectionPool::UpstreamCallbacks& callback) PURE; +}; + +typedef std::unique_ptr ConnectionDataPtr; + +/** + * Pool callbacks invoked in the context of a newConnection() call, either synchronously or + * asynchronously. + */ +class Callbacks { +public: + virtual ~Callbacks() {} + + /** + * Called when a pool error occurred and no connection could be acquired for making the request. + * @param reason supplies the failure reason. + * @param host supplies the description of the host that caused the failure. This may be nullptr + * if no host was involved in the failure (for example overflow). + */ + virtual void onPoolFailure(PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) PURE; + + /** + * Called when a connection is available to process a request/response. Connections may be + * released back to the pool for re-use by resetting the ConnectionDataPtr. If the connection is + * no longer viable for reuse (e.g. due to some kind of protocol error), the underlying + * ClientConnection should be closed to prevent its reuse. + * + * @param conn supplies the connection data to use. + * @param host supplies the description of the host that will carry the request. For logical + * connection pools the description may be different each time this is called. + */ + virtual void onPoolReady(ConnectionDataPtr&& conn, + Upstream::HostDescriptionConstSharedPtr host) PURE; +}; + +/** + * An instance of a generic connection pool. + */ +class Instance : public Event::DeferredDeletable { +public: + virtual ~Instance() {} + + /** + * Called when a connection pool has been drained of pending requests, busy connections, and + * ready connections. + */ + typedef std::function DrainedCb; + + /** + * Register a callback that gets called when the connection pool is fully drained. No actual + * draining is done. The owner of the connection pool is responsible for not creating any + * new connections. + */ + virtual void addDrainedCallback(DrainedCb cb) PURE; + + /** + * Actively drain all existing connection pool connections. This method can be used in cases + * where the connection pool is not being destroyed, but the caller wishes to make sure that + * all new requests take place on a new connection. For example, when a health check failure + * occurs. + */ + virtual void drainConnections() PURE; + + /** + * Create a new connection on the pool. + * @param cb supplies the callbacks to invoke when the connection is ready or has failed. The + * callbacks may be invoked immediately within the context of this call if there is a + * ready connection or an immediate failure. In this case, the routine returns nullptr. + * @return Cancellable* If no connection is ready, the callback is not invoked, and a handle + * is returned that can be used to cancel the request. Otherwise, one of the + * callbacks is called and the routine returns nullptr. NOTE: Once a callback + * is called, the handle is no longer valid and any further cancellation + * should be done by resetting the connection. + */ + virtual Cancellable* newConnection(Callbacks& callbacks) PURE; +}; + +typedef std::unique_ptr InstancePtr; + +} // namespace ConnectionPool +} // namespace Tcp +} // namespace Envoy diff --git a/include/envoy/upstream/BUILD b/include/envoy/upstream/BUILD index 60f3c7ba4bdc9..e59ddb8a392fa 100644 --- a/include/envoy/upstream/BUILD +++ b/include/envoy/upstream/BUILD @@ -12,6 +12,7 @@ envoy_cc_library( name = "cluster_manager_interface", hdrs = ["cluster_manager.h"], deps = [ + ":health_checker_interface", ":load_balancer_interface", ":thread_local_cluster_interface", ":upstream_interface", @@ -24,6 +25,7 @@ envoy_cc_library( "//include/envoy/runtime:runtime_interface", "//include/envoy/secret:secret_manager_interface", "//include/envoy/server:admin_interface", + "//include/envoy/tcp:conn_pool_interface", "@envoy_api//envoy/api/v2:cds_cc", "@envoy_api//envoy/config/bootstrap/v2:bootstrap_cc", ], @@ -32,7 +34,10 @@ envoy_cc_library( envoy_cc_library( name = "health_checker_interface", hdrs = ["health_checker.h"], - deps = [":upstream_interface"], + deps = [ + ":upstream_interface", + "@envoy_api//envoy/data/core/v2alpha:health_check_event_cc", + ], ) envoy_cc_library( @@ -101,6 +106,8 @@ envoy_cc_library( "//include/envoy/http:codec_interface", "//include/envoy/network:connection_interface", "//include/envoy/network:transport_socket_interface", + "//include/envoy/runtime:runtime_interface", "//include/envoy/ssl:context_interface", + "//include/envoy/ssl:context_manager_interface", ], ) diff --git a/include/envoy/upstream/cluster_manager.h b/include/envoy/upstream/cluster_manager.h index f2ec6e672953e..07c93b67e48a5 100644 --- a/include/envoy/upstream/cluster_manager.h +++ b/include/envoy/upstream/cluster_manager.h @@ -17,6 +17,9 @@ #include "envoy/runtime/runtime.h" #include "envoy/secret/secret_manager.h" #include "envoy/server/admin.h" +#include "envoy/ssl/context_manager.h" +#include "envoy/tcp/conn_pool.h" +#include "envoy/upstream/health_checker.h" #include "envoy/upstream/load_balancer.h" #include "envoy/upstream/thread_local_cluster.h" #include "envoy/upstream/upstream.h" @@ -119,6 +122,18 @@ class ClusterManager { Http::Protocol protocol, LoadBalancerContext* context) PURE; + /** + * Allocate a load balanced TCP connection pool for a cluster. This is *per-thread* so that + * callers do not need to worry about per thread synchronization. The load balancing policy that + * is used is the one defined on the cluster when it was created. + * + * Can return nullptr if there is no host available in the cluster or if the cluster does not + * exist. + */ + virtual Tcp::ConnectionPool::Instance* tcpConnPoolForCluster(const std::string& cluster, + ResourcePriority priority, + LoadBalancerContext* context) PURE; + /** * Allocate a load balanced TCP connection for a cluster. The created connection is already * bound to the correct *per-thread* dispatcher, so no further synchronization is needed. The @@ -248,12 +263,22 @@ class ClusterManagerFactory { ResourcePriority priority, Http::Protocol protocol, const Network::ConnectionSocket::OptionsSharedPtr& options) PURE; + /** + * Allocate a TCP connection pool for the host. Pools are separated by 'priority' and + * 'options->hashKey()', if any. + */ + virtual Tcp::ConnectionPool::InstancePtr + allocateTcpConnPool(Event::Dispatcher& dispatcher, HostConstSharedPtr host, + ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options) PURE; + /** * Allocate a cluster from configuration proto. */ virtual ClusterSharedPtr clusterFromProto(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, + AccessLog::AccessLogManager& log_manager, bool added_via_api) PURE; /** @@ -269,5 +294,31 @@ class ClusterManagerFactory { virtual Secret::SecretManager& secretManager() PURE; }; +/** + * Factory for creating ClusterInfo + */ +class ClusterInfoFactory { +public: + virtual ~ClusterInfoFactory() {} + + /** + * This method returns a Upstream::ClusterInfoConstSharedPtr + * + * @param runtime supplies the runtime loader. + * @param cluster supplies the owning cluster. + * @param bind_config supplies information on binding newly established connections. + * @param stats supplies a store for all known counters, gauges, and timers. + * @param ssl_context_manager supplies a manager for all SSL contexts. + * @param secret_manager supplies a manager for static secrets. + * @param added_via_api denotes whether this was added via API. + * @return Upstream::ClusterInfoConstSharedPtr + */ + virtual Upstream::ClusterInfoConstSharedPtr + createClusterInfo(Runtime::Loader& runtime, const envoy::api::v2::Cluster& cluster, + const envoy::api::v2::core::BindConfig& bind_config, Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, Secret::SecretManager& secret_manager, + bool added_via_api) PURE; +}; + } // namespace Upstream } // namespace Envoy diff --git a/include/envoy/upstream/health_checker.h b/include/envoy/upstream/health_checker.h index e0827eb80e845..e9fafe5cb69cb 100644 --- a/include/envoy/upstream/health_checker.h +++ b/include/envoy/upstream/health_checker.h @@ -3,6 +3,7 @@ #include #include +#include "envoy/data/core/v2alpha/health_check_event.pb.h" #include "envoy/upstream/upstream.h" namespace Envoy { @@ -59,5 +60,36 @@ typedef std::shared_ptr HealthCheckerSharedPtr; std::ostream& operator<<(std::ostream& out, HealthState state); std::ostream& operator<<(std::ostream& out, HealthTransition changed_state); +/** + * Sink for health check event logs. + */ +class HealthCheckEventLogger { +public: + virtual ~HealthCheckEventLogger() {} + + /** + * Log an unhealthy host ejection event. + * @param health_checker_type supplies the type of health checker that generated the event. + * @param host supplies the host that generated the event. + * @param failure_type supplies the type of health check failure + */ + virtual void + logEjectUnhealthy(envoy::data::core::v2alpha::HealthCheckerType health_checker_type, + const HostDescriptionConstSharedPtr& host, + envoy::data::core::v2alpha::HealthCheckFailureType failure_type) PURE; + + /** + * Log a healthy host addition event. + * @param health_checker_type supplies the type of health checker that generated the event. + * @param host supplies the host that generated the event. + * @param healthy_threshold supplied the configured healthy threshold for this health check + * @param first_check whether this is a fast path success on the first health check for this host + */ + virtual void logAddHealthy(envoy::data::core::v2alpha::HealthCheckerType health_checker_type, + const HostDescriptionConstSharedPtr& host, bool first_check) PURE; +}; + +typedef std::unique_ptr HealthCheckEventLoggerPtr; + } // namespace Upstream } // namespace Envoy diff --git a/include/envoy/upstream/host_description.h b/include/envoy/upstream/host_description.h index aff175735b9af..9f5eb67f7ba20 100644 --- a/include/envoy/upstream/host_description.h +++ b/include/envoy/upstream/host_description.h @@ -52,10 +52,20 @@ class HostDescription { */ virtual bool canary() const PURE; + /** + * Update the canary status of the host. + */ + virtual void canary(bool is_canary) PURE; + /** * @return the metadata associated with this host */ - virtual const envoy::api::v2::core::Metadata& metadata() const PURE; + virtual const std::shared_ptr metadata() const PURE; + + /** + * Set the current metadata. + */ + virtual void metadata(const envoy::api::v2::core::Metadata& new_metadata) PURE; /** * @return the cluster the host is a member of. diff --git a/include/envoy/upstream/load_balancer_type.h b/include/envoy/upstream/load_balancer_type.h index 5a8e21828d1b4..cc2e1b3029e59 100644 --- a/include/envoy/upstream/load_balancer_type.h +++ b/include/envoy/upstream/load_balancer_type.h @@ -46,6 +46,11 @@ class LoadBalancerSubsetInfo { * sorted keys used to define load balancer subsets. */ virtual const std::vector>& subsetKeys() const PURE; + + /* + * @return bool whether routing to subsets should take locality weights into account. + */ + virtual bool localityWeightAware() const PURE; }; } // namespace Upstream diff --git a/source/common/access_log/BUILD b/source/common/access_log/BUILD index 40896979eae9d..afbca0e51e2cf 100644 --- a/source/common/access_log/BUILD +++ b/source/common/access_log/BUILD @@ -51,7 +51,10 @@ envoy_cc_library( "//source/common/http:header_utility_lib", "//source/common/http:headers_lib", "//source/common/http:utility_lib", + "//source/common/protobuf:utility_lib", + "//source/common/request_info:request_info_lib", "//source/common/runtime:uuid_util_lib", "//source/common/tracing:http_tracer_lib", + "@envoy_api//envoy/config/filter/accesslog/v2:accesslog_cc", ], ) diff --git a/source/common/access_log/access_log_formatter.cc b/source/common/access_log/access_log_formatter.cc index 82f8205200565..88ec956d9d558 100644 --- a/source/common/access_log/access_log_formatter.cc +++ b/source/common/access_log/access_log_formatter.cc @@ -35,14 +35,16 @@ FormatterPtr AccessLogFormatUtils::defaultAccessLogFormatter() { std::string AccessLogFormatUtils::durationToString(const absl::optional& time) { if (time) { - return fmt::FormatInt( - std::chrono::duration_cast(time.value()).count()) - .str(); + return durationToString(time.value()); } else { return UnspecifiedValueString; } } +std::string AccessLogFormatUtils::durationToString(const std::chrono::nanoseconds& time) { + return fmt::FormatInt(std::chrono::duration_cast(time).count()).str(); +} + const std::string& AccessLogFormatUtils::protocolToString(const absl::optional& protocol) { if (protocol) { @@ -221,6 +223,18 @@ RequestInfoFormatter::RequestInfoFormatter(const std::string& field_name) { field_extractor_ = [](const RequestInfo::RequestInfo& request_info) { return AccessLogFormatUtils::durationToString(request_info.firstUpstreamRxByteReceived()); }; + } else if (field_name == "RESPONSE_TX_DURATION") { + field_extractor_ = [](const RequestInfo::RequestInfo& request_info) { + auto downstream = request_info.lastDownstreamTxByteSent(); + auto upstream = request_info.firstUpstreamRxByteReceived(); + + if (downstream && upstream) { + auto val = downstream.value() - upstream.value(); + return AccessLogFormatUtils::durationToString(val); + } + + return UnspecifiedValueString; + }; } else if (field_name == "BYTES_RECEIVED") { field_extractor_ = [](const RequestInfo::RequestInfo& request_info) { return fmt::FormatInt(request_info.bytesReceived()).str(); @@ -388,7 +402,7 @@ std::string MetadataFormatter::format(const envoy::api::v2::core::Metadata& meta } ProtobufTypes::String json; const auto status = Protobuf::util::MessageToJsonString(*data, &json); - RELEASE_ASSERT(status.ok()); + RELEASE_ASSERT(status.ok(), ""); if (max_length_ && json.length() > max_length_.value()) { return json.substr(0, max_length_.value()); } diff --git a/source/common/access_log/access_log_formatter.h b/source/common/access_log/access_log_formatter.h index b0a1246345c06..b6cb0a8a2775d 100644 --- a/source/common/access_log/access_log_formatter.h +++ b/source/common/access_log/access_log_formatter.h @@ -73,6 +73,7 @@ class AccessLogFormatUtils { static FormatterPtr defaultAccessLogFormatter(); static const std::string& protocolToString(const absl::optional& protocol); static std::string durationToString(const absl::optional& time); + static std::string durationToString(const std::chrono::nanoseconds& time); private: AccessLogFormatUtils(); diff --git a/source/common/access_log/access_log_impl.cc b/source/common/access_log/access_log_impl.cc index 74f0f5b677fdb..163b695cf0253 100644 --- a/source/common/access_log/access_log_impl.cc +++ b/source/common/access_log/access_log_impl.cc @@ -4,6 +4,7 @@ #include #include "envoy/common/time.h" +#include "envoy/config/filter/accesslog/v2/accesslog.pb.validate.h" #include "envoy/filesystem/filesystem.h" #include "envoy/http/header_map.h" #include "envoy/runtime/runtime.h" @@ -17,6 +18,8 @@ #include "common/http/header_utility.h" #include "common/http/headers.h" #include "common/http/utility.h" +#include "common/protobuf/utility.h" +#include "common/request_info/utility.h" #include "common/runtime/uuid_util.h" #include "common/tracing/http_tracer_impl.h" @@ -44,7 +47,7 @@ bool ComparisonFilter::compareAgainstValue(uint64_t lhs) { case envoy::config::filter::accesslog::v2::ComparisonFilter::LE: return lhs <= value; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -68,8 +71,11 @@ FilterFactory::fromProto(const envoy::config::filter::accesslog::v2::AccessLogFi return FilterPtr{new OrFilter(config.or_filter(), runtime, random)}; case envoy::config::filter::accesslog::v2::AccessLogFilter::kHeaderFilter: return FilterPtr{new HeaderFilter(config.header_filter())}; + case envoy::config::filter::accesslog::v2::AccessLogFilter::kResponseFlagFilter: + MessageUtil::validate(config); + return FilterPtr{new ResponseFlagFilter(config.response_flag_filter())}; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -174,6 +180,24 @@ bool HeaderFilter::evaluate(const RequestInfo::RequestInfo&, return Http::HeaderUtility::matchHeaders(request_headers, header_data_); } +ResponseFlagFilter::ResponseFlagFilter( + const envoy::config::filter::accesslog::v2::ResponseFlagFilter& config) { + for (int i = 0; i < config.flags_size(); i++) { + absl::optional response_flag = + RequestInfo::ResponseFlagUtils::toResponseFlag(config.flags(i)); + // The config has been validated. Therefore, every flag in the config will have a mapping. + ASSERT(response_flag.has_value()); + configured_flags_ |= response_flag.value(); + } +} + +bool ResponseFlagFilter::evaluate(const RequestInfo::RequestInfo& info, const Http::HeaderMap&) { + if (configured_flags_ != 0) { + return info.intersectResponseFlags(configured_flags_); + } + return info.hasAnyResponseFlag(); +} + InstanceSharedPtr AccessLogFactory::fromProto(const envoy::config::filter::accesslog::v2::AccessLog& config, Server::Configuration::FactoryContext& context) { diff --git a/source/common/access_log/access_log_impl.h b/source/common/access_log/access_log_impl.h index 63c22735f996a..cad54fd633012 100644 --- a/source/common/access_log/access_log_impl.h +++ b/source/common/access_log/access_log_impl.h @@ -6,7 +6,6 @@ #include "envoy/access_log/access_log.h" #include "envoy/config/filter/accesslog/v2/accesslog.pb.h" -#include "envoy/request_info/request_info.h" #include "envoy/runtime/runtime.h" #include "envoy/server/access_log_config.h" @@ -166,6 +165,21 @@ class HeaderFilter : public Filter { std::vector header_data_; }; +/** + * Filter requests that had a response with an Envoy response flag set. + */ +class ResponseFlagFilter : public Filter { +public: + ResponseFlagFilter(const envoy::config::filter::accesslog::v2::ResponseFlagFilter& config); + + // AccessLog::Filter + bool evaluate(const RequestInfo::RequestInfo& info, + const Http::HeaderMap& request_headers) override; + +private: + uint64_t configured_flags_{}; +}; + /** * Access log factory that reads the configuration from proto. */ diff --git a/source/common/api/os_sys_calls_impl.cc b/source/common/api/os_sys_calls_impl.cc index 7012290927660..80cfd24bf602c 100644 --- a/source/common/api/os_sys_calls_impl.cc +++ b/source/common/api/os_sys_calls_impl.cc @@ -11,6 +11,10 @@ int OsSysCallsImpl::bind(int sockfd, const sockaddr* addr, socklen_t addrlen) { return ::bind(sockfd, addr, addrlen); } +int OsSysCallsImpl::ioctl(int sockfd, unsigned long int request, void* argp) { + return ::ioctl(sockfd, request, argp); +} + int OsSysCallsImpl::open(const std::string& full_path, int flags, int mode) { return ::open(full_path.c_str(), flags, mode); } @@ -57,5 +61,9 @@ int OsSysCallsImpl::getsockopt(int sockfd, int level, int optname, void* optval, return ::getsockopt(sockfd, level, optname, optval, optlen); } +int OsSysCallsImpl::socket(int domain, int type, int protocol) { + return ::socket(domain, type, protocol); +} + } // namespace Api } // namespace Envoy diff --git a/source/common/api/os_sys_calls_impl.h b/source/common/api/os_sys_calls_impl.h index d1985622615d5..db325862367df 100644 --- a/source/common/api/os_sys_calls_impl.h +++ b/source/common/api/os_sys_calls_impl.h @@ -11,6 +11,7 @@ class OsSysCallsImpl : public OsSysCalls { public: // Api::OsSysCalls int bind(int sockfd, const sockaddr* addr, socklen_t addrlen) override; + int ioctl(int sockfd, unsigned long int request, void* argp) override; int open(const std::string& full_path, int flags, int mode) override; ssize_t write(int fd, const void* buffer, size_t num_bytes) override; ssize_t writev(int fd, const iovec* iovec, int num_iovec) override; @@ -24,6 +25,7 @@ class OsSysCallsImpl : public OsSysCalls { int stat(const char* pathname, struct stat* buf) override; int setsockopt(int sockfd, int level, int optname, const void* optval, socklen_t optlen) override; int getsockopt(int sockfd, int level, int optname, void* optval, socklen_t* optlen) override; + int socket(int domain, int type, int protocol) override; }; typedef ThreadSafeSingleton OsSysCallsSingleton; diff --git a/source/common/buffer/buffer_impl.cc b/source/common/buffer/buffer_impl.cc index f7bdfcd12aa46..888d87077376f 100644 --- a/source/common/buffer/buffer_impl.cc +++ b/source/common/buffer/buffer_impl.cc @@ -94,9 +94,9 @@ void OwnedImpl::move(Instance& rhs, uint64_t length) { static_cast(rhs).postProcess(); } -int OwnedImpl::read(int fd, uint64_t max_length) { +Api::SysCallResult OwnedImpl::read(int fd, uint64_t max_length) { if (max_length == 0) { - return 0; + return {0, 0}; } constexpr uint64_t MaxSlices = 2; RawSlice slices[MaxSlices]; @@ -115,8 +115,9 @@ int OwnedImpl::read(int fd, uint64_t max_length) { ASSERT(num_bytes_to_read <= max_length); auto& os_syscalls = Api::OsSysCallsSingleton::get(); const ssize_t rc = os_syscalls.readv(fd, iov, static_cast(num_slices_to_read)); + const int error = errno; if (rc < 0) { - return rc; + return {static_cast(rc), error}; } uint64_t num_slices_to_commit = 0; uint64_t bytes_to_commit = rc; @@ -130,7 +131,7 @@ int OwnedImpl::read(int fd, uint64_t max_length) { } ASSERT(num_slices_to_commit <= num_slices); commit(slices, num_slices_to_commit); - return rc; + return {static_cast(rc), error}; } uint64_t OwnedImpl::reserve(uint64_t length, RawSlice* iovecs, uint64_t num_iovecs) { @@ -151,7 +152,7 @@ ssize_t OwnedImpl::search(const void* data, uint64_t size, size_t start) const { return result_ptr.pos; } -int OwnedImpl::write(int fd) { +Api::SysCallResult OwnedImpl::write(int fd) { constexpr uint64_t MaxSlices = 16; RawSlice slices[MaxSlices]; const uint64_t num_slices = std::min(getRawSlices(slices, MaxSlices), MaxSlices); @@ -165,14 +166,15 @@ int OwnedImpl::write(int fd) { } } if (num_slices_to_write == 0) { - return 0; + return {0, 0}; } auto& os_syscalls = Api::OsSysCallsSingleton::get(); const ssize_t rc = os_syscalls.writev(fd, iov, num_slices_to_write); + const int error = errno; if (rc > 0) { drain(static_cast(rc)); } - return static_cast(rc); + return {static_cast(rc), error}; } OwnedImpl::OwnedImpl() : buffer_(evbuffer_new()) {} diff --git a/source/common/buffer/buffer_impl.h b/source/common/buffer/buffer_impl.h index 4d9a2f4aa7b6f..993f4990405d8 100644 --- a/source/common/buffer/buffer_impl.h +++ b/source/common/buffer/buffer_impl.h @@ -80,17 +80,12 @@ class OwnedImpl : public LibEventInstance { void* linearize(uint32_t size) override; void move(Instance& rhs) override; void move(Instance& rhs, uint64_t length) override; - int read(int fd, uint64_t max_length) override; + Api::SysCallResult read(int fd, uint64_t max_length) override; uint64_t reserve(uint64_t length, RawSlice* iovecs, uint64_t num_iovecs) override; ssize_t search(const void* data, uint64_t size, size_t start) const override; - int write(int fd) override; + Api::SysCallResult write(int fd) override; void postProcess() override {} - - /** - * Construct a flattened string from a buffer. - * @return the flattened string. - */ - std::string toString() const; + std::string toString() const override; Event::Libevent::BufferPtr& buffer() override { return buffer_; } diff --git a/source/common/buffer/watermark_buffer.cc b/source/common/buffer/watermark_buffer.cc index 9eb32b1815ee7..fe2c1981e54f0 100644 --- a/source/common/buffer/watermark_buffer.cc +++ b/source/common/buffer/watermark_buffer.cc @@ -40,10 +40,10 @@ void WatermarkBuffer::move(Instance& rhs, uint64_t length) { checkHighWatermark(); } -int WatermarkBuffer::read(int fd, uint64_t max_length) { - int bytes_read = OwnedImpl::read(fd, max_length); +Api::SysCallResult WatermarkBuffer::read(int fd, uint64_t max_length) { + Api::SysCallResult result = OwnedImpl::read(fd, max_length); checkHighWatermark(); - return bytes_read; + return result; } uint64_t WatermarkBuffer::reserve(uint64_t length, RawSlice* iovecs, uint64_t num_iovecs) { @@ -52,10 +52,10 @@ uint64_t WatermarkBuffer::reserve(uint64_t length, RawSlice* iovecs, uint64_t nu return bytes_reserved; } -int WatermarkBuffer::write(int fd) { - int bytes_written = OwnedImpl::write(fd); +Api::SysCallResult WatermarkBuffer::write(int fd) { + Api::SysCallResult result = OwnedImpl::write(fd); checkLowWatermark(); - return bytes_written; + return result; } void WatermarkBuffer::setWatermarks(uint32_t low_watermark, uint32_t high_watermark) { diff --git a/source/common/buffer/watermark_buffer.h b/source/common/buffer/watermark_buffer.h index 5be55409ef1e4..fb74ccde04f4c 100644 --- a/source/common/buffer/watermark_buffer.h +++ b/source/common/buffer/watermark_buffer.h @@ -28,9 +28,9 @@ class WatermarkBuffer : public OwnedImpl { void drain(uint64_t size) override; void move(Instance& rhs) override; void move(Instance& rhs, uint64_t length) override; - int read(int fd, uint64_t max_length) override; + Api::SysCallResult read(int fd, uint64_t max_length) override; uint64_t reserve(uint64_t length, RawSlice* iovecs, uint64_t num_iovecs) override; - int write(int fd) override; + Api::SysCallResult write(int fd) override; void postProcess() override { checkLowWatermark(); } void setWatermarks(uint32_t watermark) { setWatermarks(watermark / 2, watermark); } diff --git a/source/common/buffer/zero_copy_input_stream_impl.cc b/source/common/buffer/zero_copy_input_stream_impl.cc index f6030f69eec72..9159045b5c332 100644 --- a/source/common/buffer/zero_copy_input_stream_impl.cc +++ b/source/common/buffer/zero_copy_input_stream_impl.cc @@ -44,7 +44,7 @@ bool ZeroCopyInputStreamImpl::Next(const void** data, int* size) { return false; } -bool ZeroCopyInputStreamImpl::Skip(int) { NOT_IMPLEMENTED; } +bool ZeroCopyInputStreamImpl::Skip(int) { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void ZeroCopyInputStreamImpl::BackUp(int count) { ASSERT(count >= 0); diff --git a/source/common/common/BUILD b/source/common/common/BUILD index fc0e43d110f1b..b921a9d154449 100644 --- a/source/common/common/BUILD +++ b/source/common/common/BUILD @@ -16,6 +16,17 @@ envoy_cc_library( deps = [":minimal_logger_lib"], ) +envoy_cc_library( + name = "backoff_lib", + srcs = ["backoff_strategy.cc"], + hdrs = ["backoff_strategy.h"], + deps = [ + ":assert_lib", + "//include/envoy/common:backoff_strategy_interface", + "//include/envoy/runtime:runtime_interface", + ], +) + envoy_cc_library( name = "base64_lib", srcs = ["base64.cc"], @@ -113,6 +124,21 @@ envoy_cc_library( hdrs = ["macros.h"], ) +envoy_cc_library( + name = "matchers_lib", + srcs = ["matchers.cc"], + hdrs = ["matchers.h"], + external_deps = ["abseil_optional"], + deps = [ + ":utility_lib", + "//source/common/config:metadata_lib", + "//source/common/protobuf", + "@envoy_api//envoy/type/matcher:metadata_cc", + "@envoy_api//envoy/type/matcher:number_cc", + "@envoy_api//envoy/type/matcher:string_cc", + ], +) + envoy_cc_library( name = "non_copyable", hdrs = ["non_copyable.h"], @@ -209,6 +235,8 @@ envoy_cc_library( deps = [ ":assert_lib", ":logger_lib", + "//include/envoy/stats:stats_interface", + "//source/common/stats:stats_lib", ], ) diff --git a/source/common/common/assert.h b/source/common/common/assert.h index 64ca60856e20e..9d4f4c5275215 100644 --- a/source/common/common/assert.h +++ b/source/common/common/assert.h @@ -3,21 +3,33 @@ #include "common/common/logger.h" namespace Envoy { + /** * assert macro that uses our builtin logging which gives us thread ID and can log to various * sinks. + * + * The old style release assert was of the form RELEASE_ASSERT(foo == bar); + * where it would log stack traces and the failed conditional and crash if the + * condition is not met. The are many legacy RELEASE_ASSERTS in Envoy which + * were converted to RELEASE_ASSERT(foo == bar, ""); + * + * The new style of release assert is of the form + * RELEASE_ASSERT(foo == bar, "reason foo should actually be bar"); + * new uses of RELEASE_ASSERT should supply a verbose explanation of what went wrong. */ -#define RELEASE_ASSERT(X) \ +#define RELEASE_ASSERT(X, DETAILS) \ do { \ if (!(X)) { \ + const std::string& details = (DETAILS); \ ENVOY_LOG_TO_LOGGER(Envoy::Logger::Registry::getLog(Envoy::Logger::Id::assert), critical, \ - "assert failure: {}", #X); \ + "assert failure: {}.{}{}", #X, \ + details.empty() ? "" : " Details: ", details); \ abort(); \ } \ } while (false) #ifndef NDEBUG -#define ASSERT(X) RELEASE_ASSERT(X) +#define ASSERT(X) RELEASE_ASSERT(X, "") #else // This non-implementation ensures that its argument is a valid expression that can be statically // casted to a bool, but the expression is never evaluated and will be compiled away. @@ -36,10 +48,14 @@ namespace Envoy { "panic: {}", X); \ abort(); -#define NOT_IMPLEMENTED PANIC("not implemented") +// NOT_IMPLEMENTED_GCOVR_EXCL_LINE is for overridden functions that are expressly not implemented. +// The macro name includes "GCOVR_EXCL_LINE" to exclude the macro's usage from code coverage +// reports. +#define NOT_IMPLEMENTED_GCOVR_EXCL_LINE PANIC("not implemented") -// NOT_REACHED is for spots the compiler insists on having a return, but where we know that it -// shouldn't be possible to arrive there, assuming no horrendous bugs. For example, after a -// switch (some_enum) with all enum values included in the cases. -#define NOT_REACHED PANIC("not reached") +// NOT_REACHED_GCOVR_EXCL_LINE is for spots the compiler insists on having a return, but where we +// know that it shouldn't be possible to arrive there, assuming no horrendous bugs. For example, +// after a switch (some_enum) with all enum values included in the cases. The macro name includes +// "GCOVR_EXCL_LINE" to exclude the macro's usage from code coverage reports. +#define NOT_REACHED_GCOVR_EXCL_LINE PANIC("not reached") } // Envoy diff --git a/source/common/common/backoff_strategy.cc b/source/common/common/backoff_strategy.cc new file mode 100644 index 0000000000000..d8a21be12c801 --- /dev/null +++ b/source/common/common/backoff_strategy.cc @@ -0,0 +1,22 @@ +#include "common/common/backoff_strategy.h" + +namespace Envoy { + +JitteredBackOffStrategy::JitteredBackOffStrategy(uint64_t base_interval, uint64_t max_interval, + Runtime::RandomGenerator& random) + : base_interval_(base_interval), max_interval_(max_interval), random_(random) { + ASSERT(base_interval_ <= max_interval_); +} + +uint64_t JitteredBackOffStrategy::nextBackOffMs() { + const uint64_t multiplier = (1 << current_retry_) - 1; + const uint64_t base_backoff = multiplier * base_interval_; + if (base_backoff <= max_interval_) { + current_retry_++; + } + return std::min(random_.random() % base_backoff, max_interval_); +} + +void JitteredBackOffStrategy::reset() { current_retry_ = 1; } + +} // namespace Envoy \ No newline at end of file diff --git a/source/common/common/backoff_strategy.h b/source/common/common/backoff_strategy.h new file mode 100644 index 0000000000000..787320ecb5a98 --- /dev/null +++ b/source/common/common/backoff_strategy.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +#include "envoy/common/backoff_strategy.h" +#include "envoy/runtime/runtime.h" + +#include "common/common/assert.h" + +namespace Envoy { + +/** + * Implementation of BackOffStrategy that uses a fully jittered exponential backoff algorithm. + */ +class JitteredBackOffStrategy : public BackOffStrategy { + +public: + /** + * Constructs fully jittered backoff strategy. + * @param base_interval the base_interval to be used for next backoff computation. + * @param max_interval the cap on the next backoff value. + * @param random the random generator + */ + JitteredBackOffStrategy(uint64_t base_interval, uint64_t max_interval, + Runtime::RandomGenerator& random); + + // BackOffStrategy methods + uint64_t nextBackOffMs() override; + void reset() override; + +private: + const uint64_t base_interval_; + const uint64_t max_interval_{}; + uint64_t current_retry_{1}; + Runtime::RandomGenerator& random_; +}; +} // namespace Envoy diff --git a/source/common/common/block_memory_hash_set.h b/source/common/common/block_memory_hash_set.h index 6677e8a6d37a9..5724e527067b0 100644 --- a/source/common/common/block_memory_hash_set.h +++ b/source/common/common/block_memory_hash_set.h @@ -5,10 +5,12 @@ #include #include "envoy/common/exception.h" +#include "envoy/stats/stats.h" #include "common/common/assert.h" #include "common/common/fmt.h" #include "common/common/logger.h" +#include "common/stats/stats_impl.h" #include "absl/strings/string_view.h" @@ -57,21 +59,24 @@ template class BlockMemoryHashSet : public Logger::Loggable class BlockMemoryHashSet : public Logger::Loggableoptions); } - - /** - * Returns the options structure that was used to construct the set. - */ - const BlockMemoryHashSetOptions& options() const { return control_->options; } + uint64_t numBytes(const Stats::StatsOptions& stats_options) const { + return numBytes(control_->hash_set_options, stats_options); + } /** Examines the data structures to see if they are sane, assert-failing on any trouble. */ void sanityCheck() { - RELEASE_ASSERT(control_->size <= control_->options.capacity); + RELEASE_ASSERT(control_->size <= control_->hash_set_options.capacity, ""); // As a sanity check, make sure there are control_->size values // reachable from the slots, each of which has a valid @@ -107,22 +110,22 @@ template class BlockMemoryHashSet : public Logger::Loggableoptions.num_slots; ++slot) { + for (uint32_t slot = 0; slot < control_->hash_set_options.num_slots; ++slot) { uint32_t next = 0; // initialized to silence compilers. for (uint32_t cell_index = slots_[slot]; (cell_index != Sentinal) && (num_values <= control_->size); cell_index = next) { - RELEASE_ASSERT(cell_index < control_->options.capacity); + RELEASE_ASSERT(cell_index < control_->hash_set_options.capacity, ""); Cell& cell = getCell(cell_index); absl::string_view key = cell.value.key(); - RELEASE_ASSERT(computeSlot(key) == slot); + RELEASE_ASSERT(computeSlot(key) == slot, ""); next = cell.next_cell_index; ++num_values; } } - RELEASE_ASSERT(num_values == control_->size); + RELEASE_ASSERT(num_values == control_->size, ""); uint32_t num_free_entries = 0; - uint32_t expected_free_entries = control_->options.capacity - control_->size; + uint32_t expected_free_entries = control_->hash_set_options.capacity - control_->size; // Don't infinite-loop with a corruption; break when we see there's a problem. for (uint32_t cell_index = control_->free_cell_index; @@ -130,7 +133,7 @@ template class BlockMemoryHashSet : public Logger::Loggable class BlockMemoryHashSet : public Logger::Loggablesize >= control_->options.capacity) { + if (control_->size >= control_->hash_set_options.capacity) { return ValueCreatedPair(nullptr, false); } const uint32_t slot = computeSlot(key); @@ -162,7 +165,7 @@ template class BlockMemoryHashSet : public Logger::Loggableinitialize(key); + value->initialize(key, stats_options_); ++control_->size; return ValueCreatedPair(value, true); } @@ -215,9 +218,9 @@ template class BlockMemoryHashSet : public Logger::Loggableoptions.toString(), - control_->hash_signature, numBytes()); + std::string version(const Stats::StatsOptions& stats_options) { + return fmt::format("options={} hash={} size={}", control_->hash_set_options.toString(), + control_->hash_signature, numBytes(stats_options)); } private: @@ -228,20 +231,20 @@ template class BlockMemoryHashSet : public Logger::Loggablehash_signature = Value::hash(signatureStringToHash()); - control_->num_bytes = numBytes(options); - control_->options = options; + control_->num_bytes = numBytes(hash_set_options, stats_options_); + control_->hash_set_options = hash_set_options; control_->size = 0; control_->free_cell_index = 0; // Initialize all the slots; - for (uint32_t slot = 0; slot < options.num_slots; ++slot) { + for (uint32_t slot = 0; slot < hash_set_options.num_slots; ++slot) { slots_[slot] = Sentinal; } // Initialize the free-cell list. - const uint32_t last_cell = options.capacity - 1; + const uint32_t last_cell = hash_set_options.capacity - 1; for (uint32_t cell_index = 0; cell_index < last_cell; ++cell_index) { Cell& cell = getCell(cell_index); cell.next_cell_index = cell_index + 1; @@ -254,10 +257,10 @@ template class BlockMemoryHashSet : public Logger::Loggablenum_bytes) { - ENVOY_LOG(error, "BlockMemoryHashSet unexpected memory size {} != {}", numBytes(options), - control_->num_bytes); + bool attach(const BlockMemoryHashSetOptions& hash_set_options) { + if (numBytes(hash_set_options, stats_options_) != control_->num_bytes) { + ENVOY_LOG(error, "BlockMemoryHashSet unexpected memory size {} != {}", + numBytes(hash_set_options, stats_options_), control_->num_bytes); return false; } if (Value::hash(signatureStringToHash()) != control_->hash_signature) { @@ -269,7 +272,7 @@ template class BlockMemoryHashSet : public Logger::Loggableoptions.num_slots; + return Value::hash(key) % control_->hash_set_options.num_slots; } /** @@ -290,11 +293,11 @@ template class BlockMemoryHashSet : public Logger::Loggable class BlockMemoryHashSet : public Logger::Loggable 0) && ((alignment & (alignment - 1)) == 0)); + RELEASE_ASSERT((alignment > 0) && ((alignment & (alignment - 1)) == 0), ""); return (size + alignment - 1) & ~(alignment - 1); } @@ -323,10 +326,11 @@ template class BlockMemoryHashSet : public Logger::Loggable class BlockMemoryHashSet : public Logger::Loggable(cells_) + cellOffset(cell_index); - RELEASE_ASSERT((reinterpret_cast(ptr) & (calculateAlignment() - 1)) == 0); + char* ptr = reinterpret_cast(cells_) + cellOffset(cell_index, stats_options_); + RELEASE_ASSERT((reinterpret_cast(ptr) & (calculateAlignment() - 1)) == 0, ""); return *reinterpret_cast(ptr); } /** Maps out the segments of memory for us to work with. */ - void mapMemorySegments(const BlockMemoryHashSetOptions& options, uint8_t* memory) { + void mapMemorySegments(const BlockMemoryHashSetOptions& hash_set_options, uint8_t* memory) { // Note that we are not examining or mutating memory here, just looking at the pointer, // so we don't need to hold any locks. cells_ = reinterpret_cast(memory); // First because Value may need to be aligned. - memory += cellOffset(options.capacity); + memory += cellOffset(hash_set_options.capacity, stats_options_); control_ = reinterpret_cast(memory); memory += sizeof(Control); slots_ = reinterpret_cast(memory); @@ -356,6 +360,7 @@ template class BlockMemoryHashSet : public Logger::Loggablematch(value.number_value()); + case ProtobufWkt::Value::kStringValue: + return string_matcher_.has_value() && string_matcher_->match(value.string_value()); + case ProtobufWkt::Value::kBoolValue: + return (bool_matcher_.has_value() && *bool_matcher_ == value.bool_value()); + case ProtobufWkt::Value::kListValue: + case ProtobufWkt::Value::kStructValue: + case ProtobufWkt::Value::KIND_NOT_SET: + return false; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +} // namespace Matchers +} // namespace Envoy diff --git a/source/common/common/matchers.h b/source/common/common/matchers.h new file mode 100644 index 0000000000000..36197b7f7c1d6 --- /dev/null +++ b/source/common/common/matchers.h @@ -0,0 +1,76 @@ +#pragma once + +#include + +#include "envoy/api/v2/core/base.pb.h" +#include "envoy/type/matcher/metadata.pb.h" +#include "envoy/type/matcher/number.pb.h" +#include "envoy/type/matcher/string.pb.h" + +#include "common/common/utility.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Matchers { + +class DoubleMatcher { +public: + DoubleMatcher(const envoy::type::matcher::DoubleMatcher& matcher) : matcher_(matcher) {} + + /** + * Check whether the value is matched to the matcher. + * @param value the double value to check. + * @return true if it's matched otherwise false. + */ + bool match(double value) const; + +private: + const envoy::type::matcher::DoubleMatcher matcher_; +}; + +class StringMatcher { +public: + StringMatcher(const envoy::type::matcher::StringMatcher& matcher) : matcher_(matcher) { + if (matcher.match_pattern_case() == envoy::type::matcher::StringMatcher::kRegex) { + regex_ = RegexUtil::parseRegex(matcher_.regex()); + } + } + + /** + * Check whether the value is matched to the matcher. + * @param value the string to check. + * @return true if it's matched otherwise false. + */ + bool match(const std::string& value) const; + +private: + const envoy::type::matcher::StringMatcher matcher_; + std::regex regex_; +}; + +class MetadataMatcher { +public: + MetadataMatcher(const envoy::type::matcher::MetadataMatcher& matcher); + + /** + * Check whether the metadata is matched to the matcher. + * @param metadata the metadata to check. + * @return true if it's matched otherwise false. + */ + bool match(const envoy::api::v2::core::Metadata& metadata) const; + +private: + const envoy::type::matcher::MetadataMatcher matcher_; + std::vector path_; + + bool null_matcher_{false}; + absl::optional bool_matcher_; + bool present_matcher_{false}; + + absl::optional double_matcher_; + absl::optional string_matcher_; +}; + +} // namespace Matchers +} // namespace Envoy diff --git a/source/common/common/thread.cc b/source/common/common/thread.cc index 5dae2f0f198f3..953e280642447 100644 --- a/source/common/common/thread.cc +++ b/source/common/common/thread.cc @@ -15,14 +15,14 @@ namespace Envoy { namespace Thread { Thread::Thread(std::function thread_routine) : thread_routine_(thread_routine) { - RELEASE_ASSERT(Logger::Registry::initialized()); + RELEASE_ASSERT(Logger::Registry::initialized(), ""); int rc = pthread_create(&thread_id_, nullptr, [](void* arg) -> void* { static_cast(arg)->thread_routine_(); return nullptr; }, this); - RELEASE_ASSERT(rc == 0); + RELEASE_ASSERT(rc == 0, ""); } int32_t Thread::currentThreadId() { @@ -39,7 +39,7 @@ int32_t Thread::currentThreadId() { void Thread::join() { int rc = pthread_join(thread_id_, nullptr); - RELEASE_ASSERT(rc == 0); + RELEASE_ASSERT(rc == 0, ""); } } // namespace Thread diff --git a/source/common/compressor/zlib_compressor_impl.cc b/source/common/compressor/zlib_compressor_impl.cc index 65a4ab045baac..b427fb91e1a6f 100644 --- a/source/common/compressor/zlib_compressor_impl.cc +++ b/source/common/compressor/zlib_compressor_impl.cc @@ -27,7 +27,7 @@ void ZlibCompressorImpl::init(CompressionLevel comp_level, CompressionStrategy c ASSERT(initialized_ == false); const int result = deflateInit2(zstream_ptr_.get(), static_cast(comp_level), Z_DEFLATED, window_bits, memory_level, static_cast(comp_strategy)); - RELEASE_ASSERT(result >= 0); + RELEASE_ASSERT(result >= 0, ""); initialized_ = true; } @@ -57,7 +57,7 @@ bool ZlibCompressorImpl::deflateNext(int64_t flush_state) { switch (flush_state) { case Z_FINISH: if (result != Z_OK && result != Z_BUF_ERROR) { - RELEASE_ASSERT(result == Z_STREAM_END); + RELEASE_ASSERT(result == Z_STREAM_END, ""); return false; } FALLTHRU; @@ -65,7 +65,7 @@ bool ZlibCompressorImpl::deflateNext(int64_t flush_state) { if (result == Z_BUF_ERROR && zstream_ptr_->avail_in == 0) { return false; // This means that zlib needs more input, so stop here. } - RELEASE_ASSERT(result == Z_OK); + RELEASE_ASSERT(result == Z_OK, ""); } return true; diff --git a/source/common/config/BUILD b/source/common/config/BUILD index 83067aa8806eb..a463141313e37 100644 --- a/source/common/config/BUILD +++ b/source/common/config/BUILD @@ -35,6 +35,7 @@ envoy_cc_library( "//source/common/common:assert_lib", "//source/common/json:config_schemas_lib", "//source/common/protobuf:utility_lib", + "//source/common/stats:stats_lib", "//source/extensions/stat_sinks:well_known_names", "@envoy_api//envoy/config/bootstrap/v2:bootstrap_cc", ], @@ -66,6 +67,7 @@ envoy_cc_library( "//source/common/common:assert_lib", "//source/common/json:config_schemas_lib", "//source/common/network:utility_lib", + "//source/common/stats:stats_lib", "@envoy_api//envoy/api/v2:cds_cc", "@envoy_api//envoy/api/v2/cluster:circuit_breaker_cc", ], @@ -106,6 +108,7 @@ envoy_cc_library( ":rds_json_lib", ":utility_lib", "//include/envoy/json:json_object_interface", + "//include/envoy/stats:stats_interface", "//source/common/common:assert_lib", "//source/common/common:utility_lib", "//source/common/json:config_schemas_lib", @@ -142,6 +145,7 @@ envoy_cc_library( "//include/envoy/config:subscription_interface", "//include/envoy/grpc:async_client_interface", "//include/envoy/upstream:cluster_manager_interface", + "//source/common/common:backoff_lib", "//source/common/common:minimal_logger_lib", "//source/common/common:token_bucket_impl_lib", "//source/common/protobuf", @@ -217,6 +221,7 @@ envoy_cc_library( "//source/common/common:assert_lib", "//source/common/json:config_schemas_lib", "//source/common/network:utility_lib", + "//source/common/stats:stats_lib", "//source/extensions/filters/network:well_known_names", "@envoy_api//envoy/api/v2:lds_cc", ], @@ -270,6 +275,7 @@ envoy_cc_library( "//source/common/common:assert_lib", "//source/common/config:utility_lib", "//source/common/json:config_schemas_lib", + "//source/common/stats:stats_lib", "//source/extensions/filters/http:well_known_names", "@envoy_api//envoy/api/v2:rds_cc", ], diff --git a/source/common/config/bootstrap_json.cc b/source/common/config/bootstrap_json.cc index 5db8f7e536a96..4373f5f69e37a 100644 --- a/source/common/config/bootstrap_json.cc +++ b/source/common/config/bootstrap_json.cc @@ -15,7 +15,8 @@ namespace Envoy { namespace Config { void BootstrapJson::translateClusterManagerBootstrap( - const Json::Object& json_cluster_manager, envoy::config::bootstrap::v2::Bootstrap& bootstrap) { + const Json::Object& json_cluster_manager, envoy::config::bootstrap::v2::Bootstrap& bootstrap, + const Stats::StatsOptions& stats_options) { json_cluster_manager.validateSchema(Json::Schema::CLUSTER_MANAGER_SCHEMA); absl::optional eds_config; @@ -24,7 +25,7 @@ void BootstrapJson::translateClusterManagerBootstrap( auto* cluster = bootstrap.mutable_static_resources()->mutable_clusters()->Add(); Config::CdsJson::translateCluster(*json_sds->getObject("cluster"), absl::optional(), - *cluster); + *cluster, stats_options); Config::Utility::translateEdsConfig( *json_sds, *bootstrap.mutable_dynamic_resources()->mutable_deprecated_v1()->mutable_sds_config()); @@ -34,7 +35,8 @@ void BootstrapJson::translateClusterManagerBootstrap( if (json_cluster_manager.hasObject("cds")) { const auto json_cds = json_cluster_manager.getObject("cds"); auto* cluster = bootstrap.mutable_static_resources()->mutable_clusters()->Add(); - Config::CdsJson::translateCluster(*json_cds->getObject("cluster"), eds_config, *cluster); + Config::CdsJson::translateCluster(*json_cds->getObject("cluster"), eds_config, *cluster, + stats_options); Config::Utility::translateCdsConfig( *json_cds, *bootstrap.mutable_dynamic_resources()->mutable_cds_config()); } @@ -42,7 +44,7 @@ void BootstrapJson::translateClusterManagerBootstrap( for (const Json::ObjectSharedPtr& json_cluster : json_cluster_manager.getObjectArray("clusters")) { auto* cluster = bootstrap.mutable_static_resources()->mutable_clusters()->Add(); - Config::CdsJson::translateCluster(*json_cluster, eds_config, *cluster); + Config::CdsJson::translateCluster(*json_cluster, eds_config, *cluster, stats_options); } auto* cluster_manager = bootstrap.mutable_cluster_manager(); @@ -54,10 +56,12 @@ void BootstrapJson::translateClusterManagerBootstrap( } void BootstrapJson::translateBootstrap(const Json::Object& json_config, - envoy::config::bootstrap::v2::Bootstrap& bootstrap) { + envoy::config::bootstrap::v2::Bootstrap& bootstrap, + const Stats::StatsOptions& stats_options) { json_config.validateSchema(Json::Schema::TOP_LEVEL_CONFIG_SCHEMA); - translateClusterManagerBootstrap(*json_config.getObject("cluster_manager"), bootstrap); + translateClusterManagerBootstrap(*json_config.getObject("cluster_manager"), bootstrap, + stats_options); if (json_config.hasObject("lds")) { auto* lds_config = bootstrap.mutable_dynamic_resources()->mutable_lds_config(); @@ -66,7 +70,7 @@ void BootstrapJson::translateBootstrap(const Json::Object& json_config, for (const auto json_listener : json_config.getObjectArray("listeners")) { auto* listener = bootstrap.mutable_static_resources()->mutable_listeners()->Add(); - Config::LdsJson::translateListener(*json_listener, *listener); + Config::LdsJson::translateListener(*json_listener, *listener, stats_options); } JSON_UTIL_SET_STRING(json_config, bootstrap, flags_path); @@ -74,7 +78,7 @@ void BootstrapJson::translateBootstrap(const Json::Object& json_config, auto* stats_sinks = bootstrap.mutable_stats_sinks(); if (json_config.hasObject("statsd_udp_ip_address")) { auto* stats_sink = stats_sinks->Add(); - stats_sink->set_name(Extensions::StatSinks::StatsSinkNames::get().STATSD); + stats_sink->set_name(Extensions::StatSinks::StatsSinkNames::get().Statsd); envoy::config::metrics::v2::StatsdSink statsd_sink; AddressJson::translateAddress(json_config.getString("statsd_udp_ip_address"), false, true, *statsd_sink.mutable_address()); @@ -83,7 +87,7 @@ void BootstrapJson::translateBootstrap(const Json::Object& json_config, if (json_config.hasObject("statsd_tcp_cluster_name")) { auto* stats_sink = stats_sinks->Add(); - stats_sink->set_name(Extensions::StatSinks::StatsSinkNames::get().STATSD); + stats_sink->set_name(Extensions::StatSinks::StatsSinkNames::get().Statsd); envoy::config::metrics::v2::StatsdSink statsd_sink; statsd_sink.set_tcp_cluster_name(json_config.getString("statsd_tcp_cluster_name")); MessageUtil::jsonConvert(statsd_sink, *stats_sink->mutable_config()); diff --git a/source/common/config/bootstrap_json.h b/source/common/config/bootstrap_json.h index 80c567d7264ad..dd77bc1f9e705 100644 --- a/source/common/config/bootstrap_json.h +++ b/source/common/config/bootstrap_json.h @@ -2,6 +2,7 @@ #include "envoy/config/bootstrap/v2/bootstrap.pb.h" #include "envoy/json/json_object.h" +#include "envoy/stats/stats.h" namespace Envoy { namespace Config { @@ -14,7 +15,8 @@ class BootstrapJson { * @param bootstrap destination v2 envoy::config::bootstrap::v2::Bootstrap. */ static void translateClusterManagerBootstrap(const Json::Object& json_cluster_manager, - envoy::config::bootstrap::v2::Bootstrap& bootstrap); + envoy::config::bootstrap::v2::Bootstrap& bootstrap, + const Stats::StatsOptions& stats_options); /** * Translate a v1 JSON static config object to v2 envoy::config::bootstrap::v2::Bootstrap. @@ -22,7 +24,8 @@ class BootstrapJson { * @param bootstrap destination v2 envoy::config::bootstrap::v2::Bootstrap. */ static void translateBootstrap(const Json::Object& json_config, - envoy::config::bootstrap::v2::Bootstrap& bootstrap); + envoy::config::bootstrap::v2::Bootstrap& bootstrap, + const Stats::StatsOptions& stats_options); }; } // namespace Config diff --git a/source/common/config/cds_json.cc b/source/common/config/cds_json.cc index 3cf33249a7648..a7a684a590d62 100644 --- a/source/common/config/cds_json.cc +++ b/source/common/config/cds_json.cc @@ -53,9 +53,11 @@ void CdsJson::translateHealthCheck(const Json::Object& json_health_check, } } else { ASSERT(hc_type == "redis"); - auto* redis_health_check = health_check.mutable_redis_health_check(); + auto* redis_health_check = health_check.mutable_custom_health_check(); + redis_health_check->set_name("envoy.health_checkers.redis"); if (json_health_check.hasObject("redis_key")) { - redis_health_check->set_key(json_health_check.getString("redis_key")); + redis_health_check->mutable_config()->MergeFrom( + MessageUtil::keyValueStruct("key", json_health_check.getString("redis_key"))); } } } @@ -99,11 +101,12 @@ void CdsJson::translateOutlierDetection( void CdsJson::translateCluster(const Json::Object& json_cluster, const absl::optional& eds_config, - envoy::api::v2::Cluster& cluster) { + envoy::api::v2::Cluster& cluster, + const Stats::StatsOptions& stats_options) { json_cluster.validateSchema(Json::Schema::CLUSTER_SCHEMA); const std::string name = json_cluster.getString("name"); - Utility::checkObjNameLength("Invalid cluster name", name); + Utility::checkObjNameLength("Invalid cluster name", name, stats_options); cluster.set_name(name); const std::string string_type = json_cluster.getString("type"); diff --git a/source/common/config/cds_json.h b/source/common/config/cds_json.h index f2995074f79ac..8d48d02ac6d6d 100644 --- a/source/common/config/cds_json.h +++ b/source/common/config/cds_json.h @@ -3,6 +3,7 @@ #include "envoy/api/v2/cds.pb.h" #include "envoy/api/v2/cluster/circuit_breaker.pb.h" #include "envoy/json/json_object.h" +#include "envoy/stats/stats.h" #include "envoy/upstream/cluster_manager.h" #include "absl/types/optional.h" @@ -64,7 +65,8 @@ class CdsJson { */ static void translateCluster(const Json::Object& json_cluster, const absl::optional& eds_config, - envoy::api::v2::Cluster& cluster); + envoy::api::v2::Cluster& cluster, + const Stats::StatsOptions& stats_options); }; } // namespace Config diff --git a/source/common/config/filter_json.cc b/source/common/config/filter_json.cc index 4ef1eaedb3760..8069fea03b715 100644 --- a/source/common/config/filter_json.cc +++ b/source/common/config/filter_json.cc @@ -116,7 +116,7 @@ void FilterJson::translateAccessLog(const Json::Object& json_config, // Statically registered access logs are a v2-only feature, so use the standard internal file // access log for json config conversion. - proto_config.set_name(Extensions::AccessLoggers::AccessLogNames::get().FILE); + proto_config.set_name(Extensions::AccessLoggers::AccessLogNames::get().File); if (json_config.hasObject("filter")) { translateAccessLogFilter(*json_config.getObject("filter"), *proto_config.mutable_filter()); @@ -126,7 +126,8 @@ void FilterJson::translateAccessLog(const Json::Object& json_config, void FilterJson::translateHttpConnectionManager( const Json::Object& json_config, envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& - proto_config) { + proto_config, + const Stats::StatsOptions& stats_options) { json_config.validateSchema(Json::Schema::HTTP_CONN_NETWORK_FILTER_SCHEMA); envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager::CodecType @@ -138,7 +139,8 @@ void FilterJson::translateHttpConnectionManager( JSON_UTIL_SET_STRING(json_config, proto_config, stat_prefix); if (json_config.hasObject("rds")) { - Utility::translateRdsConfig(*json_config.getObject("rds"), *proto_config.mutable_rds()); + Utility::translateRdsConfig(*json_config.getObject("rds"), *proto_config.mutable_rds(), + stats_options); } if (json_config.hasObject("route_config")) { if (json_config.hasObject("rds")) { @@ -146,7 +148,7 @@ void FilterJson::translateHttpConnectionManager( "http connection manager must have either rds or route_config but not both"); } RdsJson::translateRouteConfiguration(*json_config.getObject("route_config"), - *proto_config.mutable_route_config()); + *proto_config.mutable_route_config(), stats_options); } for (const auto& json_filter : json_config.getObjectArray("filters", true)) { @@ -220,7 +222,7 @@ void FilterJson::translateHttpConnectionManager( proto_config.mutable_set_current_client_cert_details()->mutable_subject()->set_value(true); } else { ASSERT(detail == "SAN"); - proto_config.mutable_set_current_client_cert_details()->mutable_san()->set_value(true); + proto_config.mutable_set_current_client_cert_details()->set_uri(true); } } } @@ -300,7 +302,10 @@ void FilterJson::translateHealthCheckFilter( JSON_UTIL_SET_BOOL(json_config, proto_config, pass_through_mode); JSON_UTIL_SET_DURATION(json_config, proto_config, cache_time); - JSON_UTIL_SET_STRING(json_config, proto_config, endpoint); + std::string endpoint = json_config.getString("endpoint"); + auto& header = *proto_config.add_headers(); + header.set_name(":path"); + header.set_exact_match(endpoint); } void FilterJson::translateGrpcJsonTranscoder( diff --git a/source/common/config/filter_json.h b/source/common/config/filter_json.h index 73ca0f7085dc3..64ea883ac1651 100644 --- a/source/common/config/filter_json.h +++ b/source/common/config/filter_json.h @@ -15,6 +15,7 @@ #include "envoy/config/filter/network/redis_proxy/v2/redis_proxy.pb.h" #include "envoy/config/filter/network/tcp_proxy/v2/tcp_proxy.pb.h" #include "envoy/json/json_object.h" +#include "envoy/stats/stats.h" namespace Envoy { namespace Config { @@ -49,7 +50,8 @@ class FilterJson { static void translateHttpConnectionManager( const Json::Object& json_config, envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& - proto_config); + proto_config, + const Stats::StatsOptions& stats_options); /** * Translate a v1 JSON Redis proxy object to v2 diff --git a/source/common/config/grpc_mux_impl.cc b/source/common/config/grpc_mux_impl.cc index 0334ee09c6d35..77b9fc84dc345 100644 --- a/source/common/config/grpc_mux_impl.cc +++ b/source/common/config/grpc_mux_impl.cc @@ -12,10 +12,12 @@ namespace Config { GrpcMuxImpl::GrpcMuxImpl(const envoy::api::v2::core::Node& node, Grpc::AsyncClientPtr async_client, Event::Dispatcher& dispatcher, const Protobuf::MethodDescriptor& service_method, - MonotonicTimeSource& time_source) + Runtime::RandomGenerator& random, MonotonicTimeSource& time_source) : node_(node), async_client_(std::move(async_client)), service_method_(service_method), - time_source_(time_source) { + random_(random), time_source_(time_source) { retry_timer_ = dispatcher.createTimer([this]() -> void { establishNewStream(); }); + backoff_strategy_ = std::make_unique(RETRY_INITIAL_DELAY_MS, + RETRY_MAX_DELAY_MS, random_); } GrpcMuxImpl::~GrpcMuxImpl() { @@ -29,7 +31,7 @@ GrpcMuxImpl::~GrpcMuxImpl() { void GrpcMuxImpl::start() { establishNewStream(); } void GrpcMuxImpl::setRetryTimer() { - retry_timer_->enableTimer(std::chrono::milliseconds(RETRY_DELAY_MS)); + retry_timer_->enableTimer(std::chrono::milliseconds(backoff_strategy_->nextBackOffMs())); } void GrpcMuxImpl::establishNewStream() { @@ -156,6 +158,9 @@ void GrpcMuxImpl::onReceiveInitialMetadata(Http::HeaderMapPtr&& metadata) { } void GrpcMuxImpl::onReceiveMessage(std::unique_ptr&& message) { + // Reset here so that it starts with fresh backoff interval on next disconnect. + backoff_strategy_->reset(); + const std::string& type_url = message->type_url(); ENVOY_LOG(debug, "Received gRPC message for {} at version {}", type_url, message->version_info()); if (api_state_.count(type_url) == 0) { diff --git a/source/common/config/grpc_mux_impl.h b/source/common/config/grpc_mux_impl.h index 0eb1bfb90b151..5f2c99c2e6168 100644 --- a/source/common/config/grpc_mux_impl.h +++ b/source/common/config/grpc_mux_impl.h @@ -11,6 +11,7 @@ #include "envoy/grpc/status.h" #include "envoy/upstream/cluster_manager.h" +#include "common/common/backoff_strategy.h" #include "common/common/logger.h" namespace Envoy { @@ -25,6 +26,7 @@ class GrpcMuxImpl : public GrpcMux, public: GrpcMuxImpl(const envoy::api::v2::core::Node& node, Grpc::AsyncClientPtr async_client, Event::Dispatcher& dispatcher, const Protobuf::MethodDescriptor& service_method, + Runtime::RandomGenerator& random, MonotonicTimeSource& time_source = ProdMonotonicTimeSource::instance_); ~GrpcMuxImpl(); @@ -42,7 +44,8 @@ class GrpcMuxImpl : public GrpcMux, void onRemoteClose(Grpc::Status::GrpcStatus status, const std::string& message) override; // TODO(htuch): Make this configurable or some static. - const uint32_t RETRY_DELAY_MS = 5000; + const uint32_t RETRY_INITIAL_DELAY_MS = 500; + const uint32_t RETRY_MAX_DELAY_MS = 30000; // Do not cross more than 30s private: void setRetryTimer(); @@ -100,7 +103,9 @@ class GrpcMuxImpl : public GrpcMux, // Envoy's dependendency ordering. std::list subscriptions_; Event::TimerPtr retry_timer_; + Runtime::RandomGenerator& random_; MonotonicTimeSource& time_source_; + BackOffStrategyPtr backoff_strategy_; }; class NullGrpcMuxImpl : public GrpcMux { diff --git a/source/common/config/grpc_mux_subscription_impl.h b/source/common/config/grpc_mux_subscription_impl.h index 790bfc19d3d91..2bd0808118270 100644 --- a/source/common/config/grpc_mux_subscription_impl.h +++ b/source/common/config/grpc_mux_subscription_impl.h @@ -64,7 +64,7 @@ class GrpcMuxSubscriptionImpl : public Subscription, // TODO(htuch): Less fragile signal that this is failure vs. reject. if (e == nullptr) { stats_.update_failure_.inc(); - ENVOY_LOG(warn, "gRPC update for {} failed", type_url_); + ENVOY_LOG(debug, "gRPC update for {} failed", type_url_); } else { stats_.update_rejected_.inc(); ENVOY_LOG(warn, "gRPC config for {} rejected: {}", type_url_, e->what()); diff --git a/source/common/config/grpc_subscription_impl.h b/source/common/config/grpc_subscription_impl.h index 4b07f02239900..6117b06b7cd10 100644 --- a/source/common/config/grpc_subscription_impl.h +++ b/source/common/config/grpc_subscription_impl.h @@ -15,9 +15,9 @@ template class GrpcSubscriptionImpl : public Config::Subscription { public: GrpcSubscriptionImpl(const envoy::api::v2::core::Node& node, Grpc::AsyncClientPtr async_client, - Event::Dispatcher& dispatcher, + Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, const Protobuf::MethodDescriptor& service_method, SubscriptionStats stats) - : grpc_mux_(node, std::move(async_client), dispatcher, service_method), + : grpc_mux_(node, std::move(async_client), dispatcher, service_method, random), grpc_mux_subscription_(grpc_mux_, stats) {} // Config::Subscription diff --git a/source/common/config/lds_json.cc b/source/common/config/lds_json.cc index c2bc7af1761b7..5ca9f0f7cf8f0 100644 --- a/source/common/config/lds_json.cc +++ b/source/common/config/lds_json.cc @@ -14,11 +14,12 @@ namespace Envoy { namespace Config { void LdsJson::translateListener(const Json::Object& json_listener, - envoy::api::v2::Listener& listener) { + envoy::api::v2::Listener& listener, + const Stats::StatsOptions& stats_options) { json_listener.validateSchema(Json::Schema::LISTENER_SCHEMA); const std::string name = json_listener.getString("name", ""); - Utility::checkObjNameLength("Invalid listener name", name); + Utility::checkObjNameLength("Invalid listener name", name, stats_options); listener.set_name(name); AddressJson::translateAddress(json_listener.getString("address"), true, true, diff --git a/source/common/config/lds_json.h b/source/common/config/lds_json.h index 4848192acf570..d15eab2fb3c6e 100644 --- a/source/common/config/lds_json.h +++ b/source/common/config/lds_json.h @@ -3,6 +3,7 @@ #include "envoy/api/v2/lds.pb.h" #include "envoy/api/v2/listener/listener.pb.h" #include "envoy/json/json_object.h" +#include "envoy/stats/stats.h" namespace Envoy { namespace Config { @@ -15,7 +16,8 @@ class LdsJson { * @param listener destination v2 envoy::api::v2::Listener. */ static void translateListener(const Json::Object& json_listener, - envoy::api::v2::Listener& listener); + envoy::api::v2::Listener& listener, + const Stats::StatsOptions& stats_options); }; } // namespace Config diff --git a/source/common/config/rds_json.cc b/source/common/config/rds_json.cc index ceb892d69aaa6..8e91e524ecbf3 100644 --- a/source/common/config/rds_json.cc +++ b/source/common/config/rds_json.cc @@ -116,12 +116,13 @@ void RdsJson::translateQueryParameterMatcher( } void RdsJson::translateRouteConfiguration(const Json::Object& json_route_config, - envoy::api::v2::RouteConfiguration& route_config) { + envoy::api::v2::RouteConfiguration& route_config, + const Stats::StatsOptions& stats_options) { json_route_config.validateSchema(Json::Schema::ROUTE_CONFIGURATION_SCHEMA); for (const auto json_virtual_host : json_route_config.getObjectArray("virtual_hosts", true)) { auto* virtual_host = route_config.mutable_virtual_hosts()->Add(); - translateVirtualHost(*json_virtual_host, *virtual_host); + translateVirtualHost(*json_virtual_host, *virtual_host, stats_options); } for (const std::string& header : @@ -149,11 +150,12 @@ void RdsJson::translateRouteConfiguration(const Json::Object& json_route_config, } void RdsJson::translateVirtualHost(const Json::Object& json_virtual_host, - envoy::api::v2::route::VirtualHost& virtual_host) { + envoy::api::v2::route::VirtualHost& virtual_host, + const Stats::StatsOptions& stats_options) { json_virtual_host.validateSchema(Json::Schema::VIRTUAL_HOST_CONFIGURATION_SCHEMA); const std::string name = json_virtual_host.getString("name", ""); - Utility::checkObjNameLength("Invalid virtual host name", name); + Utility::checkObjNameLength("Invalid virtual host name", name, stats_options); virtual_host.set_name(name); for (const std::string& domain : json_virtual_host.getStringArray("domains", true)) { @@ -340,7 +342,7 @@ void RdsJson::translateRoute(const Json::Object& json_route, envoy::api::v2::rou const Json::ObjectSharedPtr obj = json_route.getObject("opaque_config"); auto& filter_metadata = (*route.mutable_metadata() - ->mutable_filter_metadata())[Extensions::HttpFilters::HttpFilterNames::get().ROUTER]; + ->mutable_filter_metadata())[Extensions::HttpFilters::HttpFilterNames::get().Router]; obj->iterate([&filter_metadata](const std::string& name, const Json::Object& value) { (*filter_metadata.mutable_fields())[name].set_string_value(value.asString()); return true; diff --git a/source/common/config/rds_json.h b/source/common/config/rds_json.h index dba4c8b5b6188..987a99cf5d9fd 100644 --- a/source/common/config/rds_json.h +++ b/source/common/config/rds_json.h @@ -3,6 +3,7 @@ #include "envoy/api/v2/rds.pb.h" #include "envoy/api/v2/route/route.pb.h" #include "envoy/json/json_object.h" +#include "envoy/stats/stats.h" namespace Envoy { namespace Config { @@ -64,7 +65,8 @@ class RdsJson { * @param route_config destination v2 envoy::api::v2::RouteConfiguration. */ static void translateRouteConfiguration(const Json::Object& json_route_config, - envoy::api::v2::RouteConfiguration& route_config); + envoy::api::v2::RouteConfiguration& route_config, + const Stats::StatsOptions& stats_options); /** * Translate a v1 JSON virtual host object to v2 envoy::api::v2::route::VirtualHost. @@ -72,7 +74,8 @@ class RdsJson { * @param virtual_host destination v2 envoy::api::v2::route::VirtualHost. */ static void translateVirtualHost(const Json::Object& json_virtual_host, - envoy::api::v2::route::VirtualHost& virtual_host); + envoy::api::v2::route::VirtualHost& virtual_host, + const Stats::StatsOptions& stats_options); /** * Translate a v1 JSON decorator object to v2 envoy::api::v2::route::Decorator. diff --git a/source/common/config/subscription_factory.h b/source/common/config/subscription_factory.h index f80cc09219f62..25cbe119860c1 100644 --- a/source/common/config/subscription_factory.h +++ b/source/common/config/subscription_factory.h @@ -68,12 +68,12 @@ class SubscriptionFactory { Config::Utility::factoryForGrpcApiConfigSource(cm.grpcAsyncClientManager(), config.api_config_source(), scope) ->create(), - dispatcher, *Protobuf::DescriptorPool::generated_pool()->FindMethodByName(grpc_method), - stats)); + dispatcher, random, + *Protobuf::DescriptorPool::generated_pool()->FindMethodByName(grpc_method), stats)); break; } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } break; } diff --git a/source/common/config/utility.cc b/source/common/config/utility.cc index aa3d8cd80fe00..37fac934f2ad5 100644 --- a/source/common/config/utility.cc +++ b/source/common/config/utility.cc @@ -25,15 +25,19 @@ void Utility::translateApiConfigSource(const std::string& cluster, uint32_t refr envoy::api::v2::core::ApiConfigSource& api_config_source) { // TODO(junr03): document the option to chose an api type once we have created // stronger constraints around v2. - if (api_type == ApiType::get().RestLegacy) { - api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::REST_LEGACY); - } else if (api_type == ApiType::get().Rest) { - api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::REST); - } else { - ASSERT(api_type == ApiType::get().Grpc); + if (api_type == ApiType::get().Grpc) { api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::GRPC); + envoy::api::v2::core::GrpcService* grpc_service = api_config_source.add_grpc_services(); + grpc_service->mutable_envoy_grpc()->set_cluster_name(cluster); + } else { + if (api_type == ApiType::get().RestLegacy) { + api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::REST_LEGACY); + } else if (api_type == ApiType::get().Rest) { + api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::REST); + } + api_config_source.add_cluster_names(cluster); } - api_config_source.add_cluster_names(cluster); + api_config_source.mutable_refresh_delay()->CopyFrom( Protobuf::util::TimeUtil::MillisecondsToDuration(refresh_delay_ms)); } @@ -83,26 +87,21 @@ void Utility::checkApiConfigSourceNames( const bool is_grpc = (api_config_source.api_type() == envoy::api::v2::core::ApiConfigSource::GRPC); - if (api_config_source.cluster_names().size() == 0 && - api_config_source.grpc_services().size() == 0) { + if (api_config_source.cluster_names().empty() && api_config_source.grpc_services().empty()) { throw EnvoyException("API configs must have either a gRPC service or a cluster name defined"); } if (is_grpc) { - if (api_config_source.cluster_names().size() != 0) { - ENVOY_LOG_MISC(warn, "Setting a cluster name for API config source type " - "envoy::api::v2::core::ConfigSource::GRPC is deprecated"); - } - if (api_config_source.cluster_names().size() > 1) { + if (!api_config_source.cluster_names().empty()) { throw EnvoyException( - "envoy::api::v2::core::ConfigSource must have a singleton cluster name specified"); + "envoy::api::v2::core::ConfigSource::GRPC must not have a cluster name specified."); } if (api_config_source.grpc_services().size() > 1) { throw EnvoyException( "envoy::api::v2::core::ConfigSource::GRPC must have a single gRPC service specified"); } } else { - if (api_config_source.grpc_services().size() != 0) { + if (!api_config_source.grpc_services().empty()) { throw EnvoyException("envoy::api::v2::core::ConfigSource, if not of type gRPC, must not have " "a gRPC service specified"); } @@ -178,11 +177,12 @@ void Utility::translateCdsConfig(const Json::Object& json_config, void Utility::translateRdsConfig( const Json::Object& json_rds, - envoy::config::filter::network::http_connection_manager::v2::Rds& rds) { + envoy::config::filter::network::http_connection_manager::v2::Rds& rds, + const Stats::StatsOptions& stats_options) { json_rds.validateSchema(Json::Schema::RDS_CONFIGURATION_SCHEMA); const std::string name = json_rds.getString("route_config_name", ""); - checkObjNameLength("Invalid route_config name", name); + checkObjNameLength("Invalid route_config name", name, stats_options); rds.set_route_config_name(name); translateApiConfigSource(json_rds.getString("cluster"), @@ -205,11 +205,12 @@ Utility::createTagProducer(const envoy::config::bootstrap::v2::Bootstrap& bootst return std::make_unique(bootstrap.stats_config()); } -void Utility::checkObjNameLength(const std::string& error_prefix, const std::string& name) { - if (name.length() > Stats::RawStatData::maxObjNameLength()) { +void Utility::checkObjNameLength(const std::string& error_prefix, const std::string& name, + const Stats::StatsOptions& stats_options) { + if (name.length() > stats_options.maxNameLength()) { throw EnvoyException(fmt::format("{}: Length of {} ({}) exceeds allowed maximum length ({})", error_prefix, name, name.length(), - Stats::RawStatData::maxObjNameLength())); + stats_options.maxNameLength())); } } @@ -219,14 +220,25 @@ Grpc::AsyncClientFactoryPtr Utility::factoryForGrpcApiConfigSource( Utility::checkApiConfigSourceNames(api_config_source); envoy::api::v2::core::GrpcService grpc_service; - if (api_config_source.cluster_names().empty()) { - grpc_service.MergeFrom(api_config_source.grpc_services(0)); - } else { - grpc_service.mutable_envoy_grpc()->set_cluster_name(api_config_source.cluster_names(0)); - } + grpc_service.MergeFrom(api_config_source.grpc_services(0)); return async_client_manager.factoryForGrpcService(grpc_service, scope, false); } +envoy::api::v2::ClusterLoadAssignment Utility::translateClusterHosts( + const Protobuf::RepeatedPtrField& hosts) { + envoy::api::v2::ClusterLoadAssignment load_assignment; + envoy::api::v2::endpoint::LocalityLbEndpoints* locality_lb_endpoints = + load_assignment.add_endpoints(); + // Since this LocalityLbEndpoints is built from hosts list, set the default weight to 1. + locality_lb_endpoints->mutable_load_balancing_weight()->set_value(1); + for (const envoy::api::v2::core::Address& host : hosts) { + envoy::api::v2::endpoint::LbEndpoint* lb_endpoint = locality_lb_endpoints->add_lb_endpoints(); + lb_endpoint->mutable_endpoint()->mutable_address()->MergeFrom(host); + lb_endpoint->mutable_load_balancing_weight()->set_value(1); + } + return load_assignment; +} + } // namespace Config } // namespace Envoy diff --git a/source/common/config/utility.h b/source/common/config/utility.h index b6465238c009f..79086504ad94b 100644 --- a/source/common/config/utility.h +++ b/source/common/config/utility.h @@ -173,7 +173,8 @@ class Utility { */ static void translateRdsConfig(const Json::Object& json_rds, - envoy::config::filter::network::http_connection_manager::v2::Rds& rds); + envoy::config::filter::network::http_connection_manager::v2::Rds& rds, + const Stats::StatsOptions& stats_options); /** * Convert a v1 LDS JSON config to v2 LDS envoy::api::v2::core::ConfigSource. @@ -227,7 +228,7 @@ class Utility { ProtobufTypes::MessagePtr config = factory.createEmptyConfigProto(); // Fail in an obvious way if a plugin does not return a proto. - RELEASE_ASSERT(config != nullptr); + RELEASE_ASSERT(config != nullptr, ""); if (enclosing_message.has_config()) { MessageUtil::jsonConvert(enclosing_message.config(), *config); @@ -251,7 +252,7 @@ class Utility { ProtobufTypes::MessagePtr config = factory.createEmptyRouteConfigProto(); // Fail in an obvious way if a plugin does not return a proto. - RELEASE_ASSERT(config != nullptr); + RELEASE_ASSERT(config != nullptr, ""); MessageUtil::jsonConvert(source, *config); return config; @@ -271,8 +272,11 @@ class Utility { * It should be within the configured length limit. Throws on error. * @param error_prefix supplies the prefix to use in error messages. * @param name supplies the name to check for length limits. + * @param stats_options the top-level statsOptions struct, which contains the max stat name / + * suffix lengths for stats. */ - static void checkObjNameLength(const std::string& error_prefix, const std::string& name); + static void checkObjNameLength(const std::string& error_prefix, const std::string& name, + const Stats::StatsOptions& stats_options); /** * Obtain gRPC async client factory from a envoy::api::v2::core::ApiConfigSource. @@ -284,6 +288,14 @@ class Utility { factoryForGrpcApiConfigSource(Grpc::AsyncClientManager& async_client_manager, const envoy::api::v2::core::ApiConfigSource& api_config_source, Stats::Scope& scope); + + /** + * Translate a set of cluster's hosts into a load assignment configuration. + * @param hosts cluster's list of hosts. + * @return envoy::api::v2::ClusterLoadAssignment a load assignment configuration. + */ + static envoy::api::v2::ClusterLoadAssignment + translateClusterHosts(const Protobuf::RepeatedPtrField& hosts); }; } // namespace Config diff --git a/source/common/decompressor/zlib_decompressor_impl.cc b/source/common/decompressor/zlib_decompressor_impl.cc index b18270e48d6f5..c68c9c1c906a8 100644 --- a/source/common/decompressor/zlib_decompressor_impl.cc +++ b/source/common/decompressor/zlib_decompressor_impl.cc @@ -25,7 +25,7 @@ ZlibDecompressorImpl::ZlibDecompressorImpl(uint64_t chunk_size) void ZlibDecompressorImpl::init(int64_t window_bits) { ASSERT(initialized_ == false); const int result = inflateInit2(zstream_ptr_.get(), window_bits); - RELEASE_ASSERT(result >= 0); + RELEASE_ASSERT(result >= 0, ""); initialized_ = true; } @@ -70,7 +70,7 @@ bool ZlibDecompressorImpl::inflateNext() { return false; // This means that zlib needs more input, so stop here. } - RELEASE_ASSERT(result == Z_OK); + RELEASE_ASSERT(result == Z_OK, ""); return true; } diff --git a/source/common/event/BUILD b/source/common/event/BUILD index 68420b3312b53..fe1eb52e31092 100644 --- a/source/common/event/BUILD +++ b/source/common/event/BUILD @@ -59,7 +59,6 @@ envoy_cc_library( hdrs = ["libevent.h"], external_deps = [ "event", - "event_pthreads", ], deps = [ "//source/common/common:assert_lib", @@ -73,7 +72,6 @@ envoy_cc_library( hdrs = ["dispatched_thread.h"], external_deps = [ "event", - "event_pthreads", ], deps = [ ":dispatcher_lib", diff --git a/source/common/event/dispatcher_impl.cc b/source/common/event/dispatcher_impl.cc index 9044dc75b724d..ff8c3a1c76b2a 100644 --- a/source/common/event/dispatcher_impl.cc +++ b/source/common/event/dispatcher_impl.cc @@ -27,7 +27,7 @@ namespace Event { DispatcherImpl::DispatcherImpl() : DispatcherImpl(Buffer::WatermarkFactoryPtr{new Buffer::WatermarkBufferFactory}) { // The dispatcher won't work as expected if libevent hasn't been configured to use threads. - RELEASE_ASSERT(Libevent::Global::initialized()); + RELEASE_ASSERT(Libevent::Global::initialized(), ""); } DispatcherImpl::DispatcherImpl(Buffer::WatermarkFactoryPtr&& factory) @@ -35,7 +35,7 @@ DispatcherImpl::DispatcherImpl(Buffer::WatermarkFactoryPtr&& factory) deferred_delete_timer_(createTimer([this]() -> void { clearDeferredDeleteList(); })), post_timer_(createTimer([this]() -> void { runPostCallbacks(); })), current_to_delete_(&to_delete_1_) { - RELEASE_ASSERT(Libevent::Global::initialized()); + RELEASE_ASSERT(Libevent::Global::initialized(), ""); } DispatcherImpl::~DispatcherImpl() {} diff --git a/source/common/filesystem/BUILD b/source/common/filesystem/BUILD index b548a03213262..f357a1295769b 100644 --- a/source/common/filesystem/BUILD +++ b/source/common/filesystem/BUILD @@ -40,7 +40,9 @@ envoy_cc_library( "inotify/watcher_impl.h", ], }), - external_deps = ["event"], + external_deps = [ + "event", + ], strip_include_prefix = select({ "@bazel_tools//tools/osx:darwin": "kqueue", "//conditions:default": "inotify", diff --git a/source/common/filesystem/inotify/watcher_impl.cc b/source/common/filesystem/inotify/watcher_impl.cc index 1d9c65b72033b..f61ac242b5182 100644 --- a/source/common/filesystem/inotify/watcher_impl.cc +++ b/source/common/filesystem/inotify/watcher_impl.cc @@ -24,7 +24,7 @@ WatcherImpl::WatcherImpl(Event::Dispatcher& dispatcher) }, Event::FileTriggerType::Edge, Event::FileReadyType::Read)) { - RELEASE_ASSERT(inotify_fd_ >= 0); + RELEASE_ASSERT(inotify_fd_ >= 0, ""); } WatcherImpl::~WatcherImpl() { close(inotify_fd_); } @@ -57,7 +57,7 @@ void WatcherImpl::onInotifyEvent() { if (rc == -1 && errno == EAGAIN) { return; } - RELEASE_ASSERT(rc >= 0); + RELEASE_ASSERT(rc >= 0, ""); const size_t event_count = rc; size_t index = 0; diff --git a/source/common/grpc/async_client_manager_impl.cc b/source/common/grpc/async_client_manager_impl.cc index 606d550c2e30a..7a9693db19938 100644 --- a/source/common/grpc/async_client_manager_impl.cc +++ b/source/common/grpc/async_client_manager_impl.cc @@ -80,7 +80,7 @@ AsyncClientManagerImpl::factoryForGrpcService(const envoy::api::v2::core::GrpcSe return std::make_unique(tls_, google_tls_slot_.get(), scope, config); default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } return nullptr; } diff --git a/source/common/grpc/google_async_client_impl.cc b/source/common/grpc/google_async_client_impl.cc index 53957f3f46823..7708feed95ba9 100644 --- a/source/common/grpc/google_async_client_impl.cc +++ b/source/common/grpc/google_async_client_impl.cc @@ -332,7 +332,7 @@ void GoogleAsyncStreamImpl::handleOpCompletion(GoogleAsyncTag::Operation op, boo break; } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } diff --git a/source/common/http/BUILD b/source/common/http/BUILD index 13ab86a359b86..002c311e35bd3 100644 --- a/source/common/http/BUILD +++ b/source/common/http/BUILD @@ -80,6 +80,7 @@ envoy_cc_library( "//include/envoy/stats:stats_interface", "//source/common/common:enum_to_int", "//source/common/common:utility_lib", + "@envoy_api//envoy/type:http_status_cc", ], ) diff --git a/source/common/http/async_client_impl.cc b/source/common/http/async_client_impl.cc index 6d8e2fbdbf242..e9e435cd7121b 100644 --- a/source/common/http/async_client_impl.cc +++ b/source/common/http/async_client_impl.cc @@ -13,6 +13,7 @@ namespace Envoy { namespace Http { const std::list AsyncStreamImpl::NullCorsPolicy::allow_origin_; +const std::list AsyncStreamImpl::NullCorsPolicy::allow_origin_regex_; const absl::optional AsyncStreamImpl::NullCorsPolicy::allow_credentials_; const std::vector> AsyncStreamImpl::NullRateLimitPolicy::rate_limit_policy_entry_; diff --git a/source/common/http/async_client_impl.h b/source/common/http/async_client_impl.h index 050a0731d074f..4632166133a7d 100644 --- a/source/common/http/async_client_impl.h +++ b/source/common/http/async_client_impl.h @@ -90,6 +90,9 @@ class AsyncStreamImpl : public AsyncClient::Stream, struct NullCorsPolicy : public Router::CorsPolicy { // Router::CorsPolicy const std::list& allowOrigins() const override { return allow_origin_; }; + const std::list& allowOriginRegexes() const override { + return allow_origin_regex_; + }; const std::string& allowMethods() const override { return EMPTY_STRING; }; const std::string& allowHeaders() const override { return EMPTY_STRING; }; const std::string& exposeHeaders() const override { return EMPTY_STRING; }; @@ -98,6 +101,7 @@ class AsyncStreamImpl : public AsyncClient::Stream, bool enabled() const override { return false; }; static const std::list allow_origin_; + static const std::list allow_origin_regex_; static const absl::optional allow_credentials_; }; @@ -191,6 +195,7 @@ class AsyncStreamImpl : public AsyncClient::Stream, return std::chrono::milliseconds(0); } } + absl::optional idleTimeout() const override { return absl::nullopt; } absl::optional maxGrpcTimeout() const override { return absl::nullopt; } @@ -207,7 +212,7 @@ class AsyncStreamImpl : public AsyncClient::Stream, Http::WebSocketProxyCallbacks&, Upstream::ClusterManager&, Network::ReadFilterCallbacks*) const override { - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } bool includeVirtualHostRateLimits() const override { return true; } const envoy::api::v2::core::Metadata& metadata() const override { return metadata_; } @@ -262,8 +267,8 @@ class AsyncStreamImpl : public AsyncClient::Stream, RequestInfo::RequestInfo& requestInfo() override { return request_info_; } Tracing::Span& activeSpan() override { return active_span_; } const Tracing::Config& tracingConfig() override { return tracing_config_; } - void continueDecoding() override { NOT_IMPLEMENTED; } - void addDecodedData(Buffer::Instance&, bool) override { NOT_IMPLEMENTED; } + void continueDecoding() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void addDecodedData(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } const Buffer::Instance* decodingBuffer() override { return buffered_body_.get(); } void sendLocalReply(Code code, const std::string& body, std::function modify_headers) override { diff --git a/source/common/http/conn_manager_config.h b/source/common/http/conn_manager_config.h index 26ba715c5fe67..7c172903f7a44 100644 --- a/source/common/http/conn_manager_config.h +++ b/source/common/http/conn_manager_config.h @@ -54,6 +54,7 @@ namespace Http { COUNTER (downstream_rq_4xx) \ COUNTER (downstream_rq_5xx) \ HISTOGRAM(downstream_rq_time) \ + COUNTER (downstream_rq_idle_timeout) \ COUNTER (rs_too_large) // clang-format on @@ -138,7 +139,7 @@ enum class ForwardClientCertType { * Configuration for the fields of the client cert, used for populating the current client cert * information to the next hop. */ -enum class ClientCertDetailsType { Cert, Subject, SAN, URI, DNS }; +enum class ClientCertDetailsType { Cert, Subject, URI, DNS }; /** * Abstract configuration for the connection manager. @@ -191,7 +192,13 @@ class ConnectionManagerConfig { /** * @return optional idle timeout for incoming connection manager connections. */ - virtual const absl::optional& idleTimeout() PURE; + virtual absl::optional idleTimeout() const PURE; + + /** + * @return per-stream idle timeout for incoming connection manager connections. Zero indicates a + * disabled idle timeout. + */ + virtual std::chrono::milliseconds streamIdleTimeout() const PURE; /** * @return Router::RouteConfigProvider& the configuration provider used to acquire a route diff --git a/source/common/http/conn_manager_impl.cc b/source/common/http/conn_manager_impl.cc index e0975e4473f85..8475ec8c9a704 100644 --- a/source/common/http/conn_manager_impl.cc +++ b/source/common/http/conn_manager_impl.cc @@ -164,6 +164,10 @@ void ConnectionManagerImpl::doEndStream(ActiveStream& stream) { } void ConnectionManagerImpl::doDeferredStreamDestroy(ActiveStream& stream) { + if (stream.idle_timer_ != nullptr) { + stream.idle_timer_->disableTimer(); + stream.idle_timer_ = nullptr; + } stream.state_.destroyed_ = true; for (auto& filter : stream.decoder_filters_) { filter->handle_->onDestroy(); @@ -369,6 +373,13 @@ ConnectionManagerImpl::ActiveStream::ActiveStream(ConnectionManagerImpl& connect // prevents surprises for logging code in edge cases. request_info_.setDownstreamRemoteAddress( connection_manager_.read_callbacks_->connection().remoteAddress()); + + if (connection_manager_.config_.streamIdleTimeout().count()) { + idle_timeout_ms_ = connection_manager_.config_.streamIdleTimeout(); + idle_timer_ = connection_manager_.read_callbacks_->connection().dispatcher().createTimer( + [this]() -> void { onIdleTimeout(); }); + resetIdleTimer(); + } } ConnectionManagerImpl::ActiveStream::~ActiveStream() { @@ -396,6 +407,29 @@ ConnectionManagerImpl::ActiveStream::~ActiveStream() { ASSERT(state_.filter_call_state_ == 0); } +void ConnectionManagerImpl::ActiveStream::resetIdleTimer() { + if (idle_timer_ != nullptr) { + // TODO(htuch): If this shows up in performance profiles, optimize by only + // updating a timestamp here and doing periodic checks for idle timeouts + // instead, or reducing the accuracy of timers. + idle_timer_->enableTimer(idle_timeout_ms_); + } +} + +void ConnectionManagerImpl::ActiveStream::onIdleTimeout() { + connection_manager_.stats_.named_.downstream_rq_idle_timeout_.inc(); + // If headers have not been sent to the user, send a 408. + if (response_headers_ != nullptr) { + // TODO(htuch): We could send trailers here with an x-envoy timeout header + // or gRPC status code, and/or set H2 RST_STREAM error. + connection_manager_.doEndStream(*this); + } else { + sendLocalReply(request_headers_ != nullptr && + Grpc::Common::hasGrpcContentType(*request_headers_), + Http::Code::RequestTimeout, "stream timeout", nullptr); + } +} + void ConnectionManagerImpl::ActiveStream::addStreamDecoderFilterWorker( StreamDecoderFilterSharedPtr filter, bool dual_filter) { ActiveStreamDecoderFilterPtr wrapper(new ActiveStreamDecoderFilter(*this, filter, dual_filter)); @@ -447,7 +481,7 @@ const Network::Connection* ConnectionManagerImpl::ActiveStream::connection() { void ConnectionManagerImpl::ActiveStream::decodeHeaders(HeaderMapPtr&& headers, bool end_stream) { request_headers_ = std::move(headers); - createFilterChain(); + const bool upgrade_rejected = createFilterChain() == false; maybeEndDecode(end_stream); @@ -569,7 +603,7 @@ void ConnectionManagerImpl::ActiveStream::decodeHeaders(HeaderMapPtr&& headers, connection_manager_.stats_.named_.downstream_cx_http1_active_.dec(); connection_manager_.stats_.named_.downstream_cx_websocket_total_.inc(); return; - } else if (websocket_requested) { + } else if (upgrade_rejected) { // Do not allow WebSocket upgrades if the route does not support it. connection_manager_.stats_.named_.downstream_rq_ws_on_non_ws_route_.inc(); sendLocalReply(Grpc::Common::hasGrpcContentType(*request_headers_), Code::Forbidden, "", @@ -579,12 +613,34 @@ void ConnectionManagerImpl::ActiveStream::decodeHeaders(HeaderMapPtr&& headers, // Allow non websocket requests to go through websocket enabled routes. } + if (cached_route_.value()) { + const Router::RouteEntry* route_entry = cached_route_.value()->routeEntry(); + if (route_entry != nullptr && route_entry->idleTimeout()) { + idle_timeout_ms_ = route_entry->idleTimeout().value(); + if (idle_timeout_ms_.count()) { + // If we have a route-level idle timeout but no global stream idle timeout, create a timer. + if (idle_timer_ == nullptr) { + idle_timer_ = connection_manager_.read_callbacks_->connection().dispatcher().createTimer( + [this]() -> void { onIdleTimeout(); }); + } + } else if (idle_timer_ != nullptr) { + // If we had a global stream idle timeout but the route-level idle timeout is set to zero + // (to override), we disable the idle timer. + idle_timer_->disableTimer(); + idle_timer_ = nullptr; + } + } + } + // Check if tracing is enabled at all. if (connection_manager_.config_.tracingConfig()) { traceRequest(); } decodeHeaders(nullptr, *request_headers_, end_stream); + + // Reset it here for both global and overriden cases. + resetIdleTimer(); } void ConnectionManagerImpl::ActiveStream::traceRequest() { @@ -702,6 +758,8 @@ void ConnectionManagerImpl::ActiveStream::decodeData(Buffer::Instance& data, boo void ConnectionManagerImpl::ActiveStream::decodeData(ActiveStreamDecoderFilter* filter, Buffer::Instance& data, bool end_stream) { + resetIdleTimer(); + // If a response is complete or a reset has been sent, filters do not care about further body // data. Just drop it. if (state_.local_complete_) { @@ -745,11 +803,12 @@ void ConnectionManagerImpl::ActiveStream::addDecodedData(ActiveStreamDecoderFilt } else { // TODO(mattklein123): Formalize error handling for filters and add tests. Should probably // throw an exception here. - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } void ConnectionManagerImpl::ActiveStream::decodeTrailers(HeaderMapPtr&& trailers) { + resetIdleTimer(); maybeEndDecode(true); request_trailers_ = std::move(trailers); decodeTrailers(nullptr, *request_trailers_); @@ -846,6 +905,7 @@ void ConnectionManagerImpl::ActiveStream::sendLocalReply( void ConnectionManagerImpl::ActiveStream::encode100ContinueHeaders( ActiveStreamEncoderFilter* filter, HeaderMap& headers) { + resetIdleTimer(); ASSERT(connection_manager_.config_.proxy100Continue()); // Make sure commonContinue continues encode100ContinueHeaders. has_continue_headers_ = true; @@ -869,7 +929,7 @@ void ConnectionManagerImpl::ActiveStream::encode100ContinueHeaders( // Strip the T-E headers etc. Defer other header additions as well as drain-close logic to the // continuation headers. - ConnectionManagerUtility::mutateResponseHeaders(headers, *request_headers_, EMPTY_STRING); + ConnectionManagerUtility::mutateResponseHeaders(headers, request_headers_.get(), EMPTY_STRING); // Count both the 1xx and follow-up response code in stats. chargeStats(headers); @@ -882,6 +942,8 @@ void ConnectionManagerImpl::ActiveStream::encode100ContinueHeaders( void ConnectionManagerImpl::ActiveStream::encodeHeaders(ActiveStreamEncoderFilter* filter, HeaderMap& headers, bool end_stream) { + resetIdleTimer(); + std::list::iterator entry = commonEncodePrefix(filter, end_stream); std::list::iterator continue_data_entry = encoder_filters_.end(); @@ -908,7 +970,7 @@ void ConnectionManagerImpl::ActiveStream::encodeHeaders(ActiveStreamEncoderFilte connection_manager_.config_.dateProvider().setDateHeader(headers); // Following setReference() is safe because serverName() is constant for the life of the listener. headers.insertServer().value().setReference(connection_manager_.config_.serverName()); - ConnectionManagerUtility::mutateResponseHeaders(headers, *request_headers_, + ConnectionManagerUtility::mutateResponseHeaders(headers, request_headers_.get(), connection_manager_.config_.via()); // See if we want to drain/close the connection. Send the go away frame prior to encoding the @@ -942,7 +1004,12 @@ void ConnectionManagerImpl::ActiveStream::encodeHeaders(ActiveStreamEncoderFilte if (connection_manager_.drain_state_ == DrainState::Closing && connection_manager_.codec_->protocol() != Protocol::Http2) { - headers.insertConnection().value().setReference(Headers::get().ConnectionValues.Close); + // If the connection manager is draining send "Connection: Close" on HTTP/1.1 connections. + // Do not do this for H2 (which drains via GOAWA) or Upgrade (as the upgrade + // payload is no longer HTTP/1.1) + if (!Utility::isUpgrade(headers)) { + headers.insertConnection().value().setReference(Headers::get().ConnectionValues.Close); + } } if (connection_manager_.config_.tracingConfig()) { @@ -1008,12 +1075,13 @@ void ConnectionManagerImpl::ActiveStream::addEncodedData(ActiveStreamEncoderFilt } else { // TODO(mattklein123): Formalize error handling for filters and add tests. Should probably // throw an exception here. - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } void ConnectionManagerImpl::ActiveStream::encodeData(ActiveStreamEncoderFilter* filter, Buffer::Instance& data, bool end_stream) { + resetIdleTimer(); std::list::iterator entry = commonEncodePrefix(filter, end_stream); for (; entry != encoder_filters_.end(); entry++) { ASSERT(!(state_.filter_call_state_ & FilterCallState::EncodeData)); @@ -1037,6 +1105,7 @@ void ConnectionManagerImpl::ActiveStream::encodeData(ActiveStreamEncoderFilter* void ConnectionManagerImpl::ActiveStream::encodeTrailers(ActiveStreamEncoderFilter* filter, HeaderMap& trailers) { + resetIdleTimer(); std::list::iterator entry = commonEncodePrefix(filter, true); for (; entry != encoder_filters_.end(); entry++) { ASSERT(!(state_.filter_call_state_ & FilterCallState::EncodeTrailers)); @@ -1118,8 +1187,22 @@ void ConnectionManagerImpl::ActiveStream::setBufferLimit(uint32_t new_limit) { } } -void ConnectionManagerImpl::ActiveStream::createFilterChain() { +bool ConnectionManagerImpl::ActiveStream::createFilterChain() { + bool upgrade_rejected = false; + auto upgrade = request_headers_->Upgrade(); + if (upgrade != nullptr) { + if (connection_manager_.config_.filterFactory().createUpgradeFilterChain( + upgrade->value().c_str(), *this)) { + return true; + } else { + upgrade_rejected = true; + // Fall through to the default filter chain. The function calling this + // will send a local reply indicating that the upgrade failed. + } + } + connection_manager_.config_.filterFactory().createFilterChain(*this); + return !upgrade_rejected; } void ConnectionManagerImpl::ActiveStreamFilterBase::commonContinue() { diff --git a/source/common/http/conn_manager_impl.h b/source/common/http/conn_manager_impl.h index 44b6d39e6dfd7..672ff52433cb2 100644 --- a/source/common/http/conn_manager_impl.h +++ b/source/common/http/conn_manager_impl.h @@ -150,7 +150,7 @@ class ConnectionManagerImpl : Logger::Loggable, Buffer::WatermarkBufferPtr createBuffer() override; Buffer::WatermarkBufferPtr& bufferedData() override { return parent_.buffered_request_data_; } bool complete() override { return parent_.state_.remote_complete_; } - void do100ContinueHeaders() override { NOT_REACHED; } + void do100ContinueHeaders() override { NOT_REACHED_GCOVR_EXCL_LINE; } void doHeaders(bool end_stream) override { parent_.decodeHeaders(this, *parent_.request_headers_, end_stream); } @@ -287,7 +287,7 @@ class ConnectionManagerImpl : Logger::Loggable, void onBelowWriteBufferLowWatermark() override; // Http::StreamDecoder - void decode100ContinueHeaders(HeaderMapPtr&&) override { NOT_REACHED; } + void decode100ContinueHeaders(HeaderMapPtr&&) override { NOT_REACHED_GCOVR_EXCL_LINE; } void decodeHeaders(HeaderMapPtr&& headers, bool end_stream) override; void decodeData(Buffer::Instance& data, bool end_stream) override; void decodeTrailers(HeaderMapPtr&& trailers) override; @@ -361,7 +361,11 @@ class ConnectionManagerImpl : Logger::Loggable, // Possibly increases buffer_limit_ to the value of limit. void setBufferLimit(uint32_t limit); // Set up the Encoder/Decoder filter chain. - void createFilterChain(); + bool createFilterChain(); + // Per-stream idle timeout callback. + void onIdleTimeout(); + // Reset per-stream idle timer. + void resetIdleTimer(); ConnectionManagerImpl& connection_manager_; Router::ConfigConstSharedPtr snapped_route_config_; @@ -379,6 +383,9 @@ class ConnectionManagerImpl : Logger::Loggable, std::list encoder_filters_; std::list access_log_handlers_; Stats::TimespanPtr request_timer_; + // Per-stream idle timeout. + Event::TimerPtr idle_timer_; + std::chrono::milliseconds idle_timeout_ms_{}; State state_; RequestInfo::RequestInfoImpl request_info_; absl::optional cached_route_; diff --git a/source/common/http/conn_manager_utility.cc b/source/common/http/conn_manager_utility.cc index bc14af4821352..ebe58b59d1dbe 100644 --- a/source/common/http/conn_manager_utility.cc +++ b/source/common/http/conn_manager_utility.cc @@ -23,9 +23,9 @@ Network::Address::InstanceConstSharedPtr ConnectionManagerUtility::mutateRequest ConnectionManagerConfig& config, const Router::Config& route_config, Runtime::RandomGenerator& random, Runtime::Loader& runtime, const LocalInfo::LocalInfo& local_info) { - // If this is a WebSocket Upgrade request, do not remove the Connection and Upgrade headers, + // If this is a Upgrade request, do not remove the Connection and Upgrade headers, // as we forward them verbatim to the upstream hosts. - if (protocol == Protocol::Http11 && Utility::isWebSocketUpgradeRequest(request_headers)) { + if (protocol == Protocol::Http11 && Utility::isUpgrade(request_headers)) { // The current WebSocket implementation re-uses the HTTP1 codec to send upgrade headers to // the upstream host. This adds the "transfer-encoding: chunked" request header if the stream // has not ended and content-length does not exist. In HTTP1.1, if transfer-encoding and @@ -238,7 +238,7 @@ void ConnectionManagerUtility::mutateXfccRequestHeader(Http::HeaderMap& request_ return; } - // TODO(myidpt): Handle the special characters in By and SAN fields. + // TODO(myidpt): Handle the special characters in By and URI fields. // TODO: Optimize client_cert_details based on perf analysis (direct string appending may be more // preferable). std::vector client_cert_details; @@ -268,12 +268,8 @@ void ConnectionManagerUtility::mutateXfccRequestHeader(Http::HeaderMap& request_ client_cert_details.push_back("Subject=\"" + connection.ssl()->subjectPeerCertificate() + "\""); break; - case Http::ClientCertDetailsType::SAN: - // Currently, we only support a single SAN field with URI type. - // The "SAN" key still exists even if the SAN is empty. - client_cert_details.push_back("SAN=" + connection.ssl()->uriSanPeerCertificate()); - break; case Http::ClientCertDetailsType::URI: + // The "URI" key still exists even if the URI is empty. client_cert_details.push_back("URI=" + connection.ssl()->uriSanPeerCertificate()); break; case Http::ClientCertDetailsType::DNS: { @@ -296,20 +292,37 @@ void ConnectionManagerUtility::mutateXfccRequestHeader(Http::HeaderMap& request_ } else if (config.forwardClientCert() == Http::ForwardClientCertType::SanitizeSet) { request_headers.insertForwardedClientCert().value(client_cert_details_str); } else { - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } void ConnectionManagerUtility::mutateResponseHeaders(Http::HeaderMap& response_headers, - const Http::HeaderMap& request_headers, + const Http::HeaderMap* request_headers, const std::string& via) { - response_headers.removeConnection(); + if (request_headers != nullptr && Utility::isUpgrade(*request_headers) && + Utility::isUpgrade(response_headers)) { + // As in mutateRequestHeaders, Upgrade responses have special handling. + // + // Unlike mutateRequestHeaders there is no explicit protocol check. If Envoy is proxying an + // upgrade response it has already passed the protocol checks. + const bool no_body = + (!response_headers.TransferEncoding() && !response_headers.ContentLength()); + if (no_body) { + response_headers.insertContentLength().value(uint64_t(0)); + } + } else { + response_headers.removeConnection(); + } response_headers.removeTransferEncoding(); - if (request_headers.EnvoyForceTrace() && request_headers.RequestId()) { - response_headers.insertRequestId().value(*request_headers.RequestId()); + if (request_headers != nullptr && request_headers->EnvoyForceTrace() && + request_headers->RequestId()) { + response_headers.insertRequestId().value(*request_headers->RequestId()); } + response_headers.removeKeepAlive(); + response_headers.removeProxyConnection(); + if (!via.empty()) { Utility::appendVia(response_headers, via); } diff --git a/source/common/http/conn_manager_utility.h b/source/common/http/conn_manager_utility.h index 5d74cb0dfc950..c27886cda9936 100644 --- a/source/common/http/conn_manager_utility.h +++ b/source/common/http/conn_manager_utility.h @@ -34,7 +34,7 @@ class ConnectionManagerUtility { Runtime::Loader& runtime, const LocalInfo::LocalInfo& local_info); static void mutateResponseHeaders(Http::HeaderMap& response_headers, - const Http::HeaderMap& request_headers, const std::string& via); + const Http::HeaderMap* request_headers, const std::string& via); private: /** diff --git a/source/common/http/header_map_impl.cc b/source/common/http/header_map_impl.cc index 475092c913847..bb260620f57b8 100644 --- a/source/common/http/header_map_impl.cc +++ b/source/common/http/header_map_impl.cc @@ -91,7 +91,7 @@ void HeaderString::append(const char* data, uint32_t size) { const uint64_t new_capacity = (static_cast(string_length_) + size) * 2; // If the resizing will cause buffer overflow due to hitting uint32_t::max, an OOM is likely // imminent. Fast-fail rather than allow a buffer overflow attack (issue #1421) - RELEASE_ASSERT(new_capacity <= std::numeric_limits::max()); + RELEASE_ASSERT(new_capacity <= std::numeric_limits::max(), ""); buffer_.dynamic_ = static_cast(malloc(new_capacity)); memcpy(buffer_.dynamic_, inline_buffer_, string_length_); dynamic_capacity_ = new_capacity; diff --git a/source/common/http/header_utility.cc b/source/common/http/header_utility.cc index db42bfc35f9c4..3618a5a08f484 100644 --- a/source/common/http/header_utility.cc +++ b/source/common/http/header_utility.cc @@ -9,13 +9,8 @@ namespace Envoy { namespace Http { -// HeaderMatcher will consist of one of the below two options: -// 1.value (string) and regex (bool) -// An empty header value allows for matching to be only based on header presence. -// Regex is an opt-in. Unless explicitly mentioned, the header values will be used for -// exact string matching. -// This is now deprecated. -// 2.header_match_specifier which can be any one of exact_match, regex_match, range_match, +// HeaderMatcher will consist of: +// header_match_specifier which can be any one of exact_match, regex_match, range_match, // present_match, prefix_match or suffix_match. // Each of these also can be inverted with the invert_match option. // Absence of these options implies empty header value match based on header presence. @@ -56,15 +51,7 @@ HeaderUtility::HeaderData::HeaderData(const envoy::api::v2::route::HeaderMatcher case envoy::api::v2::route::HeaderMatcher::HEADER_MATCH_SPECIFIER_NOT_SET: FALLTHRU; default: - if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, regex, false)) { - header_match_type_ = HeaderMatchType::Regex; - regex_pattern_ = RegexUtil::parseRegex(config.value()); - } else if (config.value().empty()) { - header_match_type_ = HeaderMatchType::Present; - } else { - header_match_type_ = HeaderMatchType::Value; - value_ = config.value(); - } + header_match_type_ = HeaderMatchType::Present; break; } } @@ -122,7 +109,7 @@ bool HeaderUtility::matchHeaders(const Http::HeaderMap& request_headers, match = absl::EndsWith(header->value().getStringView(), header_data.value_); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } return match != header_data.invert_match_; diff --git a/source/common/http/headers.h b/source/common/http/headers.h index 4eadfb45b0eb0..09816aea98128 100644 --- a/source/common/http/headers.h +++ b/source/common/http/headers.h @@ -72,6 +72,7 @@ class HeaderValues { const LowerCaseString LastModified{"last-modified"}; const LowerCaseString Location{"location"}; const LowerCaseString Method{":method"}; + const LowerCaseString NoChunks{":no-chunks"}; const LowerCaseString Origin{"origin"}; const LowerCaseString OtSpanContext{"x-ot-span-context"}; const LowerCaseString Path{":path"}; @@ -107,12 +108,14 @@ class HeaderValues { } UpgradeValues; struct { + const std::string NoCache{"no-cache"}; const std::string NoCacheMaxAge0{"no-cache, max-age=0"}; const std::string NoTransform{"no-transform"}; } CacheControlValues; struct { const std::string Text{"text/plain"}; + const std::string TextEventStream{"text/event-stream"}; const std::string TextUtf8{"text/plain; charset=UTF-8"}; // TODO(jmarantz): fold this into Text const std::string Html{"text/html; charset=UTF-8"}; const std::string Grpc{"application/grpc"}; @@ -212,6 +215,10 @@ class HeaderValues { const std::string AcceptEncoding{"Accept-Encoding"}; const std::string Wildcard{"*"}; } VaryValues; + + struct { + const std::string All{"*"}; + } AccessControlAllowOriginValue; }; typedef ConstSingleton Headers; diff --git a/source/common/http/http1/codec_impl.cc b/source/common/http/http1/codec_impl.cc index 333e49675befe..a08878e609d1c 100644 --- a/source/common/http/http1/codec_impl.cc +++ b/source/common/http/http1/codec_impl.cc @@ -74,13 +74,23 @@ void StreamEncoderImpl::encodeHeaders(const HeaderMap& headers, bool end_stream) // Assume we are chunk encoding unless we are passed a content length or this is a header only // response. Upper layers generally should strip transfer-encoding since it only applies to // HTTP/1.1. The codec will infer it based on the type of response. - if (saw_content_length) { + // for streaming (e.g. SSE stream sent to hystrix dashboard), we do not want + // chunk transfer encoding but we don't have a content-length so we pass "envoy only" + // header to avoid adding chunks + // + // Note that for HEAD requests Envoy does best-effort guessing when there is no + // content-length. If a client makes a HEAD request for an upstream resource + // with no bytes but the upstream response doesn't include "Content-length: 0", + // Envoy will incorrectly assume a subsequent response to GET will be chunk encoded. + if (saw_content_length || headers.NoChunks()) { chunk_encoding_ = false; } else { if (processing_100_continue_) { // Make sure we don't serialize chunk information with 100-Continue headers. chunk_encoding_ = false; - } else if (end_stream) { + } else if (end_stream && !is_response_to_head_request_) { + // If this is a headers-only stream, append an explicit "Content-Length: 0" unless it's a + // response to a HEAD request. encodeHeader(Headers::get().ContentLength.get().c_str(), Headers::get().ContentLength.get().size(), "0", 1); chunk_encoding_ = false; @@ -91,7 +101,14 @@ void StreamEncoderImpl::encodeHeaders(const HeaderMap& headers, bool end_stream) Headers::get().TransferEncoding.get().size(), Headers::get().TransferEncodingValues.Chunked.c_str(), Headers::get().TransferEncodingValues.Chunked.size()); - chunk_encoding_ = true; + // We do not aply chunk encoding for HTTP upgrades. + // If there is a body in a WebSocket Upgrade response, the chunks will be + // passed through via maybeDirectDispatch so we need to avoid appending + // extra chunk boundaries. + // + // When sending a response to a HEAD request Envoy may send an informational + // "Transfer-Encoding: chunked" header, but should not send a chunk encoded body. + chunk_encoding_ = !Utility::isUpgrade(headers) && !is_response_to_head_request_; } } @@ -267,7 +284,7 @@ http_parser_settings ConnectionImpl::settings_{ return 0; }, [](http_parser* parser) -> int { - static_cast(parser->data)->onMessageComplete(); + static_cast(parser->data)->onMessageCompleteBase(); return 0; }, nullptr, // on_chunk_header @@ -301,9 +318,32 @@ void ConnectionImpl::completeLastHeader() { ASSERT(current_header_value_.empty()); } +bool ConnectionImpl::maybeDirectDispatch(Buffer::Instance& data) { + if (!handling_upgrade_) { + // Only direct dispatch for Upgrade requests. + return false; + } + + ssize_t total_parsed = 0; + uint64_t num_slices = data.getRawSlices(nullptr, 0); + Buffer::RawSlice slices[num_slices]; + data.getRawSlices(slices, num_slices); + for (Buffer::RawSlice& slice : slices) { + total_parsed += slice.len_; + onBody(static_cast(slice.mem_), slice.len_); + } + ENVOY_CONN_LOG(trace, "direct-dispatched {} bytes", connection_, total_parsed); + data.drain(total_parsed); + return true; +} + void ConnectionImpl::dispatch(Buffer::Instance& data) { ENVOY_CONN_LOG(trace, "parsing {} bytes", connection_, data.length()); + if (maybeDirectDispatch(data)) { + return; + } + // Always unpause before dispatch. http_parser_pause(&parser_, 0); @@ -321,6 +361,10 @@ void ConnectionImpl::dispatch(Buffer::Instance& data) { ENVOY_CONN_LOG(trace, "parsed {} bytes", connection_, total_parsed); data.drain(total_parsed); + + // If an upgrade has been handled and there is body data or early upgrade + // payload to send on, send it on. + maybeDirectDispatch(data); } size_t ConnectionImpl::dispatchSlice(const char* slice, size_t len) { @@ -365,14 +409,34 @@ int ConnectionImpl::onHeadersCompleteBase() { // HTTP/1.1 or not. protocol_ = Protocol::Http10; } + if (Utility::isUpgrade(*current_header_map_)) { + ENVOY_CONN_LOG(trace, "codec entering upgrade mode.", connection_); + handling_upgrade_ = true; + } int rc = onHeadersComplete(std::move(current_header_map_)); current_header_map_.reset(); header_parsing_state_ = HeaderParsingState::Done; - return rc; + + // Returning 2 informs http_parser to not expect a body or further data on this connection. + return handling_upgrade_ ? 2 : rc; +} + +void ConnectionImpl::onMessageCompleteBase() { + ENVOY_CONN_LOG(trace, "message complete", connection_); + if (handling_upgrade_) { + // If this is an upgrade request, swallow the onMessageComplete. The + // upgrade payload will be treated as stream body. + ASSERT(!deferred_end_stream_headers_); + ENVOY_CONN_LOG(trace, "Pausing parser due to upgrade.", connection_); + http_parser_pause(&parser_, 1); + return; + } + onMessageComplete(); } void ConnectionImpl::onMessageBeginBase() { + ENVOY_CONN_LOG(trace, "message begin", connection_); ASSERT(!current_header_map_); current_header_map_.reset(new HeaderMapImpl()); header_parsing_state_ = HeaderParsingState::Field; @@ -484,6 +548,10 @@ int ServerConnectionImpl::onHeadersComplete(HeaderMapImplPtr&& headers) { if (active_request_) { const char* method_string = http_method_str(static_cast(parser_.method)); + // Inform the response encoder about any HEAD method, so it can set content + // length and transfer encoding headers correctly. + active_request_->response_encoder_.isResponseToHeadRequest(parser_.method == HTTP_HEAD); + // Currently, CONNECT is not supported, however; http_parser_parse_url needs to know about // CONNECT handlePath(*headers, parser_.method); @@ -498,7 +566,7 @@ int ServerConnectionImpl::onHeadersComplete(HeaderMapImplPtr&& headers) { // scenario where the higher layers stream through and implicitly switch to chunked transfer // encoding because end stream with zero body length has not yet been indicated. if (parser_.flags & F_CHUNKED || - (parser_.content_length > 0 && parser_.content_length != ULLONG_MAX)) { + (parser_.content_length > 0 && parser_.content_length != ULLONG_MAX) || handling_upgrade_) { active_request_->request_decoder_->decodeHeaders(std::move(headers), false); // If the connection has been closed (or is closing) after decoding headers, pause the parser @@ -540,7 +608,6 @@ void ServerConnectionImpl::onBody(const char* data, size_t length) { void ServerConnectionImpl::onMessageComplete() { if (active_request_) { - ENVOY_CONN_LOG(trace, "message complete", connection_); Buffer::OwnedImpl buffer; active_request_->remote_complete_ = true; diff --git a/source/common/http/http1/codec_impl.h b/source/common/http/http1/codec_impl.h index 6002059aad25c..63a8d5540db3c 100644 --- a/source/common/http/http1/codec_impl.h +++ b/source/common/http/http1/codec_impl.h @@ -46,6 +46,8 @@ class StreamEncoderImpl : public StreamEncoder, void readDisable(bool disable) override; uint32_t bufferLimit() override; + void isResponseToHeadRequest(bool value) { is_response_to_head_request_ = value; } + protected: StreamEncoderImpl(ConnectionImpl& connection) : connection_(connection) {} @@ -71,6 +73,7 @@ class StreamEncoderImpl : public StreamEncoder, bool chunk_encoding_{true}; bool processing_100_continue_{false}; + bool is_response_to_head_request_{false}; }; /** @@ -152,6 +155,8 @@ class ConnectionImpl : public virtual Connection, protected Logger::Loggableheaders_.reset(); @@ -723,7 +722,7 @@ ConnectionImpl::Http2Callbacks::Http2Callbacks() { ConnectionImpl::Http2Callbacks::~Http2Callbacks() { nghttp2_session_callbacks_del(callbacks_); } -ConnectionImpl::Http2Options::Http2Options() { +ConnectionImpl::Http2Options::Http2Options(const Http2Settings& http2_settings) { nghttp2_option_new(&options_); // Currently we do not do anything with stream priority. Setting the following option prevents // nghttp2 from keeping around closed streams for use during stream priority dependency graph @@ -731,16 +730,31 @@ ConnectionImpl::Http2Options::Http2Options() { // of kept alive HTTP/2 connections. nghttp2_option_set_no_closed_streams(options_, 1); nghttp2_option_set_no_auto_window_update(options_, 1); + + if (http2_settings.hpack_table_size_ != NGHTTP2_DEFAULT_HEADER_TABLE_SIZE) { + nghttp2_option_set_max_deflate_dynamic_table_size(options_, http2_settings.hpack_table_size_); + } } ConnectionImpl::Http2Options::~Http2Options() { nghttp2_option_del(options_); } +ConnectionImpl::ClientHttp2Options::ClientHttp2Options(const Http2Settings& http2_settings) + : Http2Options(http2_settings) { + // Temporarily disable initial max streams limit/protection, since we might want to create + // more than 100 streams before receiving the HTTP/2 SETTINGS frame from the server. + // + // TODO(PiotrSikora): remove this once multiple upstream connections or queuing are implemented. + nghttp2_option_set_peer_max_concurrent_streams(options_, + Http2Settings::DEFAULT_MAX_CONCURRENT_STREAMS); +} + ClientConnectionImpl::ClientConnectionImpl(Network::Connection& connection, Http::ConnectionCallbacks& callbacks, Stats::Scope& stats, const Http2Settings& http2_settings) : ConnectionImpl(connection, stats, http2_settings), callbacks_(callbacks) { + ClientHttp2Options client_http2_options(http2_settings); nghttp2_session_client_new2(&session_, http2_callbacks_.callbacks(), base(), - http2_options_.options()); + client_http2_options.options()); sendSettings(http2_settings, true); } @@ -759,9 +773,10 @@ Http::StreamEncoder& ClientConnectionImpl::newStream(StreamDecoder& decoder) { int ClientConnectionImpl::onBeginHeaders(const nghttp2_frame* frame) { // The client code explicitly does not currently suport push promise. - RELEASE_ASSERT(frame->hd.type == NGHTTP2_HEADERS); + RELEASE_ASSERT(frame->hd.type == NGHTTP2_HEADERS, ""); RELEASE_ASSERT(frame->headers.cat == NGHTTP2_HCAT_RESPONSE || - frame->headers.cat == NGHTTP2_HCAT_HEADERS); + frame->headers.cat == NGHTTP2_HCAT_HEADERS, + ""); if (frame->headers.cat == NGHTTP2_HCAT_HEADERS) { StreamImpl* stream = getStream(frame->hd.stream_id); ASSERT(!stream->headers_); @@ -783,8 +798,9 @@ ServerConnectionImpl::ServerConnectionImpl(Network::Connection& connection, Http::ServerConnectionCallbacks& callbacks, Stats::Scope& scope, const Http2Settings& http2_settings) : ConnectionImpl(connection, scope, http2_settings), callbacks_(callbacks) { + Http2Options http2_options(http2_settings); nghttp2_session_server_new2(&session_, http2_callbacks_.callbacks(), base(), - http2_options_.options()); + http2_options.options()); sendSettings(http2_settings, false); } diff --git a/source/common/http/http2/codec_impl.h b/source/common/http/http2/codec_impl.h index 43cbba9efbbcf..6b4e0016a72c0 100644 --- a/source/common/http/http2/codec_impl.h +++ b/source/common/http/http2/codec_impl.h @@ -118,15 +118,20 @@ class ConnectionImpl : public virtual Connection, protected Logger::Loggable active_streams_; nghttp2_session* session_{}; diff --git a/source/common/http/utility.cc b/source/common/http/utility.cc index afd0510dbf77c..e3ff89ad85949 100644 --- a/source/common/http/utility.cc +++ b/source/common/http/utility.cc @@ -201,15 +201,18 @@ uint64_t Utility::getResponseStatus(const HeaderMap& headers) { return response_code; } -bool Utility::isWebSocketUpgradeRequest(const HeaderMap& headers) { +bool Utility::isUpgrade(const HeaderMap& headers) { // In firefox the "Connection" request header value is "keep-alive, Upgrade", // we should check if it contains the "Upgrade" token. return (headers.Connection() && headers.Upgrade() && - headers.Connection()->value().caseInsensitiveContains( - Http::Headers::get().ConnectionValues.Upgrade.c_str()) && - (0 == StringUtil::caseInsensitiveCompare( - headers.Upgrade()->value().c_str(), - Http::Headers::get().UpgradeValues.WebSocket.c_str()))); + Envoy::StringUtil::caseFindToken(headers.Connection()->value().getStringView(), ",", + Http::Headers::get().ConnectionValues.Upgrade.c_str())); +} + +bool Utility::isWebSocketUpgradeRequest(const HeaderMap& headers) { + return (isUpgrade(headers) && (0 == StringUtil::caseInsensitiveCompare( + headers.Upgrade()->value().c_str(), + Http::Headers::get().UpgradeValues.WebSocket.c_str()))); } Http2Settings @@ -335,7 +338,7 @@ const std::string& Utility::getProtocolString(const Protocol protocol) { return Headers::get().ProtocolStrings.Http2String; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void Utility::extractHostPathFromUri(const absl::string_view& uri, absl::string_view& host, @@ -345,7 +348,7 @@ void Utility::extractHostPathFromUri(const absl::string_view& uri, absl::string_ * * Example: * uri = "https://example.com:8443/certs" - * pos: ^ + * pos: ^ * host_pos: ^ * path_pos: ^ * host = "example.com:8443" diff --git a/source/common/http/utility.h b/source/common/http/utility.h index 2e00748d6b698..8a11ae5129f2b 100644 --- a/source/common/http/utility.h +++ b/source/common/http/utility.h @@ -93,6 +93,14 @@ std::string makeSetCookieValue(const std::string& key, const std::string& value, */ uint64_t getResponseStatus(const HeaderMap& headers); +/** + * Determine whether these headers are a valid Upgrade request or response. + * This function returns true if the following HTTP headers and values are present: + * - Connection: Upgrade + * - Upgrade: [any value] + */ +bool isUpgrade(const HeaderMap& headers); + /** * Determine whether this is a WebSocket Upgrade request. * This function returns true if the following HTTP headers and values are present: diff --git a/source/common/http/websocket/ws_handler_impl.cc b/source/common/http/websocket/ws_handler_impl.cc index 541f085155dfe..5906c54b5d92c 100644 --- a/source/common/http/websocket/ws_handler_impl.cc +++ b/source/common/http/websocket/ws_handler_impl.cc @@ -131,7 +131,8 @@ void WsHandlerImpl::onConnectionSuccess() { // the connection pool. The current approach is a stop gap solution, where // we put the onus on the user to tell us if a route (and corresponding upstream) // is supposed to allow websocket upgrades or not. - Http1::ClientConnectionImpl upstream_http(*upstream_connection_, http_conn_callbacks_); + Http1::ClientConnectionImpl upstream_http(upstream_conn_data_->connection(), + http_conn_callbacks_); Http1::RequestStreamEncoderImpl upstream_request = Http1::RequestStreamEncoderImpl(upstream_http); upstream_request.encodeHeaders(request_headers_, false); ASSERT(state_ == ConnectState::PreConnect); diff --git a/source/common/json/json_loader.cc b/source/common/json/json_loader.cc index 8a4a79d9cafb1..9a6b24dead315 100644 --- a/source/common/json/json_loader.cc +++ b/source/common/json/json_loader.cc @@ -118,7 +118,7 @@ class Field : public Object { return "String"; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } struct Value { @@ -312,7 +312,7 @@ void Field::buildRapidJsonDocument(const Field& field, rapidjson::Value& value, break; } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -557,7 +557,7 @@ bool ObjectHandler::StartObject() { state_ = expectKeyOrEndObject; return true; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -576,7 +576,7 @@ bool ObjectHandler::EndObject(rapidjson::SizeType) { } return true; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -587,7 +587,7 @@ bool ObjectHandler::Key(const char* value, rapidjson::SizeType size, bool) { state_ = expectValueOrStartObjectArray; return true; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -611,7 +611,7 @@ bool ObjectHandler::StartArray() { state_ = expectArrayValueOrEndArray; return true; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -631,7 +631,7 @@ bool ObjectHandler::EndArray(rapidjson::SizeType) { return true; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -661,7 +661,7 @@ bool ObjectHandler::String(const char* value, rapidjson::SizeType size, bool) { bool ObjectHandler::RawNumber(const char*, rapidjson::SizeType, bool) { // Only called if kParseNumbersAsStrings is set as a parse flag, which it is not. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } bool ObjectHandler::handleValueEvent(FieldSharedPtr ptr) { @@ -737,7 +737,7 @@ FieldSharedPtr parseYamlNode(YAML::Node node) { case YAML::NodeType::Undefined: throw EnvoyException("Undefined YAML value"); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } // namespace diff --git a/source/common/memory/stats.cc b/source/common/memory/stats.cc index 38fc68fe6dee6..f73930bc81cef 100644 --- a/source/common/memory/stats.cc +++ b/source/common/memory/stats.cc @@ -21,6 +21,12 @@ uint64_t Stats::totalCurrentlyReserved() { return value; } +uint64_t Stats::totalPageHeapUnmapped() { + size_t value = 0; + MallocExtension::instance()->GetNumericProperty("tcmalloc.pageheap_unmapped_bytes", &value); + return value; +} + } // namespace Memory } // namespace Envoy @@ -31,6 +37,7 @@ namespace Memory { uint64_t Stats::totalCurrentlyAllocated() { return 0; } uint64_t Stats::totalCurrentlyReserved() { return 0; } +uint64_t Stats::totalPageHeapUnmapped() { return 0; } } // namespace Memory } // namespace Envoy diff --git a/source/common/memory/stats.h b/source/common/memory/stats.h index 7dba0850e337a..ccfe9785bed4f 100644 --- a/source/common/memory/stats.h +++ b/source/common/memory/stats.h @@ -20,6 +20,11 @@ class Stats { * allocated. */ static uint64_t totalCurrentlyReserved(); + + /** + * @return uint64_t the number of bytes in free, unmapped pages in the page heap. + */ + static uint64_t totalPageHeapUnmapped(); }; } // namespace Memory diff --git a/source/common/network/BUILD b/source/common/network/BUILD index 1ae34e6438b14..94f5dd4c18ef3 100644 --- a/source/common/network/BUILD +++ b/source/common/network/BUILD @@ -15,6 +15,7 @@ envoy_cc_library( hdrs = ["address_impl.h"], deps = [ "//include/envoy/network:address_interface", + "//source/common/api:os_sys_calls_lib", "//source/common/common:assert_lib", "//source/common/common:utility_lib", ], @@ -96,7 +97,6 @@ envoy_cc_library( envoy_cc_library( name = "lc_trie_lib", - srcs = ["lc_trie.cc"], hdrs = ["lc_trie.h"], external_deps = ["abseil_int128"], deps = [ @@ -182,6 +182,7 @@ envoy_cc_library( external_deps = ["abseil_optional"], deps = [ ":address_lib", + "//include/envoy/api:os_sys_calls_interface", "//include/envoy/network:listen_socket_interface", "//source/common/api:os_sys_calls_lib", "//source/common/common:assert_lib", @@ -227,6 +228,7 @@ envoy_cc_library( ":address_lib", "//include/envoy/network:connection_interface", "//include/envoy/stats:stats_interface", + "//source/common/api:os_sys_calls_lib", "//source/common/common:assert_lib", "//source/common/common:utility_lib", "//source/common/protobuf", diff --git a/source/common/network/address_impl.cc b/source/common/network/address_impl.cc index 99a2db5bd576b..6f1c5f18c75d0 100644 --- a/source/common/network/address_impl.cc +++ b/source/common/network/address_impl.cc @@ -11,6 +11,7 @@ #include "envoy/common/exception.h" +#include "common/api/os_sys_calls_impl.h" #include "common/common/assert.h" #include "common/common/fmt.h" #include "common/common/utility.h" @@ -21,19 +22,10 @@ namespace Address { namespace { -// Check if an IP family is supported on this machine. -bool ipFamilySupported(int domain) { - const int fd = ::socket(domain, SOCK_STREAM, 0); - if (fd >= 0) { - RELEASE_ASSERT(::close(fd) == 0); - } - return fd != -1; -} - // Validate that IPv4 is supported on this platform, raise an exception for the // given address if not. void validateIpv4Supported(const std::string& address) { - static const bool supported = ipFamilySupported(AF_INET); + static const bool supported = Network::Address::ipFamilySupported(AF_INET); if (!supported) { throw EnvoyException( fmt::format("IPv4 addresses are not supported on this machine: {}", address)); @@ -43,7 +35,7 @@ void validateIpv4Supported(const std::string& address) { // Validate that IPv6 is supported on this platform, raise an exception for the // given address if not. void validateIpv6Supported(const std::string& address) { - static const bool supported = ipFamilySupported(AF_INET6); + static const bool supported = Network::Address::ipFamilySupported(AF_INET6); if (!supported) { throw EnvoyException( fmt::format("IPv6 addresses are not supported on this machine: {}", address)); @@ -52,18 +44,28 @@ void validateIpv6Supported(const std::string& address) { } // namespace +// Check if an IP family is supported on this machine. +bool ipFamilySupported(int domain) { + Api::OsSysCalls& os_sys_calls = Api::OsSysCallsSingleton::get(); + const int fd = os_sys_calls.socket(domain, SOCK_STREAM, 0); + if (fd >= 0) { + RELEASE_ASSERT(os_sys_calls.close(fd) == 0, ""); + } + return fd != -1; +} + Address::InstanceConstSharedPtr addressFromSockAddr(const sockaddr_storage& ss, socklen_t ss_len, bool v6only) { - RELEASE_ASSERT(ss_len == 0 || ss_len >= sizeof(sa_family_t)); + RELEASE_ASSERT(ss_len == 0 || ss_len >= sizeof(sa_family_t), ""); switch (ss.ss_family) { case AF_INET: { - RELEASE_ASSERT(ss_len == 0 || ss_len == sizeof(sockaddr_in)); + RELEASE_ASSERT(ss_len == 0 || ss_len == sizeof(sockaddr_in), ""); const struct sockaddr_in* sin = reinterpret_cast(&ss); ASSERT(AF_INET == sin->sin_family); return std::make_shared(sin); } case AF_INET6: { - RELEASE_ASSERT(ss_len == 0 || ss_len == sizeof(sockaddr_in6)); + RELEASE_ASSERT(ss_len == 0 || ss_len == sizeof(sockaddr_in6), ""); const struct sockaddr_in6* sin6 = reinterpret_cast(&ss); ASSERT(AF_INET6 == sin6->sin6_family); if (!v6only && IN6_IS_ADDR_V4MAPPED(&sin6->sin6_addr)) { @@ -84,13 +86,13 @@ Address::InstanceConstSharedPtr addressFromSockAddr(const sockaddr_storage& ss, case AF_UNIX: { const struct sockaddr_un* sun = reinterpret_cast(&ss); ASSERT(AF_UNIX == sun->sun_family); - RELEASE_ASSERT(ss_len == 0 || ss_len >= offsetof(struct sockaddr_un, sun_path) + 1); + RELEASE_ASSERT(ss_len == 0 || ss_len >= offsetof(struct sockaddr_un, sun_path) + 1, ""); return std::make_shared(sun, ss_len); } default: throw EnvoyException(fmt::format("Unexpected sockaddr family: {}", ss.ss_family)); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } InstanceConstSharedPtr addressFromFd(int fd) { @@ -104,7 +106,7 @@ InstanceConstSharedPtr addressFromFd(int fd) { int socket_v6only = 0; if (ss.ss_family == AF_INET6) { socklen_t size_int = sizeof(socket_v6only); - RELEASE_ASSERT(::getsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &socket_v6only, &size_int) == 0); + RELEASE_ASSERT(::getsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &socket_v6only, &size_int) == 0, ""); } return addressFromSockAddr(ss, ss_len, rc == 0 && socket_v6only); } @@ -161,11 +163,11 @@ int InstanceBase::socketFromSocketType(SocketType socketType) const { } int fd = ::socket(domain, flags, 0); - RELEASE_ASSERT(fd != -1); + RELEASE_ASSERT(fd != -1, ""); #ifdef __APPLE__ // Cannot set SOCK_NONBLOCK as a ::socket flag. - RELEASE_ASSERT(fcntl(fd, F_SETFL, O_NONBLOCK) != -1); + RELEASE_ASSERT(fcntl(fd, F_SETFL, O_NONBLOCK) != -1, ""); #endif return fd; @@ -212,14 +214,16 @@ bool Ipv4Instance::operator==(const Instance& rhs) const { (ip_.port() == rhs_casted->ip_.port())); } -int Ipv4Instance::bind(int fd) const { - return ::bind(fd, reinterpret_cast(&ip_.ipv4_.address_), - sizeof(ip_.ipv4_.address_)); +Api::SysCallResult Ipv4Instance::bind(int fd) const { + const int rc = ::bind(fd, reinterpret_cast(&ip_.ipv4_.address_), + sizeof(ip_.ipv4_.address_)); + return {rc, errno}; } -int Ipv4Instance::connect(int fd) const { - return ::connect(fd, reinterpret_cast(&ip_.ipv4_.address_), - sizeof(ip_.ipv4_.address_)); +Api::SysCallResult Ipv4Instance::connect(int fd) const { + const int rc = ::connect(fd, reinterpret_cast(&ip_.ipv4_.address_), + sizeof(ip_.ipv4_.address_)); + return {rc, errno}; } int Ipv4Instance::socket(SocketType type) const { return socketFromSocketType(type); } @@ -275,22 +279,23 @@ bool Ipv6Instance::operator==(const Instance& rhs) const { (ip_.port() == rhs_casted->ip_.port())); } -int Ipv6Instance::bind(int fd) const { - return ::bind(fd, reinterpret_cast(&ip_.ipv6_.address_), - sizeof(ip_.ipv6_.address_)); +Api::SysCallResult Ipv6Instance::bind(int fd) const { + const int rc = ::bind(fd, reinterpret_cast(&ip_.ipv6_.address_), + sizeof(ip_.ipv6_.address_)); + return {rc, errno}; } -int Ipv6Instance::connect(int fd) const { - return ::connect(fd, reinterpret_cast(&ip_.ipv6_.address_), - sizeof(ip_.ipv6_.address_)); +Api::SysCallResult Ipv6Instance::connect(int fd) const { + const int rc = ::connect(fd, reinterpret_cast(&ip_.ipv6_.address_), + sizeof(ip_.ipv6_.address_)); + return {rc, errno}; } int Ipv6Instance::socket(SocketType type) const { const int fd = socketFromSocketType(type); - // Setting IPV6_V6ONLY resticts the IPv6 socket to IPv6 connections only. const int v6only = ip_.v6only_; - RELEASE_ASSERT(::setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &v6only, sizeof(v6only)) != -1); + RELEASE_ASSERT(::setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &v6only, sizeof(v6only)) != -1, ""); return fd; } @@ -300,7 +305,7 @@ PipeInstance::PipeInstance(const sockaddr_un* address, socklen_t ss_len) #if !defined(__linux__) throw EnvoyException("Abstract AF_UNIX sockets are only supported on linux."); #endif - RELEASE_ASSERT(ss_len >= offsetof(struct sockaddr_un, sun_path) + 1); + RELEASE_ASSERT(ss_len >= offsetof(struct sockaddr_un, sun_path) + 1, ""); abstract_namespace_ = true; address_length_ = ss_len - offsetof(struct sockaddr_un, sun_path); } @@ -333,24 +338,28 @@ PipeInstance::PipeInstance(const std::string& pipe_path) : InstanceBase(Type::Pi bool PipeInstance::operator==(const Instance& rhs) const { return asString() == rhs.asString(); } -int PipeInstance::bind(int fd) const { +Api::SysCallResult PipeInstance::bind(int fd) const { if (abstract_namespace_) { - return ::bind(fd, reinterpret_cast(&address_), - offsetof(struct sockaddr_un, sun_path) + address_length_); + const int rc = ::bind(fd, reinterpret_cast(&address_), + offsetof(struct sockaddr_un, sun_path) + address_length_); + return {rc, errno}; } // Try to unlink an existing filesystem object at the requested path. Ignore // errors -- it's fine if the path doesn't exist, and if it exists but can't // be unlinked then `::bind()` will generate a reasonable errno. unlink(address_.sun_path); - return ::bind(fd, reinterpret_cast(&address_), sizeof(address_)); + const int rc = ::bind(fd, reinterpret_cast(&address_), sizeof(address_)); + return {rc, errno}; } -int PipeInstance::connect(int fd) const { +Api::SysCallResult PipeInstance::connect(int fd) const { if (abstract_namespace_) { - return ::connect(fd, reinterpret_cast(&address_), - offsetof(struct sockaddr_un, sun_path) + address_length_); + const int rc = ::connect(fd, reinterpret_cast(&address_), + offsetof(struct sockaddr_un, sun_path) + address_length_); + return {rc, errno}; } - return ::connect(fd, reinterpret_cast(&address_), sizeof(address_)); + const int rc = ::connect(fd, reinterpret_cast(&address_), sizeof(address_)); + return {rc, errno}; } int PipeInstance::socket(SocketType type) const { return socketFromSocketType(type); } diff --git a/source/common/network/address_impl.h b/source/common/network/address_impl.h index c48003624c137..1e691ddee48ad 100644 --- a/source/common/network/address_impl.h +++ b/source/common/network/address_impl.h @@ -15,6 +15,12 @@ namespace Envoy { namespace Network { namespace Address { +/** + * Returns true if the given family is supported on this machine. + * @param domain the IP family. + */ +bool ipFamilySupported(int domain); + /** * Convert an address in the form of the socket address struct defined by Posix, Linux, etc. into * a Network::Address::Instance and return a pointer to it. Raises an EnvoyException on failure. @@ -91,8 +97,8 @@ class Ipv4Instance : public InstanceBase { // Network::Address::Instance bool operator==(const Instance& rhs) const override; - int bind(int fd) const override; - int connect(int fd) const override; + Api::SysCallResult bind(int fd) const override; + Api::SysCallResult connect(int fd) const override; const Ip* ip() const override { return &ip_; } int socket(SocketType type) const override; @@ -151,8 +157,8 @@ class Ipv6Instance : public InstanceBase { // Network::Address::Instance bool operator==(const Instance& rhs) const override; - int bind(int fd) const override; - int connect(int fd) const override; + Api::SysCallResult bind(int fd) const override; + Api::SysCallResult connect(int fd) const override; const Ip* ip() const override { return &ip_; } int socket(SocketType type) const override; @@ -208,8 +214,8 @@ class PipeInstance : public InstanceBase { // Network::Address::Instance bool operator==(const Instance& rhs) const override; - int bind(int fd) const override; - int connect(int fd) const override; + Api::SysCallResult bind(int fd) const override; + Api::SysCallResult connect(int fd) const override; const Ip* ip() const override { return nullptr; } int socket(SocketType type) const override; diff --git a/source/common/network/cidr_range.cc b/source/common/network/cidr_range.cc index 05bc7225275e4..0ff7ce010c890 100644 --- a/source/common/network/cidr_range.cc +++ b/source/common/network/cidr_range.cc @@ -194,7 +194,7 @@ InstanceConstSharedPtr CidrRange::truncateIpAddressAndLength(InstanceConstShared return std::make_shared(sa6); } } - NOT_REACHED + NOT_REACHED_GCOVR_EXCL_LINE; } IpList::IpList(const std::vector& subnets) { diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index a5e5be1699c66..073cead4dd664 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -52,7 +52,7 @@ ConnectionImpl::ConnectionImpl(Event::Dispatcher& dispatcher, ConnectionSocketPt dispatcher_(dispatcher), id_(next_global_id_++) { // Treat the lack of a valid fd (which in practice only happens if we run out of FDs) as an OOM // condition and just crash. - RELEASE_ASSERT(fd() != -1); + RELEASE_ASSERT(fd() != -1, ""); if (!connected) { connecting_ = true; @@ -162,7 +162,7 @@ void ConnectionImpl::noDelay(bool enable) { sockaddr addr; socklen_t len = sizeof(addr); int rc = getsockname(fd(), &addr, &len); - RELEASE_ASSERT(rc == 0); + RELEASE_ASSERT(rc == 0, ""); if (addr.sa_family == AF_UNIX) { return; @@ -179,7 +179,7 @@ void ConnectionImpl::noDelay(bool enable) { } #endif - RELEASE_ASSERT(0 == rc); + RELEASE_ASSERT(0 == rc, ""); } uint64_t ConnectionImpl::id() const { return id_; } @@ -551,10 +551,10 @@ ClientConnectionImpl::ClientConnectionImpl( } if (source_address != nullptr) { - const int rc = source_address->bind(fd()); - if (rc < 0) { + const Api::SysCallResult result = source_address->bind(fd()); + if (result.rc_ < 0) { ENVOY_LOG_MISC(debug, "Bind failure. Failed to bind to {}: {}", source_address->asString(), - strerror(errno)); + strerror(result.errno_)); bind_error_ = true; // Set a special error state to ensure asynchronous close to give the owner of the // ConnectionImpl a chance to add callbacks and detect the "disconnect". @@ -568,19 +568,19 @@ ClientConnectionImpl::ClientConnectionImpl( void ClientConnectionImpl::connect() { ENVOY_CONN_LOG(debug, "connecting to {}", *this, socket_->remoteAddress()->asString()); - const int rc = socket_->remoteAddress()->connect(fd()); - if (rc == 0) { + const Api::SysCallResult result = socket_->remoteAddress()->connect(fd()); + if (result.rc_ == 0) { // write will become ready. ASSERT(connecting_); } else { - ASSERT(rc == -1); - if (errno == EINPROGRESS) { + ASSERT(result.rc_ == -1); + if (result.errno_ == EINPROGRESS) { ASSERT(connecting_); ENVOY_CONN_LOG(debug, "connection in progress", *this); } else { immediate_error_event_ = ConnectionEvent::RemoteClose; connecting_ = false; - ENVOY_CONN_LOG(debug, "immediate connection error: {}", *this, errno); + ENVOY_CONN_LOG(debug, "immediate connection error: {}", *this, result.errno_); // Trigger a write event. This is needed on OSX and seems harmless on Linux. file_event_->activate(Event::FileReadyType::Write); diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index f0f32516a6bd7..a9e796e3b02f6 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -89,6 +89,7 @@ class ConnectionImpl : public virtual Connection, const ConnectionSocket::OptionsSharedPtr& socketOptions() const override { return socket_->options(); } + absl::string_view requestedServerName() const override { return socket_->requestedServerName(); } // Network::BufferSource BufferSource::StreamBuffer getReadBuffer() override { return {read_buffer_, read_end_stream_}; } diff --git a/source/common/network/dns_impl.cc b/source/common/network/dns_impl.cc index 6bf2c88a6f9fb..33b721e59155e 100644 --- a/source/common/network/dns_impl.cc +++ b/source/common/network/dns_impl.cc @@ -56,7 +56,7 @@ DnsResolverImpl::DnsResolverImpl( } const std::string resolvers_csv = StringUtil::join(resolver_addrs, ","); int result = ares_set_servers_ports_csv(channel_, resolvers_csv.c_str()); - RELEASE_ASSERT(result == ARES_SUCCESS); + RELEASE_ASSERT(result == ARES_SUCCESS, ""); } } diff --git a/source/common/network/lc_trie.cc b/source/common/network/lc_trie.cc deleted file mode 100644 index 2323065cf199b..0000000000000 --- a/source/common/network/lc_trie.cc +++ /dev/null @@ -1,138 +0,0 @@ -#include "common/network/lc_trie.h" - -namespace Envoy { -namespace Network { -namespace LcTrie { - -LcTrie::LcTrie(const std::vector>>& tag_data, - double fill_factor, uint32_t root_branching_factor) { - - // The LcTrie implementation uses 20-bit "pointers" in its compact internal representation, - // so it cannot hold more than 2^20 nodes. But the number of nodes can be greater than the - // number of supported prefixes. Given N prefixes in the tag_data input list, step 2 below - // can produce a new list of up to 2*N prefixes to insert in the LC trie. And the LC trie - // can use up to 2*N/fill_factor nodes. - size_t num_prefixes = 0; - for (const auto& tag : tag_data) { - num_prefixes += tag.second.size(); - } - const size_t max_prefixes = MaxLcTrieNodes * fill_factor / 2; - if (num_prefixes > max_prefixes) { - throw EnvoyException(fmt::format("The input vector has '{0}' CIDR range entries. LC-Trie " - "can only support '{1}' CIDR ranges with the specified " - "fill factor.", - num_prefixes, max_prefixes)); - } - - // Step 1: separate the provided prefixes by protocol (IPv4 vs IPv6), - // and build a Binary Trie per protocol. - // - // For example, if the input prefixes are - // A: 0.0.0.0/0 - // B: 128.0.0.0/2 (10000000.0.0.0/2 in binary) - // C: 192.0.0.0/2 (11000000.0.0.0/2) - // the Binary Trie for IPv4 will look like this at the end of step 1: - // +---+ - // | A | - // +---+ - // \ 1 - // +---+ - // | | - // +---+ - // 0/ \1 - // +---+ +---+ - // | B | | C | - // +---+ +---+ - // - // Note that the prefixes in this example are nested: any IPv4 address - // that matches B or C will also match A. Unfortunately, the classic LC Trie - // algorithm does not support nested prefixes. The next step will solve that - // problem. - - BinaryTrie ipv4_temp; - BinaryTrie ipv6_temp; - for (const auto& pair_data : tag_data) { - for (const auto& cidr_range : pair_data.second) { - if (cidr_range.ip()->version() == Address::IpVersion::v4) { - IpPrefix ip_prefix(ntohl(cidr_range.ip()->ipv4()->address()), cidr_range.length(), - pair_data.first); - ipv4_temp.insert(ip_prefix); - } else { - IpPrefix ip_prefix(Utility::Ip6ntohl(cidr_range.ip()->ipv6()->address()), - cidr_range.length(), pair_data.first); - ipv6_temp.insert(ip_prefix); - } - } - } - - // Step 2: push each Binary Trie's prefixes to its leaves. - // - // Continuing the previous example, the Binary Trie will look like this - // at the end of step 2: - // +---+ - // | | - // +---+ - // 0/ \ 1 - // +---+ +---+ - // | A | | | - // +---+ +---+ - // 0/ \1 - // +---+ +---+ - // |A,B| |A,C| - // +---+ +---+ - // - // This trie yields the same match results as the original trie from - // step 1. But it has a useful new property: now that all the prefixes - // are at the leaves, they are disjoint: no prefix is nested under another. - - std::vector> ipv4_prefixes = ipv4_temp.push_leaves(); - std::vector> ipv6_prefixes = ipv6_temp.push_leaves(); - - // Step 3: take the disjoint prefixes from the leaves of each Binary Trie - // and use them to construct an LC Trie. - // - // Example inputs (from the leaves of the Binary Trie at the end of step 2) - // A: 0.0.0.0/1 - // A,B: 128.0.0.0/2 - // A,C: 192.0.0.0/2 - // - // The LC Trie generated from these inputs with fill_factor=0.5 and root_branching_factor=0 - // will be: - // - // +---------------------------+ - // | branch_factor=2, skip = 0 | - // +---------------------------+ - // 00/ 01| |10 \11 - // +---+ +---+ +---+ +---+ - // | A | | A | |A,B| |A,C| - // +---+ +---+ +---+ +---+ - // - // Or, in the internal vector form that the LcTrie class uses for memory-efficiency, - // # | branch | skip | first_child | tags | note - // ---+--------+------+-------------+------+-------------------------------------------------- - // 0 | 2 | 0 | 1 | - | (1 << branch) == 4 children, starting at offset 1 - // 1 | - | 0 | - | A | 1st child of node 0, reached if next bits are 00 - // 2 | - | 0 | - | A | . - // 3 | - | 0 | - | A,B | . - // 4 | - | 0 | - | A,C | 4th child of node 0, reached if next bits are 11 - // - // The Nilsson and Karlsson paper linked in lc_trie.h has a more thorough example. - - ipv4_trie_.reset(new LcTrieInternal(ipv4_prefixes, fill_factor, root_branching_factor)); - ipv6_trie_.reset(new LcTrieInternal(ipv6_prefixes, fill_factor, root_branching_factor)); -} - -std::vector -LcTrie::getTags(const Network::Address::InstanceConstSharedPtr& ip_address) const { - if (ip_address->ip()->version() == Address::IpVersion::v4) { - Ipv4 ip = ntohl(ip_address->ip()->ipv4()->address()); - return ipv4_trie_->getTags(ip); - } else { - Ipv6 ip = Utility::Ip6ntohl(ip_address->ip()->ipv6()->address()); - return ipv6_trie_->getTags(ip); - } -} - -} // namespace LcTrie -} // namespace Network -} // namespace Envoy diff --git a/source/common/network/lc_trie.h b/source/common/network/lc_trie.h index 3a94c5c100f72..3670eedd7c338 100644 --- a/source/common/network/lc_trie.h +++ b/source/common/network/lc_trie.h @@ -29,8 +29,8 @@ namespace LcTrie { constexpr size_t MaxLcTrieNodes = (1 << 20); /** - * Level Compressed Trie for tagging IP addresses. Both IPv4 and IPv6 addresses are supported - * within this class with no calling pattern changes. + * Level Compressed Trie for associating data with CIDR ranges. Both IPv4 and IPv6 addresses are + * supported within this class with no calling pattern changes. * * The algorithm to build the LC-Trie is desribed in the paper 'IP-address lookup using LC-tries' * by 'S. Nilsson' and 'G. Karlsson'. The paper and reference C implementation can be found here: @@ -38,10 +38,12 @@ constexpr size_t MaxLcTrieNodes = (1 << 20); * * Refer to LcTrieInternal for implementation and algorithm details. */ -class LcTrie { +template class LcTrie { public: /** - * @param tag_data supplies a vector of tag and CIDR ranges. + * @param data supplies a vector of data and CIDR ranges. + * @param exclusive if true then only data for the most specific subnet will be returned + (i.e. data isn't inherited from wider ranges). * @param fill_factor supplies the fraction of completeness to use when calculating the branch * value for a sub-trie. * @param root_branching_factor supplies the branching factor at the root. @@ -52,19 +54,141 @@ class LcTrie { * get this data for smaller LC-Tries. Another option is to expose this in the configuration and * let consumers decide. */ - LcTrie(const std::vector>>& tag_data, - double fill_factor = 0.5, uint32_t root_branching_factor = 0); + LcTrie(const std::vector>>& data, + bool exclusive = false, double fill_factor = 0.5, uint32_t root_branching_factor = 0) { + + // The LcTrie implementation uses 20-bit "pointers" in its compact internal representation, + // so it cannot hold more than 2^20 nodes. But the number of nodes can be greater than the + // number of supported prefixes. Given N prefixes in the data input list, step 2 below can + // produce a new list of up to 2*N prefixes to insert in the LC trie. And the LC trie can + // use up to 2*N/fill_factor nodes. + size_t num_prefixes = 0; + for (const auto& pair_data : data) { + num_prefixes += pair_data.second.size(); + } + const size_t max_prefixes = MaxLcTrieNodes * fill_factor / 2; + if (num_prefixes > max_prefixes) { + throw EnvoyException(fmt::format("The input vector has '{0}' CIDR range entries. LC-Trie " + "can only support '{1}' CIDR ranges with the specified " + "fill factor.", + num_prefixes, max_prefixes)); + } + + // Step 1: separate the provided prefixes by protocol (IPv4 vs IPv6), + // and build a Binary Trie per protocol. + // + // For example, if the input prefixes are + // A: 0.0.0.0/0 + // B: 128.0.0.0/2 (10000000.0.0.0/2 in binary) + // C: 192.0.0.0/2 (11000000.0.0.0/2) + // the Binary Trie for IPv4 will look like this at the end of step 1: + // +---+ + // | A | + // +---+ + // \ 1 + // +---+ + // | | + // +---+ + // 0/ \1 + // +---+ +---+ + // | B | | C | + // +---+ +---+ + // + // Note that the prefixes in this example are nested: any IPv4 address + // that matches B or C will also match A. Unfortunately, the classic LC Trie + // algorithm does not support nested prefixes. The next step will solve that + // problem. + + BinaryTrie ipv4_temp(exclusive); + BinaryTrie ipv6_temp(exclusive); + for (const auto& pair_data : data) { + for (const auto& cidr_range : pair_data.second) { + if (cidr_range.ip()->version() == Address::IpVersion::v4) { + IpPrefix ip_prefix(ntohl(cidr_range.ip()->ipv4()->address()), cidr_range.length(), + pair_data.first); + ipv4_temp.insert(ip_prefix); + } else { + IpPrefix ip_prefix(Utility::Ip6ntohl(cidr_range.ip()->ipv6()->address()), + cidr_range.length(), pair_data.first); + ipv6_temp.insert(ip_prefix); + } + } + } + + // Step 2: push each Binary Trie's prefixes to its leaves. + // + // Continuing the previous example, the Binary Trie will look like this + // at the end of step 2: + // +---+ + // | | + // +---+ + // 0/ \ 1 + // +---+ +---+ + // | A | | | + // +---+ +---+ + // 0/ \1 + // +---+ +---+ + // |A,B| |A,C| + // +---+ +---+ + // + // This trie yields the same match results as the original trie from + // step 1. But it has a useful new property: now that all the prefixes + // are at the leaves, they are disjoint: no prefix is nested under another. + + std::vector> ipv4_prefixes = ipv4_temp.push_leaves(); + std::vector> ipv6_prefixes = ipv6_temp.push_leaves(); + + // Step 3: take the disjoint prefixes from the leaves of each Binary Trie + // and use them to construct an LC Trie. + // + // Example inputs (from the leaves of the Binary Trie at the end of step 2) + // A: 0.0.0.0/1 + // A,B: 128.0.0.0/2 + // A,C: 192.0.0.0/2 + // + // The LC Trie generated from these inputs with fill_factor=0.5 and root_branching_factor=0 + // will be: + // + // +---------------------------+ + // | branch_factor=2, skip = 0 | + // +---------------------------+ + // 00/ 01| |10 \11 + // +---+ +---+ +---+ +---+ + // | A | | A | |A,B| |A,C| + // +---+ +---+ +---+ +---+ + // + // Or, in the internal vector form that the LcTrie class uses for memory-efficiency, + // # | branch | skip | first_child | data | note + // ---+--------+------+-------------+------+-------------------------------------------------- + // 0 | 2 | 0 | 1 | - | (1 << branch) == 4 children, starting at offset 1 + // 1 | - | 0 | - | A | 1st child of node 0, reached if next bits are 00 + // 2 | - | 0 | - | A | . + // 3 | - | 0 | - | A,B | . + // 4 | - | 0 | - | A,C | 4th child of node 0, reached if next bits are 11 + // + // The Nilsson and Karlsson paper linked in lc_trie.h has a more thorough example. + + ipv4_trie_.reset(new LcTrieInternal(ipv4_prefixes, fill_factor, root_branching_factor)); + ipv6_trie_.reset(new LcTrieInternal(ipv6_prefixes, fill_factor, root_branching_factor)); + } /** - * Retrieve the tag associated with the CIDR range that contains `ip_address`. Both IPv4 and IPv6 + * Retrieve data associated with the CIDR range that contains `ip_address`. Both IPv4 and IPv6 * addresses are supported. * @param ip_address supplies the IP address. - * @return a vector of tags from the CIDR ranges and IP addresses that contains 'ip_address'. An + * @return a vector of data from the CIDR ranges and IP addresses that contains 'ip_address'. An * empty vector is returned if no prefix contains 'ip_address' or there is no data for the IP * version of the ip_address. */ - std::vector - getTags(const Network::Address::InstanceConstSharedPtr& ip_address) const; + std::vector getData(const Network::Address::InstanceConstSharedPtr& ip_address) const { + if (ip_address->ip()->version() == Address::IpVersion::v4) { + Ipv4 ip = ntohl(ip_address->ip()->ipv4()->address()); + return ipv4_trie_->getData(ip); + } else { + Ipv6 ip = Utility::Ip6ntohl(ip_address->ip()->ipv6()->address()); + return ipv6_trie_->getData(ip); + } + } private: /** @@ -78,9 +202,9 @@ class LcTrie { template static IpType extractBits(uint32_t p, uint32_t n, IpType input) { // The IP's are stored in host byte order. - // By shifting the value to the left by p bits(and back), the bits between 0 and p-1 are zero'd - // out. Then to get the n bits, shift the IP back by the address_size minus the number of - // desired bits. + // By shifting the value to the left by p bits(and back), the bits between 0 and p-1 are + // zero'd out. Then to get the n bits, shift the IP back by the address_size minus the number + // of desired bits. if (n == 0) { return IpType(0); } @@ -106,22 +230,22 @@ class LcTrie { typedef uint32_t Ipv4; typedef absl::uint128 Ipv6; - typedef std::unordered_set TagSet; - typedef std::shared_ptr> TagSetSharedPtr; + typedef std::unordered_set DataSet; + typedef std::shared_ptr DataSetSharedPtr; /** - * Structure to hold a CIDR range and the tag associated with it. + * Structure to hold a CIDR range and the data associated with it. */ template struct IpPrefix { IpPrefix() {} - IpPrefix(const IpType& ip, uint32_t length, const std::string& tag) : ip_(ip), length_(length) { - tags_.insert(tag); + IpPrefix(const IpType& ip, uint32_t length, const T& data) : ip_(ip), length_(length) { + data_.insert(data); } - IpPrefix(const IpType& ip, int length, const TagSet& tags) - : ip_(ip), length_(length), tags_(tags) {} + IpPrefix(const IpType& ip, int length, const DataSet& data) + : ip_(ip), length_(length), data_(data) {} /** * @return -1 if the current object is less than other. 0 if they are the same. 1 @@ -169,8 +293,8 @@ class LcTrie { IpType ip_{0}; // Length of the cidr range. uint32_t length_{0}; - // Tag(s) for this entry. - TagSet tags_; + // Data for this entry. + DataSet data_; }; /** @@ -187,11 +311,11 @@ class LcTrie { */ template class BinaryTrie { public: - BinaryTrie() : root_(std::make_unique()) {} + BinaryTrie(bool exclusive) : root_(std::make_unique()), exclusive_(exclusive) {} /** - * Add a CIDR prefix and associated tag to the binary trie. If an entry already - * exists for the prefix, merge the tag into the existing entry. + * Add a CIDR prefix and associated data to the binary trie. If an entry already + * exists for the prefix, merge the data into the existing entry. */ void insert(const IpPrefix& prefix) { Node* node = root_.get(); @@ -203,14 +327,14 @@ class LcTrie { } node = next_node.get(); } - if (node->tags == nullptr) { - node->tags = std::make_shared(); + if (node->data == nullptr) { + node->data = std::make_shared(); } - node->tags->insert(prefix.tags_.begin(), prefix.tags_.end()); + node->data->insert(prefix.data_.begin(), prefix.data_.end()); } /** - * Update each node in the trie to inherit/override its ancestors' tags, + * Update each node in the trie to inherit/override its ancestors' data, * and then push the prefixes in the binary trie to the leaves so that: * 1) each leaf contains a prefix, and * 2) given the set of prefixes now located at the leaves, a useful @@ -221,18 +345,18 @@ class LcTrie { */ std::vector> push_leaves() { std::vector> prefixes; - std::function visit = - [&](Node* node, TagSetSharedPtr tags, unsigned depth, IpType prefix) { - // Inherit any tags set by ancestor nodes. - if (tags != nullptr) { - if (node->tags == nullptr) { - node->tags = tags; - } else { - node->tags->insert(tags->begin(), tags->end()); + std::function visit = + [&](Node* node, DataSetSharedPtr data, unsigned depth, IpType prefix) { + // Inherit any data set by ancestor nodes. + if (data != nullptr) { + if (node->data == nullptr) { + node->data = data; + } else if (!exclusive_) { + node->data->insert(data->begin(), data->end()); } } // If a node has exactly one child, create a second child node - // that inherits the union of all tags set by any ancestor nodes. + // that inherits the union of all data set by any ancestor nodes. // This gives the trie an important new property: all the configured // prefixes end up at the leaves of the trie. As no leaf is nested // under another leaf (or one of them would not be a leaf!), the @@ -245,17 +369,17 @@ class LcTrie { node->children[0] = std::make_unique(); } if (node->children[0] != nullptr) { - visit(node->children[0].get(), node->tags, depth + 1, (prefix << 1) + IpType(0)); - visit(node->children[1].get(), node->tags, depth + 1, (prefix << 1) + IpType(1)); + visit(node->children[0].get(), node->data, depth + 1, (prefix << 1) + IpType(0)); + visit(node->children[1].get(), node->data, depth + 1, (prefix << 1) + IpType(1)); } else { - if (node->tags != nullptr) { + if (node->data != nullptr) { // Compute the CIDR prefix from the path we've taken to get to this point in the // tree. IpType ip = prefix; if (depth != 0) { ip <<= (address_size - depth); } - prefixes.emplace_back(IpPrefix(ip, depth, *node->tags)); + prefixes.emplace_back(IpPrefix(ip, depth, *node->data)); } } }; @@ -266,14 +390,15 @@ class LcTrie { private: struct Node { std::unique_ptr children[2]; - TagSetSharedPtr tags; + DataSetSharedPtr data; }; typedef std::unique_ptr NodePtr; NodePtr root_; + bool exclusive_; }; /** - * Level Compressed Trie (LC-Trie) that contains CIDR ranges and its corresponding tags. + * Level Compressed Trie (LC-Trie) that contains CIDR ranges and its corresponding data. * * The following is an implementation of the algorithm described in the paper * 'IP-address lookup using LC-tries' by'S. Nilsson' and 'G. Karlsson'. @@ -289,35 +414,35 @@ class LcTrie { public: /** * Construct a LC-Trie for IpType. - * @param tag_data supplies a vector of tag and CIDR ranges (in IpPrefix format). + * @param data supplies a vector of data and CIDR ranges (in IpPrefix format). * @param fill_factor supplies the fraction of completeness to use when calculating the branch * value for a sub-trie. * @param root_branching_factor supplies the branching factor at the root. The paper suggests - * for large LC-Tries to use the value '16' for the root branching - * factor. It reduces the depth of the trie. + * for large LC-Tries to use the value '16' for the root + * branching factor. It reduces the depth of the trie. */ - LcTrieInternal(std::vector>& tag_data, double fill_factor, + LcTrieInternal(std::vector>& data, double fill_factor, uint32_t root_branching_factor); /** - * Retrieve the tag associated with the CIDR range that contains `ip_address`. + * Retrieve the data associated with the CIDR range that contains `ip_address`. * @param ip_address supplies the IP address in host byte order. - * @return a vector of tags from the CIDR ranges and IP addresses that encompasses the input. An - * empty vector is returned if the LC Trie is empty. + * @return a vector of data from the CIDR ranges and IP addresses that encompasses the input. + * An empty vector is returned if the LC Trie is empty. */ - std::vector getTags(const IpType& ip_address) const; + std::vector getData(const IpType& ip_address) const; private: /** - * Builds the Level Compresesed Trie, by first sorting the tag data, removing duplicated + * Builds the Level Compresesed Trie, by first sorting the data, removing duplicated * prefixes and invoking buildRecursive() to build the trie. */ - void build(std::vector>& tag_data) { - if (tag_data.empty()) { + void build(std::vector>& data) { + if (data.empty()) { return; } - ip_prefixes_ = tag_data; + ip_prefixes_ = data; std::sort(ip_prefixes_.begin(), ip_prefixes_.end()); // In theory, the trie_ vector can have at most twice the number of ip_prefixes entries - 1. @@ -334,6 +459,7 @@ class LcTrie { // The value of next_free_index is the final size of the trie_. trie_.resize(next_free_index); + trie_.shrink_to_fit(); } // Thin wrapper around computeBranch output to facilitate code readability. @@ -477,8 +603,8 @@ class LcTrie { trie_[position].skip_ = output.prefix_ - prefix; trie_[position].address_ = address; - // The next available free index to populate in the trie_ is at next_free_index + 2^(branching - // factor). + // The next available free index to populate in the trie_ is at next_free_index + + // 2^(branching factor). next_free_index += 1 << output.branch_; uint32_t new_position = first; @@ -487,8 +613,8 @@ class LcTrie { for (uint32_t bit_pattern = 0; bit_pattern < static_cast(1 << output.branch_); ++bit_pattern) { - // count is the number of entries in the ip_prefixes_ vector that have the same bit pattern - // as the ip_prefixes_[new_position]. + // count is the number of entries in the ip_prefixes_ vector that have the same bit + // pattern as the ip_prefixes_[new_position]. int count = 0; while (new_position + count < first + n && static_cast(extractBits( @@ -534,13 +660,13 @@ class LcTrie { } /** - * LcNode is a uint32_t. A wrapper is provided to simplify getting/setting the branch, the skip - * and the address values held within the structure. + * LcNode is a uint32_t. A wrapper is provided to simplify getting/setting the branch, the + * skip and the address values held within the structure. * * The LcNode has three parts to it * - Branch: the first 5 bits represent the branching factor. The branching factor is used to - * determine the number of descendants for the current node. The number represents a power of 2, - * so there can be at most 2^31 descendant nodes. + * determine the number of descendants for the current node. The number represents a power of + * 2, so there can be at most 2^31 descendant nodes. * - Skip: the next 7 bits represent the number of bits to skip when looking at an IP address. * This value can be between 0 and 127, so IPv6 is supported. * - Address: the remaining 20 bits represent an index either into the trie_ or the @@ -556,14 +682,15 @@ class LcTrie { uint32_t address_ : 20; // If this 20-bit size changes, please change MaxLcTrieNodes too. }; - // During build(), an estimate of the number of nodes required will be made and set this value. - // This is used to ensure no out_of_range exception is thrown. + // During build(), an estimate of the number of nodes required will be made and set this + // value. This is used to ensure no out_of_range exception is thrown. uint32_t maximum_trie_node_size; - // The CIDR range and tags needs to be maintained separately from the LC-Trie. A LC-Trie skips - // chunks of data while searching for a match. This means that the node found in the LC-Trie is - // not guaranteed to have the IP address in range. The last step prior to returning a tag is to - // check the CIDR range pointed to by the node in the LC-Trie has the IP address in range. + // The CIDR range and data needs to be maintained separately from the LC-Trie. A LC-Trie skips + // chunks of data while searching for a match. This means that the node found in the LC-Trie + // is not guaranteed to have the IP address in range. The last step prior to returning + // associated data is to check the CIDR range pointed to by the node in the LC-Trie has + // the IP address in range. std::vector> ip_prefixes_; // Main trie search structure. @@ -577,17 +704,20 @@ class LcTrie { std::unique_ptr> ipv6_trie_; }; +template template -LcTrie::LcTrieInternal::LcTrieInternal( - std::vector>& tag_data, double fill_factor, uint32_t root_branching_factor) +LcTrie::LcTrieInternal::LcTrieInternal(std::vector>& data, + double fill_factor, + uint32_t root_branching_factor) : fill_factor_(fill_factor), root_branching_factor_(root_branching_factor) { - build(tag_data); + build(data); } +template template -std::vector -LcTrie::LcTrieInternal::getTags(const IpType& ip_address) const { - std::vector return_vector; +std::vector +LcTrie::LcTrieInternal::getData(const IpType& ip_address) const { + std::vector return_vector; if (trie_.empty()) { return return_vector; } @@ -613,9 +743,9 @@ LcTrie::LcTrieInternal::getTags(const IpType& ip_address) // ip_address. const auto& prefix = ip_prefixes_[address]; if (prefix.contains(ip_address)) { - return std::vector(prefix.tags_.begin(), prefix.tags_.end()); + return std::vector(prefix.data_.begin(), prefix.data_.end()); } - return std::vector(); + return std::vector(); } } // namespace LcTrie diff --git a/source/common/network/listen_socket_impl.cc b/source/common/network/listen_socket_impl.cc index 85039407051c0..bc7e8a7858d97 100644 --- a/source/common/network/listen_socket_impl.cc +++ b/source/common/network/listen_socket_impl.cc @@ -16,11 +16,11 @@ namespace Envoy { namespace Network { void ListenSocketImpl::doBind() { - int rc = local_address_->bind(fd_); - if (rc == -1) { + const Api::SysCallResult result = local_address_->bind(fd_); + if (result.rc_ == -1) { close(); throw EnvoyException( - fmt::format("cannot bind '{}': {}", local_address_->asString(), strerror(errno))); + fmt::format("cannot bind '{}': {}", local_address_->asString(), strerror(result.errno_))); } if (local_address_->type() == Address::Type::Ip && local_address_->ip()->port() == 0) { // If the port we bind is zero, then the OS will pick a free port for us (assuming there are @@ -40,12 +40,12 @@ TcpListenSocket::TcpListenSocket(const Address::InstanceConstSharedPtr& address, const Network::Socket::OptionsSharedPtr& options, bool bind_to_port) : ListenSocketImpl(address->socket(Address::SocketType::Stream), address) { - RELEASE_ASSERT(fd_ != -1); + RELEASE_ASSERT(fd_ != -1, ""); // TODO(htuch): This might benefit from moving to SocketOptionImpl. int on = 1; int rc = setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); setListenSocketOptions(options); @@ -62,7 +62,7 @@ TcpListenSocket::TcpListenSocket(int fd, const Address::InstanceConstSharedPtr& UdsListenSocket::UdsListenSocket(const Address::InstanceConstSharedPtr& address) : ListenSocketImpl(address->socket(Address::SocketType::Stream), address) { - RELEASE_ASSERT(fd_ != -1); + RELEASE_ASSERT(fd_ != -1, ""); doBind(); } diff --git a/source/common/network/raw_buffer_socket.cc b/source/common/network/raw_buffer_socket.cc index a9c1630748a43..699313042386e 100644 --- a/source/common/network/raw_buffer_socket.cc +++ b/source/common/network/raw_buffer_socket.cc @@ -17,23 +17,22 @@ IoResult RawBufferSocket::doRead(Buffer::Instance& buffer) { bool end_stream = false; do { // 16K read is arbitrary. TODO(mattklein123) PERF: Tune the read size. - int rc = buffer.read(callbacks_->fd(), 16384); - ENVOY_CONN_LOG(trace, "read returns: {}", callbacks_->connection(), rc); + Api::SysCallResult result = buffer.read(callbacks_->fd(), 16384); + ENVOY_CONN_LOG(trace, "read returns: {}", callbacks_->connection(), result.rc_); - if (rc == 0) { + if (result.rc_ == 0) { // Remote close. end_stream = true; break; - } else if (rc == -1) { + } else if (result.rc_ == -1) { // Remote error (might be no data). - ENVOY_CONN_LOG(trace, "read error: {}", callbacks_->connection(), errno); - if (errno != EAGAIN) { + ENVOY_CONN_LOG(trace, "read error: {}", callbacks_->connection(), result.errno_); + if (result.errno_ != EAGAIN) { action = PostIoAction::Close; } - break; } else { - bytes_read += rc; + bytes_read += result.rc_; if (callbacks_->shouldDrainReadBuffer()) { callbacks_->setReadBufferReady(); break; @@ -59,20 +58,20 @@ IoResult RawBufferSocket::doWrite(Buffer::Instance& buffer, bool end_stream) { action = PostIoAction::KeepOpen; break; } - int rc = buffer.write(callbacks_->fd()); - ENVOY_CONN_LOG(trace, "write returns: {}", callbacks_->connection(), rc); - if (rc == -1) { - ENVOY_CONN_LOG(trace, "write error: {} ({})", callbacks_->connection(), errno, - strerror(errno)); - if (errno == EAGAIN) { + Api::SysCallResult result = buffer.write(callbacks_->fd()); + ENVOY_CONN_LOG(trace, "write returns: {}", callbacks_->connection(), result.rc_); + + if (result.rc_ == -1) { + ENVOY_CONN_LOG(trace, "write error: {} ({})", callbacks_->connection(), result.errno_, + strerror(result.errno_)); + if (result.errno_ == EAGAIN) { action = PostIoAction::KeepOpen; } else { action = PostIoAction::Close; } - break; } else { - bytes_written += rc; + bytes_written += result.rc_; } } while (true); diff --git a/source/common/network/socket_option_impl.cc b/source/common/network/socket_option_impl.cc index d4c35fd8e0831..a0b8c3f6cff98 100644 --- a/source/common/network/socket_option_impl.cc +++ b/source/common/network/socket_option_impl.cc @@ -13,9 +13,9 @@ namespace Network { bool SocketOptionImpl::setOption(Socket& socket, envoy::api::v2::core::SocketOption::SocketState state) const { if (in_state_ == state) { - const int error = SocketOptionImpl::setSocketOption(socket, optname_, value_); - if (error != 0) { - ENVOY_LOG(warn, "Setting option on socket failed: {}", strerror(errno)); + const Api::SysCallResult result = SocketOptionImpl::setSocketOption(socket, optname_, value_); + if (result.rc_ != 0) { + ENVOY_LOG(warn, "Setting option on socket failed: {}", strerror(result.errno_)); return false; } } @@ -24,16 +24,17 @@ bool SocketOptionImpl::setOption(Socket& socket, bool SocketOptionImpl::isSupported() const { return optname_.has_value(); } -int SocketOptionImpl::setSocketOption(Socket& socket, Network::SocketOptionName optname, - const absl::string_view value) { +Api::SysCallResult SocketOptionImpl::setSocketOption(Socket& socket, + Network::SocketOptionName optname, + const absl::string_view value) { if (!optname.has_value()) { - errno = ENOTSUP; - return -1; + return {-1, ENOTSUP}; } auto& os_syscalls = Api::OsSysCallsSingleton::get(); - return os_syscalls.setsockopt(socket.fd(), optname.value().first, optname.value().second, - value.data(), value.size()); + const int rc = os_syscalls.setsockopt(socket.fd(), optname.value().first, optname.value().second, + value.data(), value.size()); + return {rc, errno}; } } // namespace Network diff --git a/source/common/network/socket_option_impl.h b/source/common/network/socket_option_impl.h index 1b7f8d5b8f800..c40b530c0628b 100644 --- a/source/common/network/socket_option_impl.h +++ b/source/common/network/socket_option_impl.h @@ -5,6 +5,7 @@ #include #include +#include "envoy/api/os_sys_calls.h" #include "envoy/network/listen_socket.h" #include "common/common/logger.h" @@ -102,8 +103,16 @@ class SocketOptionImpl : public Socket::Option, Logger::Loggable #endif +#ifndef IP6T_SO_ORIGINAL_DST +// From linux/netfilter_ipv6/ip6_tables.h +#define IP6T_SO_ORIGINAL_DST 80 +#endif + #include #include @@ -20,6 +25,7 @@ #include "envoy/network/connection.h" #include "envoy/stats/stats.h" +#include "common/api/os_sys_calls_impl.h" #include "common/common/assert.h" #include "common/common/utility.h" #include "common/network/address_impl.h" @@ -98,7 +104,7 @@ Address::InstanceConstSharedPtr Utility::parseInternetAddress(const std::string& return std::make_shared(sa6, v6only); } throwWithMalformedIp(ip_address); - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } Address::InstanceConstSharedPtr Utility::parseInternetAddressAndPort(const std::string& ip_address, @@ -168,7 +174,7 @@ Address::InstanceConstSharedPtr Utility::getLocalAddress(const Address::IpVersio Address::InstanceConstSharedPtr ret; int rc = getifaddrs(&ifaddr); - RELEASE_ASSERT(!rc); + RELEASE_ASSERT(!rc, ""); // man getifaddrs(3) for (ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) { @@ -250,7 +256,7 @@ bool Utility::isLoopbackAddress(const Address::Instance& address) { absl::uint128 addr = address.ip()->ipv6()->address(); return 0 == memcmp(&addr, &in6addr_loopback, sizeof(in6addr_loopback)); } - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Address::InstanceConstSharedPtr Utility::getCanonicalIpv4LoopbackAddress() { @@ -285,21 +291,42 @@ Address::InstanceConstSharedPtr Utility::getAddressWithPort(const Address::Insta case Network::Address::IpVersion::v6: return std::make_shared(address.ip()->addressAsString(), port); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } Address::InstanceConstSharedPtr Utility::getOriginalDst(int fd) { #ifdef SOL_IP sockaddr_storage orig_addr; socklen_t addr_len = sizeof(sockaddr_storage); - int status = getsockopt(fd, SOL_IP, SO_ORIGINAL_DST, &orig_addr, &addr_len); + int socket_domain; + socklen_t domain_len = sizeof(socket_domain); + auto& os_syscalls = Api::OsSysCallsSingleton::get(); + int status = os_syscalls.getsockopt(fd, SOL_SOCKET, SO_DOMAIN, &socket_domain, &domain_len); + + if (status != 0) { + return nullptr; + } - if (status == 0) { - // TODO(mattklein123): IPv6 support. See github issue #1094. - ASSERT(orig_addr.ss_family == AF_INET); + if (socket_domain == AF_INET) { + status = os_syscalls.getsockopt(fd, SOL_IP, SO_ORIGINAL_DST, &orig_addr, &addr_len); + } else if (socket_domain == AF_INET6) { + status = os_syscalls.getsockopt(fd, SOL_IPV6, IP6T_SO_ORIGINAL_DST, &orig_addr, &addr_len); + } else { + return nullptr; + } + + if (status != 0) { + return nullptr; + } + + switch (orig_addr.ss_family) { + case AF_INET: return Address::InstanceConstSharedPtr{ new Address::Ipv4Instance(reinterpret_cast(&orig_addr))}; - } else { + case AF_INET6: + return Address::InstanceConstSharedPtr{ + new Address::Ipv6Instance(reinterpret_cast(orig_addr))}; + default: return nullptr; } #else @@ -384,7 +411,7 @@ Utility::protobufAddressToAddress(const envoy::api::v2::core::Address& proto_add case envoy::api::v2::core::Address::kPipe: return std::make_shared(proto_address.pipe().path()); default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } diff --git a/source/common/protobuf/utility.cc b/source/common/protobuf/utility.cc index 3c93c42992d34..c9be39db0d0ce 100644 --- a/source/common/protobuf/utility.cc +++ b/source/common/protobuf/utility.cc @@ -30,7 +30,7 @@ uint64_t fractionalPercentDenominatorToInt(const envoy::type::FractionalPercent& return 1000000; default: // Checked by schema. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -49,7 +49,9 @@ ProtoValidationException::ProtoValidationException(const std::string& validation } void MessageUtil::loadFromJson(const std::string& json, Protobuf::Message& message) { - const auto status = Protobuf::util::JsonStringToMessage(json, &message); + Protobuf::util::JsonParseOptions options; + options.ignore_unknown_fields = true; + const auto status = Protobuf::util::JsonStringToMessage(json, &message, options); if (!status.ok()) { throw EnvoyException("Unable to parse JSON as proto (" + status.ToString() + "): " + json); } @@ -87,7 +89,8 @@ void MessageUtil::loadFromFile(const std::string& path, Protobuf::Message& messa } std::string MessageUtil::getJsonStringFromMessage(const Protobuf::Message& message, - const bool pretty_print) { + const bool pretty_print, + const bool always_print_primitive_fields) { Protobuf::util::JsonPrintOptions json_options; // By default, proto field names are converted to camelCase when the message is converted to JSON. // Setting this option makes debugging easier because it keeps field names consistent in JSON @@ -96,10 +99,15 @@ std::string MessageUtil::getJsonStringFromMessage(const Protobuf::Message& messa if (pretty_print) { json_options.add_whitespace = true; } + // Primitive types such as int32s and enums will not be serialized if they have the default value. + // This flag disables that behavior. + if (always_print_primitive_fields) { + json_options.always_print_primitive_fields = true; + } ProtobufTypes::String json; const auto status = Protobuf::util::MessageToJsonString(message, &json, json_options); // This should always succeed unless something crash-worthy such as out-of-memory. - RELEASE_ASSERT(status.ok()); + RELEASE_ASSERT(status.ok(), ""); return json; } @@ -180,7 +188,7 @@ bool ValueUtil::equal(const ProtobufWkt::Value& v1, const ProtobufWkt::Value& v2 } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } diff --git a/source/common/protobuf/utility.h b/source/common/protobuf/utility.h index bbe25895f10d2..93d098c75d739 100644 --- a/source/common/protobuf/utility.h +++ b/source/common/protobuf/utility.h @@ -71,11 +71,16 @@ uint64_t fractionalPercentDenominatorToInt(const envoy::type::FractionalPercent& // @param field_name supplies the field name in the message. // @param max_value supplies the maximum allowed integral value (e.g., 100, 10000, etc.). // @param default_value supplies the default if the field is not present. +// +// TODO(anirudhmurali): Recommended to capture and validate NaN values in PGV +// Issue: https://github.com/lyft/protoc-gen-validate/issues/85 #define PROTOBUF_PERCENT_TO_ROUNDED_INTEGER_OR_DEFAULT(message, field_name, max_value, \ default_value) \ - ((message).has_##field_name() \ - ? ProtobufPercentHelper::convertPercent((message).field_name().value(), max_value) \ - : ProtobufPercentHelper::checkAndReturnDefault(default_value, max_value)) + (!std::isnan((message).field_name().value()) \ + ? (message).has_##field_name() \ + ? ProtobufPercentHelper::convertPercent((message).field_name().value(), max_value) \ + : ProtobufPercentHelper::checkAndReturnDefault(default_value, max_value) \ + : throw EnvoyException(fmt::format("Value not in the range of 0..100 range."))) namespace Envoy { @@ -226,10 +231,13 @@ class MessageUtil { * Extract JSON as string from a google.protobuf.Message. * @param message message of type type.googleapis.com/google.protobuf.Message. * @param pretty_print whether the returned JSON should be formatted. + * @param always_print_primitive_fields whether to include primitive fields set to their default + * values, e.g. an int32 set to 0 or a bool set to false. * @return std::string of formatted JSON object. */ static std::string getJsonStringFromMessage(const Protobuf::Message& message, - bool pretty_print = false); + bool pretty_print = false, + bool always_print_primitive_fields = false); /** * Extract JSON object from a google.protobuf.Message. diff --git a/source/common/request_info/request_info_impl.h b/source/common/request_info/request_info_impl.h index 6adced2357e91..f56f34ebc4e52 100644 --- a/source/common/request_info/request_info_impl.h +++ b/source/common/request_info/request_info_impl.h @@ -125,7 +125,13 @@ struct RequestInfoImpl : public RequestInfo { void setResponseFlag(ResponseFlag response_flag) override { response_flags_ |= response_flag; } - bool getResponseFlag(ResponseFlag flag) const override { return response_flags_ & flag; } + bool intersectResponseFlags(uint64_t response_flags) const override { + return (response_flags_ & response_flags) != 0; + } + + bool hasResponseFlag(ResponseFlag flag) const override { return response_flags_ & flag; } + + bool hasAnyResponseFlag() const override { return response_flags_ != 0; } void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host) override { upstream_host_ = host; diff --git a/source/common/request_info/utility.cc b/source/common/request_info/utility.cc index b8138833068b2..27a01a6b24453 100644 --- a/source/common/request_info/utility.cc +++ b/source/common/request_info/utility.cc @@ -33,61 +33,85 @@ const std::string ResponseFlagUtils::toShortString(const RequestInfo& request_in static_assert(ResponseFlag::LastFlag == 0x1000, "A flag has been added. Fix this code."); - if (request_info.getResponseFlag(ResponseFlag::FailedLocalHealthCheck)) { + if (request_info.hasResponseFlag(ResponseFlag::FailedLocalHealthCheck)) { appendString(result, FAILED_LOCAL_HEALTH_CHECK); } - if (request_info.getResponseFlag(ResponseFlag::NoHealthyUpstream)) { + if (request_info.hasResponseFlag(ResponseFlag::NoHealthyUpstream)) { appendString(result, NO_HEALTHY_UPSTREAM); } - if (request_info.getResponseFlag(ResponseFlag::UpstreamRequestTimeout)) { + if (request_info.hasResponseFlag(ResponseFlag::UpstreamRequestTimeout)) { appendString(result, UPSTREAM_REQUEST_TIMEOUT); } - if (request_info.getResponseFlag(ResponseFlag::LocalReset)) { + if (request_info.hasResponseFlag(ResponseFlag::LocalReset)) { appendString(result, LOCAL_RESET); } - if (request_info.getResponseFlag(ResponseFlag::UpstreamRemoteReset)) { + if (request_info.hasResponseFlag(ResponseFlag::UpstreamRemoteReset)) { appendString(result, UPSTREAM_REMOTE_RESET); } - if (request_info.getResponseFlag(ResponseFlag::UpstreamConnectionFailure)) { + if (request_info.hasResponseFlag(ResponseFlag::UpstreamConnectionFailure)) { appendString(result, UPSTREAM_CONNECTION_FAILURE); } - if (request_info.getResponseFlag(ResponseFlag::UpstreamConnectionTermination)) { + if (request_info.hasResponseFlag(ResponseFlag::UpstreamConnectionTermination)) { appendString(result, UPSTREAM_CONNECTION_TERMINATION); } - if (request_info.getResponseFlag(ResponseFlag::UpstreamOverflow)) { + if (request_info.hasResponseFlag(ResponseFlag::UpstreamOverflow)) { appendString(result, UPSTREAM_OVERFLOW); } - if (request_info.getResponseFlag(ResponseFlag::NoRouteFound)) { + if (request_info.hasResponseFlag(ResponseFlag::NoRouteFound)) { appendString(result, NO_ROUTE_FOUND); } - if (request_info.getResponseFlag(ResponseFlag::DelayInjected)) { + if (request_info.hasResponseFlag(ResponseFlag::DelayInjected)) { appendString(result, DELAY_INJECTED); } - if (request_info.getResponseFlag(ResponseFlag::FaultInjected)) { + if (request_info.hasResponseFlag(ResponseFlag::FaultInjected)) { appendString(result, FAULT_INJECTED); } - if (request_info.getResponseFlag(ResponseFlag::RateLimited)) { + if (request_info.hasResponseFlag(ResponseFlag::RateLimited)) { appendString(result, RATE_LIMITED); } - if (request_info.getResponseFlag(ResponseFlag::UnauthorizedExternalService)) { + if (request_info.hasResponseFlag(ResponseFlag::UnauthorizedExternalService)) { appendString(result, UNAUTHORIZED_EXTERNAL_SERVICE); } return result.empty() ? NONE : result; } +absl::optional ResponseFlagUtils::toResponseFlag(const std::string& flag) { + static const std::map map = { + {ResponseFlagUtils::FAILED_LOCAL_HEALTH_CHECK, ResponseFlag::FailedLocalHealthCheck}, + {ResponseFlagUtils::NO_HEALTHY_UPSTREAM, ResponseFlag::NoHealthyUpstream}, + {ResponseFlagUtils::UPSTREAM_REQUEST_TIMEOUT, ResponseFlag::UpstreamRequestTimeout}, + {ResponseFlagUtils::LOCAL_RESET, ResponseFlag::LocalReset}, + {ResponseFlagUtils::UPSTREAM_REMOTE_RESET, ResponseFlag::UpstreamRemoteReset}, + {ResponseFlagUtils::UPSTREAM_CONNECTION_FAILURE, ResponseFlag::UpstreamConnectionFailure}, + {ResponseFlagUtils::UPSTREAM_CONNECTION_TERMINATION, + ResponseFlag::UpstreamConnectionTermination}, + {ResponseFlagUtils::UPSTREAM_OVERFLOW, ResponseFlag::UpstreamOverflow}, + {ResponseFlagUtils::NO_ROUTE_FOUND, ResponseFlag::NoRouteFound}, + {ResponseFlagUtils::DELAY_INJECTED, ResponseFlag::DelayInjected}, + {ResponseFlagUtils::FAULT_INJECTED, ResponseFlag::FaultInjected}, + {ResponseFlagUtils::RATE_LIMITED, ResponseFlag::RateLimited}, + {ResponseFlagUtils::UNAUTHORIZED_EXTERNAL_SERVICE, ResponseFlag::UnauthorizedExternalService}, + }; + const auto& it = map.find(flag); + if (it != map.end()) { + return absl::make_optional(it->second); + } + return absl::nullopt; +} + const std::string& Utility::formatDownstreamAddressNoPort(const Network::Address::Instance& address) { if (address.type() == Network::Address::Type::Ip) { diff --git a/source/common/request_info/utility.h b/source/common/request_info/utility.h index 15b34fb83c42a..879bccd196cff 100644 --- a/source/common/request_info/utility.h +++ b/source/common/request_info/utility.h @@ -14,6 +14,7 @@ namespace RequestInfo { class ResponseFlagUtils { public: static const std::string toShortString(const RequestInfo& request_info); + static absl::optional toResponseFlag(const std::string& response_flag); private: ResponseFlagUtils(); diff --git a/source/common/router/BUILD b/source/common/router/BUILD index 1960b0c117c6f..d9313aa228cc1 100644 --- a/source/common/router/BUILD +++ b/source/common/router/BUILD @@ -100,6 +100,7 @@ envoy_cc_library( hdrs = ["rds_subscription.h"], deps = [ "//include/envoy/config:subscription_interface", + "//include/envoy/stats:stats_interface", "//source/common/common:assert_lib", "//source/common/config:rds_json_lib", "//source/common/config:utility_lib", @@ -123,6 +124,7 @@ envoy_cc_library( "//include/envoy/runtime:runtime_interface", "//include/envoy/upstream:upstream_interface", "//source/common/common:assert_lib", + "//source/common/common:backoff_lib", "//source/common/common:utility_lib", "//source/common/grpc:common_lib", "//source/common/http:codes_lib", diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc index 114d415ae6f08..d5301073fd1ac 100644 --- a/source/common/router/config_impl.cc +++ b/source/common/router/config_impl.cc @@ -54,6 +54,9 @@ CorsPolicyImpl::CorsPolicyImpl(const envoy::api::v2::route::CorsPolicy& config) for (const auto& origin : config.allow_origin()) { allow_origin_.push_back(origin); } + for (const auto& regex : config.allow_origin_regex()) { + allow_origin_regex_.push_back(RegexUtil::parseRegex(regex)); + } allow_methods_ = config.allow_methods(); allow_headers_ = config.allow_headers(); expose_headers_ = config.expose_headers(); @@ -255,6 +258,7 @@ RouteEntryImplBase::RouteEntryImplBase(const VirtualHostImpl& vhost, cluster_not_found_response_code_(ConfigUtility::parseClusterNotFoundResponseCode( route.route().cluster_not_found_response_code())), timeout_(PROTOBUF_GET_MS_OR_DEFAULT(route.route(), timeout, DEFAULT_ROUTE_TIMEOUT_MS)), + idle_timeout_(PROTOBUF_GET_OPTIONAL_MS(route.route(), idle_timeout)), max_grpc_timeout_(PROTOBUF_GET_OPTIONAL_MS(route.route(), max_grpc_timeout)), runtime_(loadRuntimeData(route.match())), loader_(factory_context.runtime()), host_redirect_(route.redirect().host_redirect()), @@ -266,9 +270,13 @@ RouteEntryImplBase::RouteEntryImplBase(const VirtualHostImpl& vhost, priority_(ConfigUtility::parsePriority(route.route().priority())), total_cluster_weight_( PROTOBUF_GET_WRAPPED_OR_DEFAULT(route.route().weighted_clusters(), total_weight, 100UL)), - request_headers_parser_(HeaderParser::configure(route.route().request_headers_to_add())), - response_headers_parser_(HeaderParser::configure(route.route().response_headers_to_add(), - route.route().response_headers_to_remove())), + route_action_request_headers_parser_( + HeaderParser::configure(route.route().request_headers_to_add())), + route_action_response_headers_parser_(HeaderParser::configure( + route.route().response_headers_to_add(), route.route().response_headers_to_remove())), + request_headers_parser_(HeaderParser::configure(route.request_headers_to_add())), + response_headers_parser_(HeaderParser::configure(route.response_headers_to_add(), + route.response_headers_to_remove())), opaque_config_(parseOpaqueConfig(route)), decorator_(parseDecorator(route)), direct_response_code_(ConfigUtility::parseDirectResponseCode(route)), direct_response_body_(ConfigUtility::parseDirectResponseBody(route)), @@ -365,8 +373,10 @@ Http::WebSocketProxyPtr RouteEntryImplBase::createWebSocketProxy( void RouteEntryImplBase::finalizeRequestHeaders(Http::HeaderMap& headers, const RequestInfo::RequestInfo& request_info, bool insert_envoy_original_path) const { - // Append user-specified request headers in the following order: route-level headers, - // virtual host level headers and finally global connection manager level headers. + // Append user-specified request headers in the following order: route-action-level headers, + // route-level headers, virtual host level headers and finally global connection manager level + // headers. + route_action_request_headers_parser_->evaluateHeaders(headers, request_info); request_headers_parser_->evaluateHeaders(headers, request_info); vhost_.requestHeaderParser().evaluateHeaders(headers, request_info); vhost_.globalRouteConfig().requestHeaderParser().evaluateHeaders(headers, request_info); @@ -382,6 +392,10 @@ void RouteEntryImplBase::finalizeRequestHeaders(Http::HeaderMap& headers, void RouteEntryImplBase::finalizeResponseHeaders( Http::HeaderMap& headers, const RequestInfo::RequestInfo& request_info) const { + // Append user-specified response headers in the following order: route-action-level headers, + // route-level headers, virtual host level headers and finally global connection manager level + // headers. + route_action_response_headers_parser_->evaluateHeaders(headers, request_info); response_headers_parser_->evaluateHeaders(headers, request_info); vhost_.responseHeaderParser().evaluateHeaders(headers, request_info); vhost_.globalRouteConfig().responseHeaderParser().evaluateHeaders(headers, request_info); @@ -457,7 +471,7 @@ RouteEntryImplBase::parseOpaqueConfig(const envoy::api::v2::route::Route& route) std::multimap ret; if (route.has_metadata()) { const auto filter_metadata = route.metadata().filter_metadata().find( - Extensions::HttpFilters::HttpFilterNames::get().ROUTER); + Extensions::HttpFilters::HttpFilterNames::get().Router); if (filter_metadata == route.metadata().filter_metadata().end()) { return ret; } @@ -537,7 +551,7 @@ RouteConstSharedPtr RouteEntryImplBase::clusterEntry(const Http::HeaderMap& head } begin = end; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void RouteEntryImplBase::validateClusters(Upstream::ClusterManager& cm) const { @@ -710,7 +724,7 @@ VirtualHostImpl::VirtualHostImpl(const envoy::api::v2::route::VirtualHost& virtu ssl_requirements_ = SslRequirements::ALL; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } for (const auto& route : virtual_host.routes()) { diff --git a/source/common/router/config_impl.h b/source/common/router/config_impl.h index a8df5a153b6c5..238f54b99d51b 100644 --- a/source/common/router/config_impl.h +++ b/source/common/router/config_impl.h @@ -98,6 +98,7 @@ class CorsPolicyImpl : public CorsPolicy { // Router::CorsPolicy const std::list& allowOrigins() const override { return allow_origin_; }; + const std::list& allowOriginRegexes() const override { return allow_origin_regex_; } const std::string& allowMethods() const override { return allow_methods_; }; const std::string& allowHeaders() const override { return allow_headers_; }; const std::string& exposeHeaders() const override { return expose_headers_; }; @@ -107,6 +108,7 @@ class CorsPolicyImpl : public CorsPolicy { private: std::list allow_origin_; + std::list allow_origin_regex_; std::string allow_methods_; std::string allow_headers_; std::string expose_headers_; @@ -309,6 +311,7 @@ class RouteEntryImplBase : public RouteEntry, return vhost_.virtualClusterFromEntries(headers); } std::chrono::milliseconds timeout() const override { return timeout_; } + absl::optional idleTimeout() const override { return idle_timeout_; } absl::optional maxGrpcTimeout() const override { return max_grpc_timeout_; } @@ -356,8 +359,6 @@ class RouteEntryImplBase : public RouteEntry, void finalizePathHeader(Http::HeaderMap& headers, const std::string& matched_path, bool insert_envoy_original_path) const; - const HeaderParser& requestHeaderParser() const { return *request_headers_parser_; }; - const HeaderParser& responseHeaderParser() const { return *response_headers_parser_; }; private: struct RuntimeData { @@ -393,6 +394,9 @@ class RouteEntryImplBase : public RouteEntry, const RetryPolicy& retryPolicy() const override { return parent_->retryPolicy(); } const ShadowPolicy& shadowPolicy() const override { return parent_->shadowPolicy(); } std::chrono::milliseconds timeout() const override { return parent_->timeout(); } + absl::optional idleTimeout() const override { + return parent_->idleTimeout(); + } absl::optional maxGrpcTimeout() const override { return parent_->maxGrpcTimeout(); } @@ -510,6 +514,7 @@ class RouteEntryImplBase : public RouteEntry, const Http::LowerCaseString cluster_header_name_; const Http::Code cluster_not_found_response_code_; const std::chrono::milliseconds timeout_; + const absl::optional idle_timeout_; const absl::optional max_grpc_timeout_; const absl::optional runtime_; Runtime::Loader& loader_; @@ -528,6 +533,8 @@ class RouteEntryImplBase : public RouteEntry, const uint64_t total_cluster_weight_; std::unique_ptr hash_policy_; MetadataMatchCriteriaConstPtr metadata_match_criteria_; + HeaderParserPtr route_action_request_headers_parser_; + HeaderParserPtr route_action_response_headers_parser_; HeaderParserPtr request_headers_parser_; HeaderParserPtr response_headers_parser_; envoy::api::v2::core::Metadata metadata_; diff --git a/source/common/router/config_utility.cc b/source/common/router/config_utility.cc index 5cc0c900a3938..b5d670b339b60 100644 --- a/source/common/router/config_utility.cc +++ b/source/common/router/config_utility.cc @@ -32,7 +32,7 @@ ConfigUtility::parsePriority(const envoy::api::v2::core::RoutingPriority& priori case envoy::api::v2::core::RoutingPriority::HIGH: return Upstream::ResourcePriority::High; default: - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } @@ -62,7 +62,7 @@ Http::Code ConfigUtility::parseRedirectResponseCode( case envoy::api::v2::route::RedirectAction::PERMANENT_REDIRECT: return Http::Code::PermanentRedirect; default: - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } @@ -114,7 +114,7 @@ Http::Code ConfigUtility::parseClusterNotFoundResponseCode( case envoy::api::v2::route::RouteAction::NOT_FOUND: return Http::Code::NotFound; default: - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } diff --git a/source/common/router/header_formatter.cc b/source/common/router/header_formatter.cc index b691159879e36..b76fd90f5bbe2 100644 --- a/source/common/router/header_formatter.cc +++ b/source/common/router/header_formatter.cc @@ -67,7 +67,7 @@ parseUpstreamMetadataField(absl::string_view params_str) { } const ProtobufWkt::Value* value = - &Config::Metadata::metadataValue(host->metadata(), params[0], params[1]); + &Config::Metadata::metadataValue(*host->metadata(), params[0], params[1]); if (value->kind_case() == ProtobufWkt::Value::KIND_NOT_SET) { // No kind indicates default ProtobufWkt::Value which means namespace or key not // found. diff --git a/source/common/router/header_parser.cc b/source/common/router/header_parser.cc index c5ccd74bb576e..97c7dc4b7b37d 100644 --- a/source/common/router/header_parser.cc +++ b/source/common/router/header_parser.cc @@ -179,7 +179,7 @@ parseInternal(const envoy::api::v2::core::HeaderValueOption& header_value_option break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } while (++pos < format.size()); diff --git a/source/common/router/rds_impl.cc b/source/common/router/rds_impl.cc index 5d2a000f6d769..6a3e8c545cae5 100644 --- a/source/common/router/rds_impl.cc +++ b/source/common/router/rds_impl.cc @@ -21,7 +21,7 @@ namespace Envoy { namespace Router { -RouteConfigProviderSharedPtr RouteConfigProviderUtil::create( +RouteConfigProviderPtr RouteConfigProviderUtil::create( const envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& config, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix, @@ -29,41 +29,45 @@ RouteConfigProviderSharedPtr RouteConfigProviderUtil::create( switch (config.route_specifier_case()) { case envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager:: kRouteConfig: - return route_config_provider_manager.getStaticRouteConfigProvider(config.route_config(), - factory_context); + return route_config_provider_manager.createStaticRouteConfigProvider(config.route_config(), + factory_context); case envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager::kRds: - return route_config_provider_manager.getRdsRouteConfigProvider(config.rds(), factory_context, - stat_prefix); + return route_config_provider_manager.createRdsRouteConfigProvider(config.rds(), factory_context, + stat_prefix); default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } StaticRouteConfigProviderImpl::StaticRouteConfigProviderImpl( const envoy::api::v2::RouteConfiguration& config, - Server::Configuration::FactoryContext& factory_context) + Server::Configuration::FactoryContext& factory_context, + RouteConfigProviderManagerImpl& route_config_provider_manager) : config_(new ConfigImpl(config, factory_context, true)), route_config_proto_{config}, - last_updated_(factory_context.systemTimeSource().currentTime()) {} + last_updated_(factory_context.systemTimeSource().currentTime()), + route_config_provider_manager_(route_config_provider_manager) { + route_config_provider_manager_.static_route_config_providers_.insert(this); +} + +StaticRouteConfigProviderImpl::~StaticRouteConfigProviderImpl() { + route_config_provider_manager_.static_route_config_providers_.erase(this); +} // TODO(htuch): If support for multiple clusters is added per #1170 cluster_name_ // initialization needs to be fixed. -RdsRouteConfigProviderImpl::RdsRouteConfigProviderImpl( +RdsRouteConfigSubscription::RdsRouteConfigSubscription( const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, const std::string& manager_identifier, Server::Configuration::FactoryContext& factory_context, - const std::string& stat_prefix, RouteConfigProviderManagerImpl& route_config_provider_manager) - : factory_context_(factory_context), tls_(factory_context.threadLocal().allocateSlot()), - route_config_name_(rds.route_config_name()), + const std::string& stat_prefix, + Envoy::Router::RouteConfigProviderManagerImpl& route_config_provider_manager) + : route_config_name_(rds.route_config_name()), scope_(factory_context.scope().createScope(stat_prefix + "rds." + route_config_name_ + ".")), stats_({ALL_RDS_STATS(POOL_COUNTER(*scope_))}), route_config_provider_manager_(route_config_provider_manager), - manager_identifier_(manager_identifier), + manager_identifier_(manager_identifier), time_source_(factory_context.systemTimeSource()), last_updated_(factory_context.systemTimeSource().currentTime()) { ::Envoy::Config::Utility::checkLocalInfo("rds", factory_context.localInfo()); - ConfigConstSharedPtr initial_config(new NullConfigImpl()); - tls_->set([initial_config](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { - return std::make_shared(initial_config); - }); subscription_ = Envoy::Config::SubscriptionFactory::subscriptionFromConfigSource< envoy::api::v2::RouteConfiguration>( rds.config_source(), factory_context.localInfo().node(), factory_context.dispatcher(), @@ -72,14 +76,14 @@ RdsRouteConfigProviderImpl::RdsRouteConfigProviderImpl( &factory_context]() -> Envoy::Config::Subscription* { return new RdsSubscription(Envoy::Config::Utility::generateStats(*scope_), rds, factory_context.clusterManager(), factory_context.dispatcher(), - factory_context.random(), factory_context.localInfo()); + factory_context.random(), factory_context.localInfo(), + factory_context.scope()); }, "envoy.api.v2.RouteDiscoveryService.FetchRoutes", "envoy.api.v2.RouteDiscoveryService.StreamRoutes"); - config_source_ = MessageUtil::getJsonStringFromMessage(rds.config_source(), true); } -RdsRouteConfigProviderImpl::~RdsRouteConfigProviderImpl() { +RdsRouteConfigSubscription::~RdsRouteConfigSubscription() { // If we get destroyed during initialization, make sure we signal that we "initialized". runInitializeCallbackIfAny(); @@ -87,16 +91,12 @@ RdsRouteConfigProviderImpl::~RdsRouteConfigProviderImpl() { // hold a shared_ptr to it. The RouteConfigProviderManager holds weak_ptrs to the // RdsRouteConfigProviders. Therefore, the map entry for the RdsRouteConfigProvider has to get // cleaned by the RdsRouteConfigProvider's destructor. - route_config_provider_manager_.route_config_providers_.erase(manager_identifier_); + route_config_provider_manager_.route_config_subscriptions_.erase(manager_identifier_); } -Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() { - return tls_->getTyped().config_; -} - -void RdsRouteConfigProviderImpl::onConfigUpdate(const ResourceVector& resources, +void RdsRouteConfigSubscription::onConfigUpdate(const ResourceVector& resources, const std::string& version_info) { - last_updated_ = factory_context_.systemTimeSource().currentTime(); + last_updated_ = time_source_.currentTime(); if (resources.empty()) { ENVOY_LOG(debug, "Missing RouteConfiguration for {} in onConfigUpdate()", route_config_name_); @@ -114,35 +114,79 @@ void RdsRouteConfigProviderImpl::onConfigUpdate(const ResourceVector& resources, throw EnvoyException(fmt::format("Unexpected RDS configuration (expecting {}): {}", route_config_name_, route_config.name())); } + const uint64_t new_hash = MessageUtil::hash(route_config); if (!config_info_ || new_hash != config_info_.value().last_config_hash_) { - ConfigConstSharedPtr new_config(new ConfigImpl(route_config, factory_context_, false)); config_info_ = {new_hash, version_info}; + route_config_proto_ = route_config; stats_.config_reload_.inc(); ENVOY_LOG(debug, "rds: loading new configuration: config_name={} hash={}", route_config_name_, new_hash); - tls_->runOnAllThreads( - [this, new_config]() -> void { tls_->getTyped().config_ = new_config; }); - route_config_proto_ = route_config; + for (auto* provider : route_config_providers_) { + provider->onConfigUpdate(); + } } + runInitializeCallbackIfAny(); } -void RdsRouteConfigProviderImpl::onConfigUpdateFailed(const EnvoyException*) { +void RdsRouteConfigSubscription::onConfigUpdateFailed(const EnvoyException*) { // We need to allow server startup to continue, even if we have a bad // config. runInitializeCallbackIfAny(); } -void RdsRouteConfigProviderImpl::runInitializeCallbackIfAny() { +void RdsRouteConfigSubscription::registerInitTarget(Init::Manager& init_manager) { + init_manager.registerTarget(*this); +} + +void RdsRouteConfigSubscription::runInitializeCallbackIfAny() { if (initialize_callback_) { initialize_callback_(); initialize_callback_ = nullptr; } } -void RdsRouteConfigProviderImpl::registerInitTarget(Init::Manager& init_manager) { - init_manager.registerTarget(*this); +RdsRouteConfigProviderImpl::RdsRouteConfigProviderImpl( + RdsRouteConfigSubscriptionSharedPtr&& subscription, + Server::Configuration::FactoryContext& factory_context) + : subscription_(std::move(subscription)), factory_context_(factory_context), + tls_(factory_context.threadLocal().allocateSlot()) { + ConfigConstSharedPtr initial_config; + if (subscription_->config_info_.has_value()) { + initial_config = + std::make_shared(subscription_->route_config_proto_, factory_context_, false); + } else { + initial_config = std::make_shared(); + } + tls_->set([initial_config](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { + return std::make_shared(initial_config); + }); + subscription_->route_config_providers_.insert(this); +} + +RdsRouteConfigProviderImpl::~RdsRouteConfigProviderImpl() { + subscription_->route_config_providers_.erase(this); +} + +Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() { + return tls_->getTyped().config_; +} + +absl::optional RdsRouteConfigProviderImpl::configInfo() const { + if (!subscription_->config_info_) { + return {}; + } else { + return ConfigInfo{subscription_->route_config_proto_, + subscription_->config_info_.value().last_config_version_}; + } +} + +void RdsRouteConfigProviderImpl::onConfigUpdate() { + ConfigConstSharedPtr new_config( + new ConfigImpl(subscription_->route_config_proto_, factory_context_, false)); + tls_->runOnAllThreads( + [this, new_config]() -> void { tls_->getTyped().config_ = new_config; }); } RouteConfigProviderManagerImpl::RouteConfigProviderManagerImpl(Server::Admin& admin) { @@ -150,105 +194,83 @@ RouteConfigProviderManagerImpl::RouteConfigProviderManagerImpl(Server::Admin& ad admin.getConfigTracker().add("routes", [this] { return dumpRouteConfigs(); }); // ConfigTracker keys must be unique. We are asserting that no one has stolen the "routes" key // from us, since the returned entry will be nullptr if the key already exists. - RELEASE_ASSERT(config_tracker_entry_); + RELEASE_ASSERT(config_tracker_entry_, ""); } -std::vector -RouteConfigProviderManagerImpl::getRdsRouteConfigProviders() { - std::vector ret; - ret.reserve(route_config_providers_.size()); - for (const auto& element : route_config_providers_) { - // Because the RouteConfigProviderManager's weak_ptrs only get cleaned up - // in the RdsRouteConfigProviderImpl destructor, and the single threaded nature - // of this code, locking the weak_ptr will not fail. - RouteConfigProviderSharedPtr provider = element.second.lock(); - ASSERT(provider); - ret.push_back(provider); - } - return ret; -}; - -std::vector -RouteConfigProviderManagerImpl::getStaticRouteConfigProviders() { - std::vector providers_strong; - // Collect non-expired providers. - for (const auto& weak_provider : static_route_config_providers_) { - const auto strong_provider = weak_provider.lock(); - if (strong_provider != nullptr) { - providers_strong.push_back(strong_provider); - } - } - - // Replace our stored list of weak_ptrs with the filtered list. - static_route_config_providers_.assign(providers_strong.begin(), providers_strong.end()); - - return providers_strong; -}; - -Router::RouteConfigProviderSharedPtr RouteConfigProviderManagerImpl::getRdsRouteConfigProvider( +Router::RouteConfigProviderPtr RouteConfigProviderManagerImpl::createRdsRouteConfigProvider( const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix) { - // RdsRouteConfigProviders are unique based on their serialized RDS config. + // RdsRouteConfigSubscriptions are unique based on their serialized RDS config. // TODO(htuch): Full serialization here gives large IDs, could get away with a // strong hash instead. const std::string manager_identifier = rds.SerializeAsString(); - auto it = route_config_providers_.find(manager_identifier); - if (it == route_config_providers_.end()) { + RdsRouteConfigSubscriptionSharedPtr subscription; + + auto it = route_config_subscriptions_.find(manager_identifier); + if (it == route_config_subscriptions_.end()) { // std::make_shared does not work for classes with private constructors. There are ways // around it. However, since this is not a performance critical path we err on the side // of simplicity. - std::shared_ptr new_provider{new RdsRouteConfigProviderImpl( - rds, manager_identifier, factory_context, stat_prefix, *this)}; + subscription.reset(new RdsRouteConfigSubscription(rds, manager_identifier, factory_context, + stat_prefix, *this)); - new_provider->registerInitTarget(factory_context.initManager()); + subscription->registerInitTarget(factory_context.initManager()); - route_config_providers_.insert({manager_identifier, new_provider}); - - return new_provider; + route_config_subscriptions_.insert({manager_identifier, subscription}); + } else { + // Because the RouteConfigProviderManager's weak_ptrs only get cleaned up + // in the RdsRouteConfigSubscription destructor, and the single threaded nature + // of this code, locking the weak_ptr will not fail. + subscription = it->second.lock(); } + ASSERT(subscription); - // Because the RouteConfigProviderManager's weak_ptrs only get cleaned up - // in the RdsRouteConfigProviderImpl destructor, and the single threaded nature - // of this code, locking the weak_ptr will not fail. - Router::RouteConfigProviderSharedPtr new_provider = it->second.lock(); - ASSERT(new_provider); + Router::RouteConfigProviderPtr new_provider{ + new RdsRouteConfigProviderImpl(std::move(subscription), factory_context)}; return new_provider; -}; +} -RouteConfigProviderSharedPtr RouteConfigProviderManagerImpl::getStaticRouteConfigProvider( +RouteConfigProviderPtr RouteConfigProviderManagerImpl::createStaticRouteConfigProvider( const envoy::api::v2::RouteConfiguration& route_config, Server::Configuration::FactoryContext& factory_context) { auto provider = - std::make_shared(std::move(route_config), factory_context); - static_route_config_providers_.push_back(provider); + absl::make_unique(route_config, factory_context, *this); + static_route_config_providers_.insert(provider.get()); return provider; } -ProtobufTypes::MessagePtr RouteConfigProviderManagerImpl::dumpRouteConfigs() { +std::unique_ptr +RouteConfigProviderManagerImpl::dumpRouteConfigs() const { auto config_dump = std::make_unique(); - for (const auto& provider : getRdsRouteConfigProviders()) { - auto config_info = provider->configInfo(); - if (config_info) { + for (const auto& element : route_config_subscriptions_) { + // Because the RouteConfigProviderManager's weak_ptrs only get cleaned up + // in the RdsRouteConfigSubscription destructor, and the single threaded nature + // of this code, locking the weak_ptr will not fail. + auto subscription = element.second.lock(); + ASSERT(subscription); + ASSERT(subscription->route_config_providers_.size() > 0); + + if (subscription->config_info_) { auto* dynamic_config = config_dump->mutable_dynamic_route_configs()->Add(); - dynamic_config->set_version_info(config_info.value().version_); - dynamic_config->mutable_route_config()->MergeFrom(config_info.value().config_); - TimestampUtil::systemClockToTimestamp(provider->lastUpdated(), - *(dynamic_config->mutable_last_updated())); + dynamic_config->set_version_info(subscription->config_info_.value().last_config_version_); + dynamic_config->mutable_route_config()->MergeFrom(subscription->route_config_proto_); + TimestampUtil::systemClockToTimestamp(subscription->last_updated_, + *dynamic_config->mutable_last_updated()); } } - for (const auto& provider : getStaticRouteConfigProviders()) { + for (const auto& provider : static_route_config_providers_) { ASSERT(provider->configInfo()); auto* static_config = config_dump->mutable_static_route_configs()->Add(); static_config->mutable_route_config()->MergeFrom(provider->configInfo().value().config_); TimestampUtil::systemClockToTimestamp(provider->lastUpdated(), - *(static_config->mutable_last_updated())); + *static_config->mutable_last_updated()); } - return ProtobufTypes::MessagePtr{std::move(config_dump)}; + return config_dump; } } // namespace Router diff --git a/source/common/router/rds_impl.h b/source/common/router/rds_impl.h index 60f39aa2a809a..b859b0bca41ae 100644 --- a/source/common/router/rds_impl.h +++ b/source/common/router/rds_impl.h @@ -4,7 +4,9 @@ #include #include #include +#include +#include "envoy/admin/v2alpha/config_dump.pb.h" #include "envoy/api/v2/rds.pb.h" #include "envoy/api/v2/route/route.pb.h" #include "envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.pb.h" @@ -34,20 +36,24 @@ class RouteConfigProviderUtil { * @return RouteConfigProviderPtr a new route configuration provider based on the supplied proto * configuration. */ - static RouteConfigProviderSharedPtr + static RouteConfigProviderPtr create(const envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& config, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix, RouteConfigProviderManager& route_config_provider_manager); }; +class RouteConfigProviderManagerImpl; + /** * Implementation of RouteConfigProvider that holds a static route configuration. */ class StaticRouteConfigProviderImpl : public RouteConfigProvider { public: StaticRouteConfigProviderImpl(const envoy::api::v2::RouteConfiguration& config, - Server::Configuration::FactoryContext& factory_context); + Server::Configuration::FactoryContext& factory_context, + RouteConfigProviderManagerImpl& route_config_provider_manager); + ~StaticRouteConfigProviderImpl(); // Router::RouteConfigProvider Router::ConfigConstSharedPtr config() override { return config_; } @@ -60,6 +66,7 @@ class StaticRouteConfigProviderImpl : public RouteConfigProvider { ConfigConstSharedPtr config_; envoy::api::v2::RouteConfiguration route_config_proto_; SystemTime last_updated_; + RouteConfigProviderManagerImpl& route_config_provider_manager_; }; /** @@ -79,19 +86,18 @@ struct RdsStats { ALL_RDS_STATS(GENERATE_COUNTER_STRUCT) }; -class RouteConfigProviderManagerImpl; +class RdsRouteConfigProviderImpl; /** - * Implementation of RouteConfigProvider that fetches the route configuration dynamically using - * the RDS API. + * A class that fetches the route configuration dynamically using the RDS API and updates them to + * RDS config providers. */ -class RdsRouteConfigProviderImpl - : public RouteConfigProvider, - public Init::Target, +class RdsRouteConfigSubscription + : public Init::Target, Envoy::Config::SubscriptionCallbacks, Logger::Loggable { public: - ~RdsRouteConfigProviderImpl(); + ~RdsRouteConfigSubscription(); // Init::Target void initialize(std::function callback) override { @@ -99,17 +105,6 @@ class RdsRouteConfigProviderImpl subscription_->start({route_config_name_}, *this); } - // Router::RouteConfigProvider - Router::ConfigConstSharedPtr config() override; - absl::optional configInfo() const override { - if (!config_info_) { - return {}; - } else { - return ConfigInfo{route_config_proto_, config_info_.value().last_config_version_}; - } - } - SystemTime lastUpdated() const override { return last_updated_; } - // Config::SubscriptionCallbacks void onConfigUpdate(const ResourceVector& resources, const std::string& version_info) override; void onConfigUpdateFailed(const EnvoyException* e) override; @@ -118,18 +113,12 @@ class RdsRouteConfigProviderImpl } private: - struct ThreadLocalConfig : public ThreadLocal::ThreadLocalObject { - ThreadLocalConfig(ConfigConstSharedPtr initial_config) : config_(initial_config) {} - - ConfigConstSharedPtr config_; - }; - struct LastConfigInfo { uint64_t last_config_hash_; std::string last_config_version_; }; - RdsRouteConfigProviderImpl( + RdsRouteConfigSubscription( const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, const std::string& manager_identifier, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix, @@ -138,19 +127,55 @@ class RdsRouteConfigProviderImpl void registerInitTarget(Init::Manager& init_manager); void runInitializeCallbackIfAny(); - Server::Configuration::FactoryContext& factory_context_; std::unique_ptr> subscription_; - ThreadLocal::SlotPtr tls_; - std::string config_source_; + std::function initialize_callback_; const std::string route_config_name_; - absl::optional config_info_; Stats::ScopePtr scope_; RdsStats stats_; - std::function initialize_callback_; RouteConfigProviderManagerImpl& route_config_provider_manager_; const std::string manager_identifier_; - envoy::api::v2::RouteConfiguration route_config_proto_; + SystemTimeSource& time_source_; SystemTime last_updated_; + absl::optional config_info_; + envoy::api::v2::RouteConfiguration route_config_proto_; + std::unordered_set route_config_providers_; + + friend class RouteConfigProviderManagerImpl; + friend class RdsRouteConfigProviderImpl; +}; + +typedef std::shared_ptr RdsRouteConfigSubscriptionSharedPtr; + +/** + * Implementation of RouteConfigProvider that fetches the route configuration dynamically using + * the subscription. + */ +class RdsRouteConfigProviderImpl : public RouteConfigProvider, + Logger::Loggable { +public: + ~RdsRouteConfigProviderImpl(); + + RdsRouteConfigSubscription& subscription() { return *subscription_; } + void onConfigUpdate(); + + // Router::RouteConfigProvider + Router::ConfigConstSharedPtr config() override; + absl::optional configInfo() const override; + SystemTime lastUpdated() const override { return subscription_->last_updated_; } + +private: + struct ThreadLocalConfig : public ThreadLocal::ThreadLocalObject { + ThreadLocalConfig(ConfigConstSharedPtr initial_config) : config_(initial_config) {} + + ConfigConstSharedPtr config_; + }; + + RdsRouteConfigProviderImpl(RdsRouteConfigSubscriptionSharedPtr&& subscription, + Server::Configuration::FactoryContext& factory_context); + + RdsRouteConfigSubscriptionSharedPtr subscription_; + Server::Configuration::FactoryContext& factory_context_; + ThreadLocal::SlotPtr tls_; friend class RouteConfigProviderManagerImpl; }; @@ -160,32 +185,29 @@ class RouteConfigProviderManagerImpl : public RouteConfigProviderManager, public: RouteConfigProviderManagerImpl(Server::Admin& admin); - // RouteConfigProviderManager - std::vector getRdsRouteConfigProviders() override; - std::vector getStaticRouteConfigProviders() override; + std::unique_ptr dumpRouteConfigs() const; - RouteConfigProviderSharedPtr getRdsRouteConfigProvider( + // RouteConfigProviderManager + RouteConfigProviderPtr createRdsRouteConfigProvider( const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix) override; - RouteConfigProviderSharedPtr - getStaticRouteConfigProvider(const envoy::api::v2::RouteConfiguration& route_config, - Server::Configuration::FactoryContext& factory_context) override; + RouteConfigProviderPtr + createStaticRouteConfigProvider(const envoy::api::v2::RouteConfiguration& route_config, + Server::Configuration::FactoryContext& factory_context) override; private: - ProtobufTypes::MessagePtr dumpRouteConfigs(); - // TODO(jsedgwick) These two members are prime candidates for the owned-entry list/map // as in ConfigTracker. I.e. the ProviderImpls would have an EntryOwner for these lists - // Then the lifetime management stuff is centralized and opaque. Plus the copypasta - // in getRdsRouteConfigProviders()/getStaticRouteConfigProviders() goes away. - std::unordered_map> - route_config_providers_; - std::vector> static_route_config_providers_; + // Then the lifetime management stuff is centralized and opaque. + std::unordered_map> + route_config_subscriptions_; + std::unordered_set static_route_config_providers_; Server::ConfigTracker::EntryOwnerPtr config_tracker_entry_; - friend class RdsRouteConfigProviderImpl; + friend class RdsRouteConfigSubscription; + friend class StaticRouteConfigProviderImpl; }; } // namespace Router diff --git a/source/common/router/rds_subscription.cc b/source/common/router/rds_subscription.cc index fcfebd95e8727..53d8b172a069b 100644 --- a/source/common/router/rds_subscription.cc +++ b/source/common/router/rds_subscription.cc @@ -1,5 +1,7 @@ #include "common/router/rds_subscription.h" +#include "envoy/stats/stats.h" + #include "common/common/assert.h" #include "common/common/fmt.h" #include "common/config/rds_json.h" @@ -14,12 +16,12 @@ RdsSubscription::RdsSubscription( Envoy::Config::SubscriptionStats stats, const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, Upstream::ClusterManager& cm, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, - const LocalInfo::LocalInfo& local_info) + const LocalInfo::LocalInfo& local_info, const Stats::Scope& scope) : RestApiFetcher(cm, rds.config_source().api_config_source().cluster_names()[0], dispatcher, random, Envoy::Config::Utility::apiConfigSourceRefreshDelay( rds.config_source().api_config_source())), - local_info_(local_info), stats_(stats) { + local_info_(local_info), stats_(stats), scope_(scope) { const auto& api_config_source = rds.config_source().api_config_source(); UNREFERENCED_PARAMETER(api_config_source); // If we are building an RdsSubscription, the ConfigSource should be REST_LEGACY. @@ -45,7 +47,8 @@ void RdsSubscription::parseResponse(const Http::Message& response) { const std::string response_body = response.bodyAsString(); Json::ObjectSharedPtr response_json = Json::Factory::loadFromString(response_body); Protobuf::RepeatedPtrField resources; - Envoy::Config::RdsJson::translateRouteConfiguration(*response_json, *resources.Add()); + Envoy::Config::RdsJson::translateRouteConfiguration(*response_json, *resources.Add(), + scope_.statsOptions()); resources[0].set_name(route_config_name_); std::pair hash = Envoy::Config::Utility::computeHashedVersion(response_body); diff --git a/source/common/router/rds_subscription.h b/source/common/router/rds_subscription.h index d27217af5616a..a44e77d152a5c 100644 --- a/source/common/router/rds_subscription.h +++ b/source/common/router/rds_subscription.h @@ -5,6 +5,7 @@ #include "envoy/config/filter/network/http_connection_manager/v2/http_connection_manager.pb.h" #include "envoy/config/subscription.h" +#include "envoy/stats/stats.h" #include "common/common/assert.h" #include "common/http/rest_api_fetcher.h" @@ -23,7 +24,8 @@ class RdsSubscription : public Http::RestApiFetcher, RdsSubscription(Envoy::Config::SubscriptionStats stats, const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, Upstream::ClusterManager& cm, Event::Dispatcher& dispatcher, - Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info); + Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info, + const Stats::Scope& scope); private: // Config::Subscription @@ -42,7 +44,7 @@ class RdsSubscription : public Http::RestApiFetcher, // We should never hit this at runtime, since this legacy adapter is only used by HTTP // connection manager that doesn't do dynamic modification of resources. UNREFERENCED_PARAMETER(resources); - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } // Http::RestApiFetcher @@ -55,6 +57,7 @@ class RdsSubscription : public Http::RestApiFetcher, const LocalInfo::LocalInfo& local_info_; Envoy::Config::SubscriptionCallbacks* callbacks_ = nullptr; Envoy::Config::SubscriptionStats stats_; + const Stats::Scope& scope_; }; } // namespace Router diff --git a/source/common/router/retry_state_impl.cc b/source/common/router/retry_state_impl.cc index e925da9193f8c..1aba35700e1aa 100644 --- a/source/common/router/retry_state_impl.cc +++ b/source/common/router/retry_state_impl.cc @@ -71,22 +71,20 @@ RetryStateImpl::RetryStateImpl(const RetryPolicy& route_policy, Http::HeaderMap& // Merge in the route policy. retry_on_ |= route_policy.retryOn(); retries_remaining_ = std::max(retries_remaining_, route_policy.numRetries()); + const uint32_t base = runtime_.snapshot().getInteger("upstream.base_retry_backoff_ms", 25); + // Cap the max interval to 10 times the base interval to ensure reasonable backoff intervals. + backoff_strategy_ = std::make_unique(base, base * 10, random_); } RetryStateImpl::~RetryStateImpl() { resetRetry(); } void RetryStateImpl::enableBackoffTimer() { - // We use a fully jittered exponential backoff algorithm. - current_retry_++; - uint32_t multiplier = (1 << current_retry_) - 1; - uint64_t base = runtime_.snapshot().getInteger("upstream.base_retry_backoff_ms", 25); - uint64_t timeout = random_.random() % (base * multiplier); - if (!retry_timer_) { retry_timer_ = dispatcher_.createTimer([this]() -> void { callback_(); }); } - retry_timer_->enableTimer(std::chrono::milliseconds(timeout)); + // We use a fully jittered exponential backoff algorithm. + retry_timer_->enableTimer(std::chrono::milliseconds(backoff_strategy_->nextBackOffMs())); } uint32_t RetryStateImpl::parseRetryOn(absl::string_view config) { diff --git a/source/common/router/retry_state_impl.h b/source/common/router/retry_state_impl.h index 8f01320b83a3a..71d7bf38cb5a4 100644 --- a/source/common/router/retry_state_impl.h +++ b/source/common/router/retry_state_impl.h @@ -10,6 +10,8 @@ #include "envoy/runtime/runtime.h" #include "envoy/upstream/upstream.h" +#include "common/common/backoff_strategy.h" + #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -55,10 +57,10 @@ class RetryStateImpl : public RetryState { Event::Dispatcher& dispatcher_; uint32_t retry_on_{}; uint32_t retries_remaining_{1}; - uint32_t current_retry_{}; DoRetryCallback callback_; Event::TimerPtr retry_timer_; Upstream::ResourcePriority priority_; + BackOffStrategyPtr backoff_strategy_; }; } // namespace Router diff --git a/source/common/router/router_ratelimit.cc b/source/common/router/router_ratelimit.cc index 4713f6e467e0f..f1f084265cd5a 100644 --- a/source/common/router/router_ratelimit.cc +++ b/source/common/router/router_ratelimit.cc @@ -108,7 +108,7 @@ RateLimitPolicyEntryImpl::RateLimitPolicyEntryImpl(const envoy::api::v2::route:: actions_.emplace_back(new HeaderValueMatchAction(action.header_value_match())); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } } diff --git a/source/common/ssl/BUILD b/source/common/ssl/BUILD index 486d0da347dca..fe3c6db0d1824 100644 --- a/source/common/ssl/BUILD +++ b/source/common/ssl/BUILD @@ -16,6 +16,7 @@ envoy_cc_library( deps = [ ":context_config_lib", ":context_lib", + ":utility_lib", "//include/envoy/network:connection_interface", "//include/envoy/network:transport_socket_interface", "//source/common/common:assert_lib", @@ -57,6 +58,7 @@ envoy_cc_library( ], external_deps = ["ssl"], deps = [ + ":utility_lib", "//include/envoy/runtime:runtime_interface", "//include/envoy/ssl:context_config_interface", "//include/envoy/ssl:context_interface", @@ -79,3 +81,12 @@ envoy_cc_library( "@envoy_api//envoy/api/v2/auth:cert_cc", ], ) + +envoy_cc_library( + name = "utility_lib", + srcs = ["utility.cc"], + hdrs = ["utility.h"], + external_deps = [ + "ssl", + ], +) diff --git a/source/common/ssl/context_config_impl.cc b/source/common/ssl/context_config_impl.cc index 374cd2945b50a..1b2fee09b383d 100644 --- a/source/common/ssl/context_config_impl.cc +++ b/source/common/ssl/context_config_impl.cc @@ -132,10 +132,10 @@ unsigned ContextConfigImpl::tlsVersionFromProto( case envoy::api::v2::auth::TlsParameters::TLSv1_3: return TLS1_3_VERSION; default: - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } ClientContextConfigImpl::ClientContextConfigImpl( diff --git a/source/common/ssl/context_impl.cc b/source/common/ssl/context_impl.cc index ec9ca6dff93eb..d2960feebab8a 100644 --- a/source/common/ssl/context_impl.cc +++ b/source/common/ssl/context_impl.cc @@ -12,6 +12,7 @@ #include "common/common/base64.h" #include "common/common/fmt.h" #include "common/common/hex.h" +#include "common/ssl/utility.h" #include "openssl/hmac.h" #include "openssl/rand.h" @@ -23,25 +24,23 @@ namespace Ssl { int ContextImpl::sslContextIndex() { CONSTRUCT_ON_FIRST_USE(int, []() -> int { int ssl_context_index = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); - RELEASE_ASSERT(ssl_context_index >= 0); + RELEASE_ASSERT(ssl_context_index >= 0, ""); return ssl_context_index; }()); } -ContextImpl::ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ContextConfig& config) - : parent_(parent), ctx_(SSL_CTX_new(TLS_method())), scope_(scope), - stats_(generateStats(scope)) { - RELEASE_ASSERT(ctx_); +ContextImpl::ContextImpl(Stats::Scope& scope, const ContextConfig& config) + : ctx_(SSL_CTX_new(TLS_method())), scope_(scope), stats_(generateStats(scope)) { + RELEASE_ASSERT(ctx_, ""); int rc = SSL_CTX_set_ex_data(ctx_.get(), sslContextIndex(), this); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); rc = SSL_CTX_set_min_proto_version(ctx_.get(), config.minProtocolVersion()); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); rc = SSL_CTX_set_max_proto_version(ctx_.get(), config.maxProtocolVersion()); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); if (!SSL_CTX_set_strict_cipher_list(ctx_.get(), config.cipherSuites().c_str())) { throw EnvoyException( @@ -58,7 +57,7 @@ ContextImpl::ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, ca_file_path_ = config.caCertPath(); bssl::UniquePtr bio( BIO_new_mem_buf(const_cast(config.caCert().data()), config.caCert().size())); - RELEASE_ASSERT(bio != nullptr); + RELEASE_ASSERT(bio != nullptr, ""); // Based on BoringSSL's X509_load_cert_crl_file(). bssl::UniquePtr list( PEM_X509_INFO_read_bio(bio.get(), nullptr, nullptr, nullptr)); @@ -100,7 +99,7 @@ ContextImpl::ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, bssl::UniquePtr bio( BIO_new_mem_buf(const_cast(config.certificateRevocationList().data()), config.certificateRevocationList().size())); - RELEASE_ASSERT(bio != nullptr); + RELEASE_ASSERT(bio != nullptr, ""); // Based on BoringSSL's X509_load_cert_crl_file(). bssl::UniquePtr list( @@ -167,7 +166,7 @@ ContextImpl::ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, cert_chain_file_path_ = config.certChainPath(); bssl::UniquePtr bio( BIO_new_mem_buf(const_cast(config.certChain().data()), config.certChain().size())); - RELEASE_ASSERT(bio != nullptr); + RELEASE_ASSERT(bio != nullptr, ""); cert_chain_.reset(PEM_read_bio_X509_AUX(bio.get(), nullptr, nullptr, nullptr)); if (cert_chain_ == nullptr || !SSL_CTX_use_certificate(ctx_.get(), cert_chain_.get())) { throw EnvoyException( @@ -198,7 +197,7 @@ ContextImpl::ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, // Load private key. bio.reset( BIO_new_mem_buf(const_cast(config.privateKey().data()), config.privateKey().size())); - RELEASE_ASSERT(bio != nullptr); + RELEASE_ASSERT(bio != nullptr, ""); bssl::UniquePtr pkey(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); if (pkey == nullptr || !SSL_CTX_use_PrivateKey(ctx_.get(), pkey.get())) { throw EnvoyException( @@ -380,7 +379,7 @@ bool ContextImpl::verifyCertificateHashList( std::vector computed_hash(SHA256_DIGEST_LENGTH); unsigned int n; X509_digest(cert, EVP_sha256(), computed_hash.data(), &n); - RELEASE_ASSERT(n == computed_hash.size()); + RELEASE_ASSERT(n == computed_hash.size(), ""); for (const auto& expected_hash : expected_hashes) { if (computed_hash == expected_hash) { @@ -446,7 +445,7 @@ std::string ContextImpl::getCaCertInformation() const { return ""; } return fmt::format("Certificate Path: {}, Serial Number: {}, Days until Expiration: {}", - getCaFileName(), getSerialNumber(ca_cert_.get()), + getCaFileName(), Utility::getSerialNumberFromCertificate(*ca_cert_.get()), getDaysUntilExpiration(ca_cert_.get())); } @@ -455,34 +454,18 @@ std::string ContextImpl::getCertChainInformation() const { return ""; } return fmt::format("Certificate Path: {}, Serial Number: {}, Days until Expiration: {}", - getCertChainFileName(), getSerialNumber(cert_chain_.get()), + getCertChainFileName(), + Utility::getSerialNumberFromCertificate(*cert_chain_.get()), getDaysUntilExpiration(cert_chain_.get())); } -std::string ContextImpl::getSerialNumber(const X509* cert) { - ASSERT(cert); - ASN1_INTEGER* serial_number = X509_get_serialNumber(const_cast(cert)); - BIGNUM num_bn; - BN_init(&num_bn); - ASN1_INTEGER_to_BN(serial_number, &num_bn); - char* char_serial_number = BN_bn2hex(&num_bn); - BN_free(&num_bn); - if (char_serial_number != nullptr) { - std::string serial_number(char_serial_number); - OPENSSL_free(char_serial_number); - return serial_number; - } - return ""; -} - -ClientContextImpl::ClientContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ClientContextConfig& config) - : ContextImpl(parent, scope, config), server_name_indication_(config.serverNameIndication()), +ClientContextImpl::ClientContextImpl(Stats::Scope& scope, const ClientContextConfig& config) + : ContextImpl(scope, config), server_name_indication_(config.serverNameIndication()), allow_renegotiation_(config.allowRenegotiation()) { if (!parsed_alpn_protocols_.empty()) { int rc = SSL_CTX_set_alpn_protos(ctx_.get(), &parsed_alpn_protocols_[0], parsed_alpn_protocols_.size()); - RELEASE_ASSERT(rc == 0); + RELEASE_ASSERT(rc == 0, ""); } } @@ -491,7 +474,7 @@ bssl::UniquePtr ClientContextImpl::newSsl() const { if (!server_name_indication_.empty()) { int rc = SSL_set_tlsext_host_name(ssl_con.get(), server_name_indication_.c_str()); - RELEASE_ASSERT(rc); + RELEASE_ASSERT(rc, ""); } if (allow_renegotiation_) { @@ -501,11 +484,10 @@ bssl::UniquePtr ClientContextImpl::newSsl() const { return ssl_con; } -ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ServerContextConfig& config, +ServerContextImpl::ServerContextImpl(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names, Runtime::Loader& runtime) - : ContextImpl(parent, scope, config), runtime_(runtime), + : ContextImpl(scope, config), runtime_(runtime), session_ticket_keys_(config.sessionTicketKeys()) { if (config.certChain().empty()) { throw EnvoyException("Server TlsCertificates must have a certificate specified"); @@ -513,11 +495,11 @@ ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& s if (!config.caCert().empty()) { bssl::UniquePtr bio( BIO_new_mem_buf(const_cast(config.caCert().data()), config.caCert().size())); - RELEASE_ASSERT(bio != nullptr); + RELEASE_ASSERT(bio != nullptr, ""); // Based on BoringSSL's SSL_add_file_cert_subjects_to_stack(). bssl::UniquePtr list(sk_X509_NAME_new( [](const X509_NAME** a, const X509_NAME** b) -> int { return X509_NAME_cmp(*a, *b); })); - RELEASE_ASSERT(list != nullptr); + RELEASE_ASSERT(list != nullptr, ""); for (;;) { bssl::UniquePtr cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); if (cert == nullptr) { @@ -574,7 +556,7 @@ ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& s ContextImpl* context_impl = static_cast( SSL_CTX_get_ex_data(SSL_get_SSL_CTX(ssl), sslContextIndex())); ServerContextImpl* server_context_impl = dynamic_cast(context_impl); - RELEASE_ASSERT(server_context_impl != nullptr); // for Coverity + RELEASE_ASSERT(server_context_impl != nullptr, ""); // for Coverity return server_context_impl->sessionTicketProcess(ssl, key_name, iv, ctx, hmac_ctx, encrypt); }); @@ -584,25 +566,25 @@ ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& s unsigned session_context_len = 0; EVP_MD_CTX md; int rc = EVP_DigestInit(&md, EVP_sha256()); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); // Hash the CommonName/SANs of the server certificate. This makes sure that // sessions can only be resumed to a certificate for the same name, but allows // resuming to unique certs in the case that different Envoy instances each have // their own certs. X509* cert = SSL_CTX_get0_certificate(ctx_.get()); - RELEASE_ASSERT(cert != nullptr); + RELEASE_ASSERT(cert != nullptr, ""); X509_NAME* cert_subject = X509_get_subject_name(cert); - RELEASE_ASSERT(cert_subject != nullptr); + RELEASE_ASSERT(cert_subject != nullptr, ""); int cn_index = X509_NAME_get_index_by_NID(cert_subject, NID_commonName, -1); // It's possible that the certificate doesn't have CommonName, but has SANs. if (cn_index >= 0) { X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); - RELEASE_ASSERT(cn_entry != nullptr); + RELEASE_ASSERT(cn_entry != nullptr, ""); ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); - RELEASE_ASSERT(ASN1_STRING_length(cn_asn1) > 0); + RELEASE_ASSERT(ASN1_STRING_length(cn_asn1) > 0, ""); rc = EVP_DigestUpdate(&md, ASN1_STRING_data(cn_asn1), ASN1_STRING_length(cn_asn1)); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } bssl::UniquePtr san_names( @@ -611,19 +593,19 @@ ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& s for (const GENERAL_NAME* san : san_names.get()) { if (san->type == GEN_DNS || san->type == GEN_URI) { rc = EVP_DigestUpdate(&md, ASN1_STRING_data(san->d.ia5), ASN1_STRING_length(san->d.ia5)); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } } } else { // Make sure that we have either CommonName or SANs. - RELEASE_ASSERT(cn_index >= 0); + RELEASE_ASSERT(cn_index >= 0, ""); } X509_NAME* cert_issuer_name = X509_get_issuer_name(cert); rc = X509_NAME_digest(cert_issuer_name, EVP_sha256(), session_context_buf, &session_context_len); - RELEASE_ASSERT(rc == 1 && session_context_len == SHA256_DIGEST_LENGTH); + RELEASE_ASSERT(rc == 1 && session_context_len == SHA256_DIGEST_LENGTH, ""); rc = EVP_DigestUpdate(&md, session_context_buf, session_context_len); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); // Hash all the settings that affect whether the server will allow/accept // the client connection. This ensures that the client is always validated against @@ -631,14 +613,14 @@ ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& s // is enabled. if (ca_cert_ != nullptr) { rc = X509_digest(ca_cert_.get(), EVP_sha256(), session_context_buf, &session_context_len); - RELEASE_ASSERT(rc == 1 && session_context_len == SHA256_DIGEST_LENGTH); + RELEASE_ASSERT(rc == 1 && session_context_len == SHA256_DIGEST_LENGTH, ""); rc = EVP_DigestUpdate(&md, session_context_buf, session_context_len); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); // verify_subject_alt_name_list_ can only be set with a ca_cert for (const std::string& name : verify_subject_alt_name_list_) { rc = EVP_DigestUpdate(&md, name.data(), name.size()); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } } @@ -646,27 +628,27 @@ ServerContextImpl::ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& s rc = EVP_DigestUpdate(&md, hash.data(), hash.size() * sizeof(std::remove_reference::type::value_type)); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } for (const auto& hash : verify_certificate_spki_list_) { rc = EVP_DigestUpdate(&md, hash.data(), hash.size() * sizeof(std::remove_reference::type::value_type)); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } // Hash configured SNIs for this context, so that sessions cannot be resumed across different // filter chains, even when using the same server certificate. for (const auto& name : server_names) { rc = EVP_DigestUpdate(&md, name.data(), name.size()); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } rc = EVP_DigestFinal(&md, session_context_buf, &session_context_len); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); rc = SSL_CTX_set_session_id_context(ctx_.get(), session_context_buf, session_context_len); - RELEASE_ASSERT(rc == 1); + RELEASE_ASSERT(rc == 1, ""); } int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv, @@ -676,7 +658,7 @@ int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv if (encrypt == 1) { // Encrypt - RELEASE_ASSERT(session_ticket_keys_.size() >= 1); + RELEASE_ASSERT(session_ticket_keys_.size() >= 1, ""); // TODO(ggreenway): validate in SDS that session_ticket_keys_ cannot be empty, // or if we allow it to be emptied, reconfigure the context so this callback // isn't set. @@ -692,7 +674,7 @@ int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv // This RELEASE_ASSERT is logically a static_assert, but we can't actually get // EVP_CIPHER_key_length(cipher) at compile-time - RELEASE_ASSERT(key.aes_key_.size() == EVP_CIPHER_key_length(cipher)); + RELEASE_ASSERT(key.aes_key_.size() == EVP_CIPHER_key_length(cipher), ""); if (!EVP_EncryptInit_ex(ctx, cipher, nullptr, key.aes_key_.data(), iv)) { return -1; } @@ -713,7 +695,7 @@ int ServerContextImpl::sessionTicketProcess(SSL*, uint8_t* key_name, uint8_t* iv return -1; } - RELEASE_ASSERT(key.aes_key_.size() == EVP_CIPHER_key_length(cipher)); + RELEASE_ASSERT(key.aes_key_.size() == EVP_CIPHER_key_length(cipher), ""); if (!EVP_DecryptInit_ex(ctx, cipher, nullptr, key.aes_key_.data(), iv)) { return -1; } diff --git a/source/common/ssl/context_impl.h b/source/common/ssl/context_impl.h index 4001ff23d9719..698769d5b9201 100644 --- a/source/common/ssl/context_impl.h +++ b/source/common/ssl/context_impl.h @@ -75,8 +75,7 @@ class ContextImpl : public virtual Context { std::string getCertChainInformation() const override; protected: - ContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, const ContextConfig& config); - ~ContextImpl() { parent_.releaseContext(this); } + ContextImpl(Stats::Scope& scope, const ContextConfig& config); /** * The global SSL-library index used for storing a pointer to the context @@ -116,12 +115,13 @@ class ContextImpl : public virtual Context { std::vector parseAlpnProtocols(const std::string& alpn_protocols); static SslStats generateStats(Stats::Scope& scope); + + // TODO: Move helper function to the `Ssl::Utility` namespace. int32_t getDaysUntilExpiration(const X509* cert) const; - static std::string getSerialNumber(const X509* cert); + std::string getCaFileName() const { return ca_file_path_; }; std::string getCertChainFileName() const { return cert_chain_file_path_; }; - ContextManagerImpl& parent_; bssl::UniquePtr ctx_; bool verify_trusted_ca_{false}; std::vector verify_subject_alt_name_list_; @@ -136,10 +136,11 @@ class ContextImpl : public virtual Context { std::string cert_chain_file_path_; }; +typedef std::shared_ptr ContextImplSharedPtr; + class ClientContextImpl : public ContextImpl, public ClientContext { public: - ClientContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ClientContextConfig& config); + ClientContextImpl(Stats::Scope& scope, const ClientContextConfig& config); bssl::UniquePtr newSsl() const override; @@ -150,9 +151,8 @@ class ClientContextImpl : public ContextImpl, public ClientContext { class ServerContextImpl : public ContextImpl, public ServerContext { public: - ServerContextImpl(ContextManagerImpl& parent, Stats::Scope& scope, - const ServerContextConfig& config, const std::vector& server_names, - Runtime::Loader& runtime); + ServerContextImpl(Stats::Scope& scope, const ServerContextConfig& config, + const std::vector& server_names, Runtime::Loader& runtime); private: int alpnSelectCallback(const unsigned char** out, unsigned char* outlen, const unsigned char* in, diff --git a/source/common/ssl/context_manager_impl.cc b/source/common/ssl/context_manager_impl.cc index 1c54c2f018656..b82fd96151af6 100644 --- a/source/common/ssl/context_manager_impl.cc +++ b/source/common/ssl/context_manager_impl.cc @@ -9,47 +9,50 @@ namespace Envoy { namespace Ssl { -ContextManagerImpl::~ContextManagerImpl() { ASSERT(contexts_.empty()); } - -void ContextManagerImpl::releaseContext(Context* context) { - std::unique_lock lock(contexts_lock_); +ContextManagerImpl::~ContextManagerImpl() { + removeEmptyContexts(); + ASSERT(contexts_.empty()); +} - // context may not be found, in the case that a subclass of Context throws - // in it's constructor. In that case the context did not get added, but - // the destructor of Context will run and call releaseContext(). - contexts_.remove(context); +void ContextManagerImpl::removeEmptyContexts() { + contexts_.remove_if([](const std::weak_ptr& n) { return n.expired(); }); } -ClientContextPtr ContextManagerImpl::createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) { - ClientContextPtr context(new ClientContextImpl(*this, scope, config)); - std::unique_lock lock(contexts_lock_); - contexts_.emplace_back(context.get()); +ClientContextSharedPtr +ContextManagerImpl::createSslClientContext(Stats::Scope& scope, const ClientContextConfig& config) { + ClientContextSharedPtr context = std::make_shared(scope, config); + removeEmptyContexts(); + contexts_.emplace_back(context); return context; } -ServerContextPtr +ServerContextSharedPtr ContextManagerImpl::createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names) { - ServerContextPtr context(new ServerContextImpl(*this, scope, config, server_names, runtime_)); - std::unique_lock lock(contexts_lock_); - contexts_.emplace_back(context.get()); + ServerContextSharedPtr context = + std::make_shared(scope, config, server_names, runtime_); + removeEmptyContexts(); + contexts_.emplace_back(context); return context; } size_t ContextManagerImpl::daysUntilFirstCertExpires() const { - std::shared_lock lock(contexts_lock_); size_t ret = std::numeric_limits::max(); - for (Context* context : contexts_) { - ret = std::min(context->daysUntilFirstCertExpires(), ret); + for (const auto& ctx_weak_ptr : contexts_) { + ContextSharedPtr context = ctx_weak_ptr.lock(); + if (context) { + ret = std::min(context->daysUntilFirstCertExpires(), ret); + } } return ret; } void ContextManagerImpl::iterateContexts(std::function callback) { - std::shared_lock lock(contexts_lock_); - for (Context* context : contexts_) { - callback(*context); + for (const auto& ctx_weak_ptr : contexts_) { + ContextSharedPtr context = ctx_weak_ptr.lock(); + if (context) { + callback(*context); + } } } diff --git a/source/common/ssl/context_manager_impl.h b/source/common/ssl/context_manager_impl.h index bd31db4f22008..d330e0dabed4f 100644 --- a/source/common/ssl/context_manager_impl.h +++ b/source/common/ssl/context_manager_impl.h @@ -22,26 +22,19 @@ class ContextManagerImpl final : public ContextManager { ContextManagerImpl(Runtime::Loader& runtime) : runtime_(runtime) {} ~ContextManagerImpl(); - /** - * Allocated contexts are owned by the caller. However, we need to be able to iterate them for - * admin purposes. When a caller frees a context it will tell us to release it also from the list - * of contexts. - */ - void releaseContext(Context* context); - // Ssl::ContextManager - Ssl::ClientContextPtr createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) override; - Ssl::ServerContextPtr + Ssl::ClientContextSharedPtr createSslClientContext(Stats::Scope& scope, + const ClientContextConfig& config) override; + Ssl::ServerContextSharedPtr createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, const std::vector& server_names) override; size_t daysUntilFirstCertExpires() const override; void iterateContexts(std::function callback) override; private: + void removeEmptyContexts(); Runtime::Loader& runtime_; - std::list contexts_; - mutable std::shared_timed_mutex contexts_lock_; + std::list> contexts_; }; } // namespace Ssl diff --git a/source/common/ssl/ssl_socket.cc b/source/common/ssl/ssl_socket.cc index b433358f16532..ec6d9967c6aa8 100644 --- a/source/common/ssl/ssl_socket.cc +++ b/source/common/ssl/ssl_socket.cc @@ -4,6 +4,7 @@ #include "common/common/empty_string.h" #include "common/common/hex.h" #include "common/http/headers.h" +#include "common/ssl/utility.h" #include "absl/strings/str_replace.h" #include "openssl/err.h" @@ -14,8 +15,8 @@ using Envoy::Network::PostIoAction; namespace Envoy { namespace Ssl { -SslSocket::SslSocket(Context& ctx, InitialState state) - : ctx_(dynamic_cast(ctx)), ssl_(ctx_.newSsl()) { +SslSocket::SslSocket(ContextSharedPtr ctx, InitialState state) + : ctx_(std::dynamic_pointer_cast(ctx)), ssl_(ctx_->newSsl()) { if (state == InitialState::Client) { SSL_set_connect_state(ssl_.get()); } else { @@ -98,7 +99,7 @@ PostIoAction SslSocket::doHandshake() { if (rc == 1) { ENVOY_CONN_LOG(debug, "handshake complete", callbacks_->connection()); handshake_complete_ = true; - ctx_.logHandshake(ssl_.get()); + ctx_->logHandshake(ssl_.get()); callbacks_->raiseEvent(Network::ConnectionEvent::Connected); // It's possible that we closed during the handshake callback. @@ -125,7 +126,7 @@ void SslSocket::drainErrorQueue() { while (uint64_t err = ERR_get_error()) { if (ERR_GET_LIB(err) == ERR_LIB_SSL) { if (ERR_GET_REASON(err) == SSL_R_PEER_DID_NOT_RETURN_A_CERTIFICATE) { - ctx_.stats().fail_verify_no_cert_.inc(); + ctx_->stats().fail_verify_no_cert_.inc(); saw_counted_error = true; } else if (ERR_GET_REASON(err) == SSL_R_CERTIFICATE_VERIFY_FAILED) { saw_counted_error = true; @@ -138,7 +139,7 @@ void SslSocket::drainErrorQueue() { ERR_reason_error_string(err)); } if (saw_error && !saw_counted_error) { - ctx_.stats().connection_error_.inc(); + ctx_->stats().connection_error_.inc(); } } @@ -246,7 +247,7 @@ const std::string& SslSocket::sha256PeerCertificateDigest() const { std::vector computed_hash(SHA256_DIGEST_LENGTH); unsigned int n; X509_digest(cert.get(), EVP_sha256(), computed_hash.data(), &n); - RELEASE_ASSERT(n == computed_hash.size()); + RELEASE_ASSERT(n == computed_hash.size(), ""); cached_sha_256_peer_certificate_digest_ = Hex::encode(computed_hash); return cached_sha_256_peer_certificate_digest_; } @@ -262,11 +263,11 @@ const std::string& SslSocket::urlEncodedPemEncodedPeerCertificate() const { } bssl::UniquePtr buf(BIO_new(BIO_s_mem())); - RELEASE_ASSERT(buf != nullptr); - RELEASE_ASSERT(PEM_write_bio_X509(buf.get(), cert.get()) == 1); + RELEASE_ASSERT(buf != nullptr, ""); + RELEASE_ASSERT(PEM_write_bio_X509(buf.get(), cert.get()) == 1, ""); const uint8_t* output; size_t length; - RELEASE_ASSERT(BIO_mem_contents(buf.get(), &output, &length) == 1); + RELEASE_ASSERT(BIO_mem_contents(buf.get(), &output, &length) == 1, ""); absl::string_view pem(reinterpret_cast(output), length); cached_url_encoded_pem_encoded_peer_certificate_ = absl::StrReplaceAll( pem, {{"\n", "%0A"}, {" ", "%20"}, {"+", "%2B"}, {"/", "%2F"}, {"=", "%3D"}}); @@ -339,9 +340,17 @@ std::string SslSocket::protocol() const { return std::string(reinterpret_cast(proto), proto_len); } +std::string SslSocket::serialNumberPeerCertificate() const { + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl_.get())); + if (!cert) { + return ""; + } + return Utility::getSerialNumberFromCertificate(*cert.get()); +} + std::string SslSocket::getSubjectFromCertificate(X509* cert) const { bssl::UniquePtr buf(BIO_new(BIO_s_mem())); - RELEASE_ASSERT(buf != nullptr); + RELEASE_ASSERT(buf != nullptr, ""); // flags=XN_FLAG_RFC2253 is the documented parameter for single-line output in RFC 2253 format. // Example from the RFC: @@ -379,7 +388,7 @@ ClientSslSocketFactory::ClientSslSocketFactory(const ClientContextConfig& config : ssl_ctx_(manager.createSslClientContext(stats_scope, config)) {} Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket() const { - return std::make_unique(*ssl_ctx_, Ssl::InitialState::Client); + return std::make_unique(ssl_ctx_, Ssl::InitialState::Client); } bool ClientSslSocketFactory::implementsSecureTransport() const { return true; } @@ -391,7 +400,7 @@ ServerSslSocketFactory::ServerSslSocketFactory(const ServerContextConfig& config : ssl_ctx_(manager.createSslServerContext(stats_scope, config, server_names)) {} Network::TransportSocketPtr ServerSslSocketFactory::createTransportSocket() const { - return std::make_unique(*ssl_ctx_, Ssl::InitialState::Server); + return std::make_unique(ssl_ctx_, Ssl::InitialState::Server); } bool ServerSslSocketFactory::implementsSecureTransport() const { return true; } diff --git a/source/common/ssl/ssl_socket.h b/source/common/ssl/ssl_socket.h index 6bb040edcd7ef..68fec106eb916 100644 --- a/source/common/ssl/ssl_socket.h +++ b/source/common/ssl/ssl_socket.h @@ -20,12 +20,13 @@ class SslSocket : public Network::TransportSocket, public Connection, protected Logger::Loggable { public: - SslSocket(Context& ctx, InitialState state); + SslSocket(ContextSharedPtr ctx, InitialState state); // Ssl::Connection bool peerCertificatePresented() const override; std::string uriSanLocalCertificate() override; const std::string& sha256PeerCertificateDigest() const override; + std::string serialNumberPeerCertificate() const override; std::string subjectPeerCertificate() const override; std::string subjectLocalCertificate() const override; std::string uriSanPeerCertificate() const override; @@ -50,12 +51,14 @@ class SslSocket : public Network::TransportSocket, Network::PostIoAction doHandshake(); void drainErrorQueue(); void shutdownSsl(); + + // TODO: Move helper functions to the `Ssl::Utility` namespace. std::string getUriSanFromCertificate(X509* cert) const; std::string getSubjectFromCertificate(X509* cert) const; std::vector getDnsSansFromCertificate(X509* cert); Network::TransportSocketCallbacks* callbacks_{}; - ContextImpl& ctx_; + ContextImplSharedPtr ctx_; bssl::UniquePtr ssl_; bool handshake_complete_{}; bool shutdown_sent_{}; @@ -68,22 +71,24 @@ class ClientSslSocketFactory : public Network::TransportSocketFactory { public: ClientSslSocketFactory(const ClientContextConfig& config, Ssl::ContextManager& manager, Stats::Scope& stats_scope); + Network::TransportSocketPtr createTransportSocket() const override; bool implementsSecureTransport() const override; private: - const ClientContextPtr ssl_ctx_; + ClientContextSharedPtr ssl_ctx_; }; class ServerSslSocketFactory : public Network::TransportSocketFactory { public: ServerSslSocketFactory(const ServerContextConfig& config, Ssl::ContextManager& manager, Stats::Scope& stats_scope, const std::vector& server_names); + Network::TransportSocketPtr createTransportSocket() const override; bool implementsSecureTransport() const override; private: - const ServerContextPtr ssl_ctx_; + ServerContextSharedPtr ssl_ctx_; }; } // namespace Ssl diff --git a/source/common/ssl/utility.cc b/source/common/ssl/utility.cc new file mode 100644 index 0000000000000..6f98864c97ac6 --- /dev/null +++ b/source/common/ssl/utility.cc @@ -0,0 +1,22 @@ +#include "common/ssl/utility.h" + +namespace Envoy { +namespace Ssl { + +std::string Utility::getSerialNumberFromCertificate(X509& cert) { + ASN1_INTEGER* serial_number = X509_get_serialNumber(&cert); + BIGNUM num_bn; + BN_init(&num_bn); + ASN1_INTEGER_to_BN(serial_number, &num_bn); + char* char_serial_number = BN_bn2hex(&num_bn); + BN_free(&num_bn); + if (char_serial_number != nullptr) { + std::string serial_number(char_serial_number); + OPENSSL_free(char_serial_number); + return serial_number; + } + return ""; +} + +} // namespace Ssl +} // namespace Envoy diff --git a/source/common/ssl/utility.h b/source/common/ssl/utility.h new file mode 100644 index 0000000000000..cb41056e0f0fb --- /dev/null +++ b/source/common/ssl/utility.h @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "openssl/ssl.h" + +namespace Envoy { +namespace Ssl { +namespace Utility { + +/** + * Retrieve the serial number of a certificate. + * @param ssl the certificate + * @return std::string the serial number field of the certificate. Returns "" if + * there is no serial number. + */ +std::string getSerialNumberFromCertificate(X509& cert); + +} // namespace Utility +} // namespace Ssl +} // namespace Envoy diff --git a/source/common/stats/stats_impl.cc b/source/common/stats/stats_impl.cc index 67ddf34ba08a7..681bc77d640f4 100644 --- a/source/common/stats/stats_impl.cc +++ b/source/common/stats/stats_impl.cc @@ -37,32 +37,15 @@ bool regexStartsWithDot(absl::string_view regex) { } // namespace -uint64_t RawStatData::size() { - // Normally the compiler would do this, but because name_ is a flexible-array-length - // element, the compiler can't. RawStatData is put into an array in HotRestartImpl, so - // it's important that each element starts on the required alignment for the type. - return roundUpMultipleNaturalAlignment(sizeof(RawStatData) + nameSize()); +// Normally the compiler would do this, but because name_ is a flexible-array-length +// element, the compiler can't. RawStatData is put into an array in HotRestartImpl, so +// it's important that each element starts on the required alignment for the type. +uint64_t RawStatData::structSize(uint64_t name_size) { + return roundUpMultipleNaturalAlignment(sizeof(RawStatData) + name_size + 1); } -uint64_t& RawStatData::initializeAndGetMutableMaxObjNameLength(uint64_t configured_size) { - // Like CONSTRUCT_ON_FIRST_USE, but non-const so that the value can be changed by tests - static uint64_t size = configured_size; - return size; -} - -void RawStatData::configure(Server::Options& options) { - const uint64_t configured = options.maxObjNameLength(); - RELEASE_ASSERT(configured > 0); - uint64_t max_obj_name_length = initializeAndGetMutableMaxObjNameLength(configured); - - // If this fails, it means that this function was called too late during - // startup because things were already using this size before it was set. - RELEASE_ASSERT(max_obj_name_length == configured); -} - -void RawStatData::configureForTestsOnly(Server::Options& options) { - const uint64_t configured = options.maxObjNameLength(); - initializeAndGetMutableMaxObjNameLength(configured) = configured; +uint64_t RawStatData::structSizeWithOptions(const StatsOptions& stats_options) { + return structSize(stats_options.maxNameLength()); } std::string Utility::sanitizeStatsName(const std::string& name) { @@ -152,36 +135,35 @@ bool TagExtractorImpl::extractTag(const std::string& stat_name, std::vector return false; } -RawStatData* HeapRawStatDataAllocator::alloc(const std::string& name) { - RawStatData* data = static_cast(::calloc(RawStatData::size(), 1)); - data->initialize(name); +HeapStatData::HeapStatData(absl::string_view key) : name_(key.data(), key.size()) {} - // Because the RawStatData object is initialized with and contains a truncated - // version of the std::string name, storing the stats in a map would require - // storing the name twice. Performing a lookup on the set is similarly - // expensive to performing a map lookup, since both require copying a truncated version of the - // string before doing the hash lookup. +HeapStatData* HeapStatDataAllocator::alloc(absl::string_view name) { + // Any expected truncation of name is done at the callsite. No truncation is + // required to use this allocator. + auto data = std::make_unique(name); Thread::ReleasableLockGuard lock(mutex_); - auto ret = stats_.insert(data); - RawStatData* existing_data = *ret.first; + auto ret = stats_.insert(data.get()); + HeapStatData* existing_data = *ret.first; lock.release(); - if (!ret.second) { - ::free(data); - ++existing_data->ref_count_; - return existing_data; - } else { - return data; + if (ret.second) { + return data.release(); } + ++existing_data->ref_count_; + return existing_data; } /** - * Counter implementation that wraps a RawStatData. + * Counter implementation that wraps a StatData. StatData must have data members: + * std::atomic value_; + * std::atomic pending_increment_; + * std::atomic flags_; + * std::atomic ref_count_; */ -class CounterImpl : public Counter, public MetricImpl { +template class CounterImpl : public Counter, public MetricImpl { public: - CounterImpl(RawStatData& data, RawStatDataAllocator& alloc, std::string&& tag_extracted_name, - std::vector&& tags) + CounterImpl(StatData& data, StatDataAllocatorImpl& alloc, + std::string&& tag_extracted_name, std::vector&& tags) : MetricImpl(data.name_, std::move(tag_extracted_name), std::move(tags)), data_(data), alloc_(alloc) {} ~CounterImpl() { alloc_.free(data_); } @@ -200,17 +182,17 @@ class CounterImpl : public Counter, public MetricImpl { uint64_t value() const override { return data_.value_; } private: - RawStatData& data_; - RawStatDataAllocator& alloc_; + StatData& data_; + StatDataAllocatorImpl& alloc_; }; /** - * Gauge implementation that wraps a RawStatData. + * Gauge implementation that wraps a StatData. */ -class GaugeImpl : public Gauge, public MetricImpl { +template class GaugeImpl : public Gauge, public MetricImpl { public: - GaugeImpl(RawStatData& data, RawStatDataAllocator& alloc, std::string&& tag_extracted_name, - std::vector&& tags) + GaugeImpl(StatData& data, StatDataAllocatorImpl& alloc, + std::string&& tag_extracted_name, std::vector&& tags) : MetricImpl(data.name_, std::move(tag_extracted_name), std::move(tags)), data_(data), alloc_(alloc) {} ~GaugeImpl() { alloc_.free(data_); } @@ -235,8 +217,8 @@ class GaugeImpl : public Gauge, public MetricImpl { bool used() const override { return data_.flags_ & Flags::Used; } private: - RawStatData& data_; - RawStatDataAllocator& alloc_; + StatData& data_; + StatDataAllocatorImpl& alloc_; }; TagProducerImpl::TagProducerImpl(const envoy::config::metrics::v2::StatsConfig& config) { @@ -337,36 +319,28 @@ TagProducerImpl::addDefaultExtractors(const envoy::config::metrics::v2::StatsCon return names; } -void HeapRawStatDataAllocator::free(RawStatData& data) { +// TODO(jmarantz): move this below HeapStatDataAllocator::alloc. +void HeapStatDataAllocator::free(HeapStatData& data) { ASSERT(data.ref_count_ > 0); if (--data.ref_count_ > 0) { return; } - size_t key_removed; { Thread::LockGuard lock(mutex_); - key_removed = stats_.erase(&data); + size_t key_removed = stats_.erase(&data); + ASSERT(key_removed == 1); } - ASSERT(key_removed == 1); - ::free(&data); + delete &data; } -void RawStatData::initialize(absl::string_view key) { +void RawStatData::initialize(absl::string_view key, const StatsOptions& stats_options) { ASSERT(!initialized()); - if (key.size() > Stats::RawStatData::maxNameLength()) { - ENVOY_LOG_MISC( - warn, - "Statistic '{}' is too long with {} characters, it will be truncated to {} characters", key, - key.size(), Stats::RawStatData::maxNameLength()); - } + ASSERT(key.size() <= stats_options.maxNameLength()); ref_count_ = 1; - - // key is not necessarily nul-terminated, but we want to make sure name_ is. - uint64_t xfer_size = std::min(nameSize() - 1, key.size()); - memcpy(name_, key.data(), xfer_size); - name_[xfer_size] = '\0'; + memcpy(name_, key.data(), key.size()); + name_[key.size()] = '\0'; } HistogramStatisticsImpl::HistogramStatisticsImpl(const histogram_t* histogram_ptr) @@ -427,26 +401,32 @@ void SourceImpl::clearCache() { histograms_.reset(); } -CounterSharedPtr RawStatDataAllocator::makeCounter(const std::string& name, - std::string&& tag_extracted_name, - std::vector&& tags) { - RawStatData* data = alloc(name); +template +CounterSharedPtr StatDataAllocatorImpl::makeCounter(absl::string_view name, + std::string&& tag_extracted_name, + std::vector&& tags) { + StatData* data = alloc(name); if (data == nullptr) { return nullptr; } - return std::make_shared(*data, *this, std::move(tag_extracted_name), - std::move(tags)); + return std::make_shared>(*data, *this, std::move(tag_extracted_name), + std::move(tags)); } -GaugeSharedPtr RawStatDataAllocator::makeGauge(const std::string& name, - std::string&& tag_extracted_name, - std::vector&& tags) { - RawStatData* data = alloc(name); +template +GaugeSharedPtr StatDataAllocatorImpl::makeGauge(absl::string_view name, + std::string&& tag_extracted_name, + std::vector&& tags) { + StatData* data = alloc(name); if (data == nullptr) { return nullptr; } - return std::make_shared(*data, *this, std::move(tag_extracted_name), std::move(tags)); + return std::make_shared>(*data, *this, std::move(tag_extracted_name), + std::move(tags)); } +template class StatDataAllocatorImpl; +template class StatDataAllocatorImpl; + } // namespace Stats } // namespace Envoy diff --git a/source/common/stats/stats_impl.h b/source/common/stats/stats_impl.h index b9616e357246a..a007d4a2a21d4 100644 --- a/source/common/stats/stats_impl.h +++ b/source/common/stats/stats_impl.h @@ -32,6 +32,27 @@ namespace Envoy { namespace Stats { +// The max name length is based on current set of stats. +// As of now, the longest stat is +// cluster..outlier_detection.ejections_consecutive_5xx +// which is 52 characters long without the cluster name. +// The max stat name length is 127 (default). So, in order to give room +// for growth to both the envoy generated stat characters +// (e.g., outlier_detection...) and user supplied names (e.g., cluster name), +// we set the max user supplied name length to 60, and the max internally +// generated stat suffixes to 67 (15 more characters to grow). +// If you want to increase the max user supplied name length, use the compiler +// option ENVOY_DEFAULT_MAX_OBJ_NAME_LENGTH or the CLI option +// max-obj-name-len +struct StatsOptionsImpl : public StatsOptions { + size_t maxNameLength() const override { return max_obj_name_length_ + max_stat_suffix_length_; } + size_t maxObjNameLength() const override { return max_obj_name_length_; } + size_t maxStatSuffixLength() const override { return max_stat_suffix_length_; } + + size_t max_obj_name_length_ = 60; + size_t max_stat_suffix_length_ = 67; +}; + class TagExtractorImpl : public TagExtractor { public: /** @@ -170,7 +191,7 @@ class Utility { * it can be allocated from shared memory if needed. * * @note Due to name_ being variable size, sizeof(RawStatData) probably isn't useful. Use - * RawStatData::size() instead. + * RawStatData::structSize() or RawStatData::structSizeWithOptions() instead. */ struct RawStatData { @@ -182,54 +203,24 @@ struct RawStatData { ~RawStatData() = delete; /** - * Configure static settings. This MUST be called - * before any other static or instance methods. - */ - static void configure(Server::Options& options); - - /** - * Allow tests to re-configure this value after it has been set. - * This is unsafe in a non-test context. - */ - static void configureForTestsOnly(Server::Options& options); - - /** - * Returns the maximum length of the name of a stat. This length - * does not include a trailing NULL-terminator. - */ - static size_t maxNameLength() { return maxObjNameLength() + MAX_STAT_SUFFIX_LENGTH; } - - /** - * Returns the maximum length of a user supplied object (route/cluster/listener) - * name field in a stat. This length does not include a trailing NULL-terminator. - */ - static size_t maxObjNameLength() { - return initializeAndGetMutableMaxObjNameLength(DEFAULT_MAX_OBJ_NAME_LENGTH); - } - - /** - * Returns the maximum length of a stat suffix that Envoy generates (over the user supplied name). - * This length does not include a trailing NULL-terminator. - */ - static size_t maxStatSuffixLength() { return MAX_STAT_SUFFIX_LENGTH; } - - /** - * size in bytes of name_ + * Returns the size of this struct, accounting for the length of name_ + * and padding for alignment. */ - static size_t nameSize() { return maxNameLength() + 1; } + static uint64_t structSize(uint64_t name_size); /** - * Returns the size of this struct, accounting for the length of name_ - * and padding for alignment. This is required by BlockMemoryHashSet. + * Wrapper for structSize, taking a StatsOptions struct. + * Required by BlockMemoryHashSet, which has the context to supply StatsOptions. */ - static uint64_t size(); + static uint64_t structSizeWithOptions(const StatsOptions& stats_options); /** * Initializes this object to have the specified key, - * a refcount of 1, and all other values zero. This is required by - * BlockMemoryHashSet. + * a refcount of 1, and all other values zero. Required for the HeapRawStatDataAllocator, which + * does not expect stat name truncation. We pass in the number of bytes allocated in order to + * assert the copy is safe inline. */ - void initialize(absl::string_view key); + void initialize(absl::string_view key, const StatsOptions& stats_options); /** * Returns a hash of the key. This is required by BlockMemoryHashSet. @@ -242,11 +233,9 @@ struct RawStatData { bool initialized() { return name_[0] != '\0'; } /** - * Returns the name as a string_view. This is required by BlockMemoryHashSet. + * Returns the name as a string_view with no truncation. */ - absl::string_view key() const { - return absl::string_view(name_, strnlen(name_, maxNameLength())); - } + absl::string_view key() const { return absl::string_view(name_); } std::atomic value_; std::atomic pending_increment_; @@ -254,28 +243,6 @@ struct RawStatData { std::atomic ref_count_; std::atomic unused_; char name_[]; - -private: - // The max name length is based on current set of stats. - // As of now, the longest stat is - // cluster..outlier_detection.ejections_consecutive_5xx - // which is 52 characters long without the cluster name. - // The max stat name length is 127 (default). So, in order to give room - // for growth to both the envoy generated stat characters - // (e.g., outlier_detection...) and user supplied names (e.g., cluster name), - // we set the max user supplied name length to 60, and the max internally - // generated stat suffixes to 67 (15 more characters to grow). - // If you want to increase the max user supplied name length, use the compiler - // option ENVOY_DEFAULT_MAX_OBJ_NAME_LENGTH or the CLI option - // max-obj-name-len - static const size_t DEFAULT_MAX_OBJ_NAME_LENGTH = 60; - static const size_t MAX_STAT_SUFFIX_LENGTH = 67; - - /** - * @return uint64_t& a reference to the configured size, which can then be changed - * by callers. - */ - static uint64_t& initializeAndGetMutableMaxObjNameLength(uint64_t configured_size); }; /** @@ -305,33 +272,48 @@ class MetricImpl : public virtual Metric { const std::vector tags_; }; -/** - * Implements a StatDataAllocator that uses RawStatData -- capable of deploying - * in a shared memory block without internal pointers. - */ -class RawStatDataAllocator : public StatDataAllocator { +// Partially implements a StatDataAllocator, leaving alloc & free for subclasses. +// We templatize on StatData rather than defining a virtual base StatData class +// for performance reasons; stat increment is on the hot path. +// +// The two production derivations cover using a fixed block of shared-memory for +// hot restart stat continuity, and heap allocation for more efficient RAM usage +// for when hot-restart is not required. +// +// Also note that RawStatData needs to live in a shared memory block, and it's +// possible, but not obvious, that a vptr would be usable across processes. In +// any case, RawStatData is allocated from a shared-memory block rather than via +// new, so the usual C++ compiler assistance for setting up vptrs will not be +// available. This could be resolved with placed new, or another nesting level. +template class StatDataAllocatorImpl : public StatDataAllocator { public: // StatDataAllocator - CounterSharedPtr makeCounter(const std::string& name, std::string&& tag_extracted_name, + CounterSharedPtr makeCounter(absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags) override; - GaugeSharedPtr makeGauge(const std::string& name, std::string&& tag_extracted_name, + GaugeSharedPtr makeGauge(absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags) override; /** * @param name the full name of the stat. - * @return RawStatData* a raw stat data block for a given stat name or nullptr if there is no - * more memory available for stats. The allocator should return a reference counted - * data location by name if one already exists with the same name. This is used for - * intra-process scope swapping as well as inter-process hot restart. + * @return StatData* a data block for a given stat name or nullptr if there is no more memory + * available for stats. The allocator should return a reference counted data location + * by name if one already exists with the same name. This is used for intra-process + * scope swapping as well as inter-process hot restart. */ - virtual RawStatData* alloc(const std::string& name) PURE; + virtual StatData* alloc(absl::string_view name) PURE; /** * Free a raw stat data block. The allocator should handle reference counting and only truly * free the block if it is no longer needed. * @param data the data returned by alloc(). */ - virtual void free(RawStatData& data) PURE; + virtual void free(StatData& data) PURE; +}; + +class RawStatDataAllocator : public StatDataAllocatorImpl { +public: + // StatDataAllocator + bool requiresBoundedStatNameSize() const override { return true; } }; /** @@ -395,30 +377,58 @@ class SourceImpl : public Source { }; /** - * Implementation of RawStatDataAllocator that uses an unordered set to store - * RawStatData pointers. + * This structure is an alternate backing store for both CounterImpl and GaugeImpl. It is designed + * so that it can be allocated efficiently from the heap on demand. + */ +struct HeapStatData { + explicit HeapStatData(absl::string_view key); + + /** + * @returns absl::string_view the name as a string_view. + */ + absl::string_view key() const { return name_; } + + std::atomic value_{0}; + std::atomic pending_increment_{0}; + std::atomic flags_{0}; + std::atomic ref_count_{1}; + std::string name_; +}; + +/** + * Implementation of StatDataAllocator using a pure heap-based strategy, so that + * Envoy implementations that do not require hot-restart can use less memory. */ -class HeapRawStatDataAllocator : public RawStatDataAllocator { +class HeapStatDataAllocator : public StatDataAllocatorImpl { public: - // RawStatDataAllocator - ~HeapRawStatDataAllocator() { ASSERT(stats_.empty()); } - RawStatData* alloc(const std::string& name) override; - void free(RawStatData& data) override; + HeapStatDataAllocator() {} + ~HeapStatDataAllocator() { ASSERT(stats_.empty()); } + + // StatDataAllocatorImpl + HeapStatData* alloc(absl::string_view name) override; + void free(HeapStatData& data) override; + + // StatDataAllocator + bool requiresBoundedStatNameSize() const override { return false; } private: - struct RawStatDataHash_ { - size_t operator()(const RawStatData* a) const { return HashUtil::xxHash64(a->key()); } + struct HeapStatHash_ { + size_t operator()(const HeapStatData* a) const { return HashUtil::xxHash64(a->key()); } }; - struct RawStatDataCompare_ { - bool operator()(const RawStatData* a, const RawStatData* b) const { + struct HeapStatCompare_ { + bool operator()(const HeapStatData* a, const HeapStatData* b) const { return (a->key() == b->key()); } }; - typedef std::unordered_set StringRawDataSet; - // An unordered set of RawStatData pointers which keys off the key() + // TODO(jmarantz): See https://github.com/envoyproxy/envoy/pull/3927 and + // https://github.com/envoyproxy/envoy/issues/3585, which can help reorganize + // the heap stats using a ref-counted symbol table to compress the stat strings. + typedef std::unordered_set StatSet; + + // An unordered set of HeapStatData pointers which keys off the key() // field in each object. This necessitates a custom comparator and hasher. - StringRawDataSet stats_ GUARDED_BY(mutex_); + StatSet stats_ GUARDED_BY(mutex_); // A mutex is needed here to protect the stats_ object from both alloc() and free() operations. // Although alloc() operations are called under existing locking, free() operations are made from // the destructors of the individual stat objects, which are not protected by locks. @@ -492,6 +502,7 @@ class IsolatedStoreImpl : public Store { Histogram& histogram = histograms_.get(name); return histogram; } + const Stats::StatsOptions& statsOptions() const override { return stats_options_; } // Stats::Store std::vector counters() const override { return counters_.toVector(); } @@ -515,15 +526,17 @@ class IsolatedStoreImpl : public Store { Histogram& histogram(const std::string& name) override { return parent_.histogram(prefix_ + name); } + const Stats::StatsOptions& statsOptions() const override { return parent_.statsOptions(); } IsolatedStoreImpl& parent_; const std::string prefix_; }; - HeapRawStatDataAllocator alloc_; + HeapStatDataAllocator alloc_; IsolatedStatsCache counters_; IsolatedStatsCache gauges_; IsolatedStatsCache histograms_; + const StatsOptionsImpl stats_options_; }; } // namespace Stats diff --git a/source/common/stats/thread_local_store.cc b/source/common/stats/thread_local_store.cc index 681c9c9b7be3c..0467e47f587a7 100644 --- a/source/common/stats/thread_local_store.cc +++ b/source/common/stats/thread_local_store.cc @@ -12,8 +12,9 @@ namespace Envoy { namespace Stats { -ThreadLocalStoreImpl::ThreadLocalStoreImpl(StatDataAllocator& alloc) - : alloc_(alloc), default_scope_(createScope("")), +ThreadLocalStoreImpl::ThreadLocalStoreImpl(const Stats::StatsOptions& stats_options, + StatDataAllocator& alloc) + : stats_options_(stats_options), alloc_(alloc), default_scope_(createScope("")), tag_producer_(std::make_unique()), num_last_resort_stats_(default_scope_->counter("stats.overflow")), source_(*this) {} @@ -155,6 +156,26 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id) { } } +absl::string_view ThreadLocalStoreImpl::truncateStatNameIfNeeded(absl::string_view name) { + // If the main allocator requires stat name truncation, warn and truncate, before + // attempting to allocate. + if (alloc_.requiresBoundedStatNameSize()) { + const uint64_t max_length = stats_options_.maxNameLength(); + + // Note that the heap-allocator does not truncate itself; we have to + // truncate here if we are using heap-allocation as a fallback due to an + // exahusted shared-memory block + if (name.size() > max_length) { + ENVOY_LOG_MISC( + warn, + "Statistic '{}' is too long with {} characters, it will be truncated to {} characters", + name, name.size(), max_length); + name = absl::string_view(name.data(), max_length); + } + } + return name; +} + std::atomic ThreadLocalStoreImpl::ScopeImpl::next_scope_id_; ThreadLocalStoreImpl::ScopeImpl::~ScopeImpl() { parent_.releaseScopeCrossThread(this); } @@ -176,13 +197,17 @@ StatType& ThreadLocalStoreImpl::ScopeImpl::safeMakeStat( std::shared_ptr& central_ref = central_cache_map[name]; if (!central_ref) { std::vector tags; + + // Tag extraction occurs on the original, untruncated name so the extraction + // can complete properly, even if the tag values are partially truncated. std::string tag_extracted_name = parent_.getTagsForName(name, tags); + absl::string_view truncated_name = parent_.truncateStatNameIfNeeded(name); std::shared_ptr stat = - make_stat(parent_.alloc_, name, std::move(tag_extracted_name), std::move(tags)); + make_stat(parent_.alloc_, truncated_name, std::move(tag_extracted_name), std::move(tags)); if (stat == nullptr) { parent_.num_last_resort_stats_.inc(); - stat = - make_stat(parent_.heap_allocator_, name, std::move(tag_extracted_name), std::move(tags)); + stat = make_stat(parent_.heap_allocator_, truncated_name, std::move(tag_extracted_name), + std::move(tags)); ASSERT(stat != nullptr); } central_ref = stat; @@ -212,7 +237,7 @@ Counter& ThreadLocalStoreImpl::ScopeImpl::counter(const std::string& name) { return safeMakeStat( final_name, central_cache_.counters_, - [](StatDataAllocator& allocator, const std::string& name, std::string&& tag_extracted_name, + [](StatDataAllocator& allocator, absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags) -> CounterSharedPtr { return allocator.makeCounter(name, std::move(tag_extracted_name), std::move(tags)); }, @@ -246,7 +271,7 @@ Gauge& ThreadLocalStoreImpl::ScopeImpl::gauge(const std::string& name) { return safeMakeStat( final_name, central_cache_.gauges_, - [](StatDataAllocator& allocator, const std::string& name, std::string&& tag_extracted_name, + [](StatDataAllocator& allocator, absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags) -> GaugeSharedPtr { return allocator.makeGauge(name, std::move(tag_extracted_name), std::move(tags)); }, diff --git a/source/common/stats/thread_local_store.h b/source/common/stats/thread_local_store.h index 05f8f16fbec11..cc280613ce66c 100644 --- a/source/common/stats/thread_local_store.h +++ b/source/common/stats/thread_local_store.h @@ -163,7 +163,7 @@ class TlsScope : public Scope { */ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRoot { public: - ThreadLocalStoreImpl(StatDataAllocator& alloc); + ThreadLocalStoreImpl(const Stats::StatsOptions& stats_options, StatDataAllocator& alloc); ~ThreadLocalStoreImpl(); // Stats::Scope @@ -195,6 +195,8 @@ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRo Source& source() override { return source_; } + const Stats::StatsOptions& statsOptions() const override { return stats_options_; } + private: struct TlsCacheEntry { std::unordered_map counters_; @@ -224,10 +226,11 @@ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRo Gauge& gauge(const std::string& name) override; Histogram& histogram(const std::string& name) override; Histogram& tlsHistogram(const std::string& name, ParentHistogramImpl& parent) override; + const Stats::StatsOptions& statsOptions() const override { return parent_.statsOptions(); } template using MakeStatFn = - std::function(StatDataAllocator&, const std::string& name, + std::function(StatDataAllocator&, absl::string_view name, std::string&& tag_extracted_name, std::vector&& tags)>; @@ -271,7 +274,9 @@ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRo void clearScopeFromCaches(uint64_t scope_id); void releaseScopeCrossThread(ScopeImpl* scope); void mergeInternal(PostMergeCb mergeCb); + absl::string_view truncateStatNameIfNeeded(absl::string_view name); + const Stats::StatsOptions& stats_options_; StatDataAllocator& alloc_; Event::Dispatcher* main_thread_dispatcher_{}; ThreadLocal::SlotPtr tls_; @@ -283,7 +288,7 @@ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRo std::atomic shutting_down_{}; std::atomic merge_in_progress_{}; Counter& num_last_resort_stats_; - HeapRawStatDataAllocator heap_allocator_; + HeapStatDataAllocator heap_allocator_; SourceImpl source_; }; diff --git a/source/common/tcp/BUILD b/source/common/tcp/BUILD new file mode 100644 index 0000000000000..a201fb01ebec7 --- /dev/null +++ b/source/common/tcp/BUILD @@ -0,0 +1,31 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "conn_pool_lib", + srcs = ["conn_pool.cc"], + hdrs = ["conn_pool.h"], + external_deps = ["abseil_optional"], + deps = [ + "//include/envoy/event:deferred_deletable", + "//include/envoy/event:dispatcher_interface", + "//include/envoy/event:timer_interface", + "//include/envoy/network:connection_interface", + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:timespan", + "//include/envoy/tcp:conn_pool_interface", + "//include/envoy/upstream:upstream_interface", + "//source/common/common:linked_object", + "//source/common/common:utility_lib", + "//source/common/network:filter_lib", + "//source/common/network:utility_lib", + "//source/common/upstream:upstream_lib", + ], +) diff --git a/source/common/tcp/conn_pool.cc b/source/common/tcp/conn_pool.cc new file mode 100644 index 0000000000000..2ee50a6fcdf9f --- /dev/null +++ b/source/common/tcp/conn_pool.cc @@ -0,0 +1,394 @@ +#include "common/tcp/conn_pool.h" + +#include "envoy/event/dispatcher.h" +#include "envoy/event/timer.h" +#include "envoy/stats/stats.h" +#include "envoy/upstream/upstream.h" + +namespace Envoy { +namespace Tcp { + +ConnPoolImpl::ConnPoolImpl(Event::Dispatcher& dispatcher, Upstream::HostConstSharedPtr host, + Upstream::ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options) + : dispatcher_(dispatcher), host_(host), priority_(priority), socket_options_(options), + upstream_ready_timer_(dispatcher_.createTimer([this]() { onUpstreamReady(); })) {} + +ConnPoolImpl::~ConnPoolImpl() { + while (!ready_conns_.empty()) { + ready_conns_.front()->conn_->close(Network::ConnectionCloseType::NoFlush); + } + + while (!busy_conns_.empty()) { + busy_conns_.front()->conn_->close(Network::ConnectionCloseType::NoFlush); + } + + // Make sure all connections are destroyed before we are destroyed. + dispatcher_.clearDeferredDeleteList(); +} + +void ConnPoolImpl::drainConnections() { + while (!ready_conns_.empty()) { + ready_conns_.front()->conn_->close(Network::ConnectionCloseType::NoFlush); + } + + // We drain busy connections by manually setting remaining requests to 1. Thus, when the next + // response completes the connection will be destroyed. + for (const auto& conn : busy_conns_) { + conn->remaining_requests_ = 1; + } +} + +void ConnPoolImpl::addDrainedCallback(DrainedCb cb) { + drained_callbacks_.push_back(cb); + checkForDrained(); +} + +void ConnPoolImpl::assignConnection(ActiveConn& conn, ConnectionPool::Callbacks& callbacks) { + ASSERT(conn.wrapper_ == nullptr); + conn.wrapper_ = std::make_shared(conn); + + callbacks.onPoolReady(std::make_unique(conn.wrapper_), + conn.real_host_description_); +} + +void ConnPoolImpl::checkForDrained() { + if (!drained_callbacks_.empty() && pending_requests_.empty() && busy_conns_.empty()) { + while (!ready_conns_.empty()) { + ready_conns_.front()->conn_->close(Network::ConnectionCloseType::NoFlush); + } + + for (const DrainedCb& cb : drained_callbacks_) { + cb(); + } + } +} + +void ConnPoolImpl::createNewConnection() { + ENVOY_LOG(debug, "creating a new connection"); + ActiveConnPtr conn(new ActiveConn(*this)); + conn->moveIntoList(std::move(conn), busy_conns_); +} + +ConnectionPool::Cancellable* ConnPoolImpl::newConnection(ConnectionPool::Callbacks& callbacks) { + if (!ready_conns_.empty()) { + ready_conns_.front()->moveBetweenLists(ready_conns_, busy_conns_); + ENVOY_CONN_LOG(debug, "using existing connection", *busy_conns_.front()->conn_); + assignConnection(*busy_conns_.front(), callbacks); + return nullptr; + } + + if (host_->cluster().resourceManager(priority_).pendingRequests().canCreate()) { + bool can_create_connection = + host_->cluster().resourceManager(priority_).connections().canCreate(); + if (!can_create_connection) { + host_->cluster().stats().upstream_cx_overflow_.inc(); + } + + // If we have no connections at all, make one no matter what so we don't starve. + if ((ready_conns_.size() == 0 && busy_conns_.size() == 0) || can_create_connection) { + createNewConnection(); + } + + ENVOY_LOG(debug, "queueing request due to no available connections"); + PendingRequestPtr pending_request(new PendingRequest(*this, callbacks)); + pending_request->moveIntoList(std::move(pending_request), pending_requests_); + return pending_requests_.front().get(); + } else { + ENVOY_LOG(debug, "max pending requests overflow"); + callbacks.onPoolFailure(ConnectionPool::PoolFailureReason::Overflow, nullptr); + host_->cluster().stats().upstream_rq_pending_overflow_.inc(); + return nullptr; + } +} + +void ConnPoolImpl::onConnectionEvent(ActiveConn& conn, Network::ConnectionEvent event) { + if (event == Network::ConnectionEvent::RemoteClose || + event == Network::ConnectionEvent::LocalClose) { + ENVOY_CONN_LOG(debug, "client disconnected", *conn.conn_); + + if (event == Network::ConnectionEvent::LocalClose) { + host_->cluster().stats().upstream_cx_destroy_local_.inc(); + } + if (event == Network::ConnectionEvent::RemoteClose) { + host_->cluster().stats().upstream_cx_destroy_remote_.inc(); + } + host_->cluster().stats().upstream_cx_destroy_.inc(); + + ActiveConnPtr removed; + bool check_for_drained = true; + if (conn.wrapper_ != nullptr) { + if (!conn.wrapper_->released_) { + if (event == Network::ConnectionEvent::LocalClose) { + host_->cluster().stats().upstream_cx_destroy_local_with_active_rq_.inc(); + } + if (event == Network::ConnectionEvent::RemoteClose) { + host_->cluster().stats().upstream_cx_destroy_remote_with_active_rq_.inc(); + } + host_->cluster().stats().upstream_cx_destroy_with_active_rq_.inc(); + + conn.wrapper_->release(true); + } + + removed = conn.removeFromList(busy_conns_); + } else if (!conn.connect_timer_) { + // The connect timer is destroyed on connect. The lack of a connect timer means that this + // connection is idle and in the ready pool. + removed = conn.removeFromList(ready_conns_); + check_for_drained = false; + } else { + // The only time this happens is if we actually saw a connect failure. + host_->cluster().stats().upstream_cx_connect_fail_.inc(); + host_->stats().cx_connect_fail_.inc(); + removed = conn.removeFromList(busy_conns_); + + // Raw connect failures should never happen under normal circumstances. If we have an upstream + // that is behaving badly, requests can get stuck here in the pending state. If we see a + // connect failure, we purge all pending requests so that calling code can determine what to + // do with the request. + // NOTE: We move the existing pending requests to a temporary list. This is done so that + // if retry logic submits a new request to the pool, we don't fail it inline. + ConnectionPool::PoolFailureReason reason; + if (conn.timed_out_) { + reason = ConnectionPool::PoolFailureReason::Timeout; + } else if (event == Network::ConnectionEvent::RemoteClose) { + reason = ConnectionPool::PoolFailureReason::RemoteConnectionFailure; + } else { + reason = ConnectionPool::PoolFailureReason::LocalConnectionFailure; + } + + std::list pending_requests_to_purge(std::move(pending_requests_)); + while (!pending_requests_to_purge.empty()) { + PendingRequestPtr request = + pending_requests_to_purge.front()->removeFromList(pending_requests_to_purge); + host_->cluster().stats().upstream_rq_pending_failure_eject_.inc(); + request->callbacks_.onPoolFailure(reason, conn.real_host_description_); + } + } + + dispatcher_.deferredDelete(std::move(removed)); + + // If we have pending requests and we just lost a connection we should make a new one. + if (pending_requests_.size() > (ready_conns_.size() + busy_conns_.size())) { + createNewConnection(); + } + + if (check_for_drained) { + checkForDrained(); + } + } + + if (conn.connect_timer_) { + conn.connect_timer_->disableTimer(); + conn.connect_timer_.reset(); + } + + // Note that the order in this function is important. Concretely, we must destroy the connect + // timer before we process an idle connection, because if this results in an immediate + // drain/destruction event, we key off of the existence of the connect timer above to determine + // whether the connection is in the ready list (connected) or the busy list (failed to connect). + if (event == Network::ConnectionEvent::Connected) { + conn_connect_ms_->complete(); + processIdleConnection(conn, false); + } +} + +void ConnPoolImpl::onPendingRequestCancel(PendingRequest& request) { + ENVOY_LOG(debug, "canceling pending request"); + request.removeFromList(pending_requests_); + host_->cluster().stats().upstream_rq_cancelled_.inc(); + checkForDrained(); +} + +void ConnPoolImpl::onConnReleased(ActiveConn& conn) { + ENVOY_CONN_LOG(debug, "connection released", *conn.conn_); + + if (conn.remaining_requests_ > 0 && --conn.remaining_requests_ == 0) { + ENVOY_CONN_LOG(debug, "maximum requests per connection", *conn.conn_); + host_->cluster().stats().upstream_cx_max_requests_.inc(); + + conn.conn_->close(Network::ConnectionCloseType::NoFlush); + } else { + // Upstream connection might be closed right after response is complete. Setting delay=true + // here to assign pending requests in next dispatcher loop to handle that case. + // https://github.com/envoyproxy/envoy/issues/2715 + processIdleConnection(conn, true); + } +} + +void ConnPoolImpl::onConnDestroyed(ActiveConn& conn) { + ENVOY_CONN_LOG(debug, "connection destroyed", *conn.conn_); +} + +void ConnPoolImpl::onUpstreamReady() { + upstream_ready_enabled_ = false; + while (!pending_requests_.empty() && !ready_conns_.empty()) { + ActiveConn& conn = *ready_conns_.front(); + ENVOY_CONN_LOG(debug, "assigning connection", *conn.conn_); + // There is work to do so bind a connection to the caller and move it to the busy list. Pending + // requests are pushed onto the front, so pull from the back. + assignConnection(conn, pending_requests_.back()->callbacks_); + pending_requests_.pop_back(); + conn.moveBetweenLists(ready_conns_, busy_conns_); + } +} + +void ConnPoolImpl::processIdleConnection(ActiveConn& conn, bool delay) { + conn.wrapper_.reset(); + if (pending_requests_.empty() || delay) { + // There is nothing to service or delayed processing is requested, so just move the connection + // into the ready list. + ENVOY_CONN_LOG(debug, "moving to ready", *conn.conn_); + conn.moveBetweenLists(busy_conns_, ready_conns_); + } else { + // There is work to do immediately so bind a request to the caller and move it to the busy list. + // Pending requests are pushed onto the front, so pull from the back. + ENVOY_CONN_LOG(debug, "assigning connection", *conn.conn_); + assignConnection(conn, pending_requests_.back()->callbacks_); + pending_requests_.pop_back(); + } + + if (delay && !pending_requests_.empty() && !upstream_ready_enabled_) { + upstream_ready_enabled_ = true; + upstream_ready_timer_->enableTimer(std::chrono::milliseconds(0)); + } + + checkForDrained(); +} + +ConnPoolImpl::ConnectionWrapper::ConnectionWrapper(ActiveConn& parent) : parent_(parent) { + parent_.parent_.host_->cluster().stats().upstream_rq_total_.inc(); + parent_.parent_.host_->cluster().stats().upstream_rq_active_.inc(); + parent_.parent_.host_->stats().rq_total_.inc(); + parent_.parent_.host_->stats().rq_active_.inc(); +} + +Network::ClientConnection& ConnPoolImpl::ConnectionWrapper::connection() { + ASSERT(!released_); + return *parent_.conn_; +} + +void ConnPoolImpl::ConnectionWrapper::addUpstreamCallbacks(ConnectionPool::UpstreamCallbacks& cb) { + ASSERT(!released_); + callbacks_ = &cb; +} + +void ConnPoolImpl::ConnectionWrapper::release(bool closed) { + // Allow multiple calls: connection close and destruction of ConnectionDataImplPtr will both + // result in this call. + if (!released_) { + released_ = true; + callbacks_ = nullptr; + if (!closed) { + parent_.parent_.onConnReleased(parent_); + } + + parent_.parent_.host_->cluster().stats().upstream_rq_active_.dec(); + parent_.parent_.host_->stats().rq_active_.dec(); + } +} + +ConnPoolImpl::PendingRequest::PendingRequest(ConnPoolImpl& parent, + ConnectionPool::Callbacks& callbacks) + : parent_(parent), callbacks_(callbacks) { + parent_.host_->cluster().stats().upstream_rq_pending_total_.inc(); + parent_.host_->cluster().stats().upstream_rq_pending_active_.inc(); + parent_.host_->cluster().resourceManager(parent_.priority_).pendingRequests().inc(); +} + +ConnPoolImpl::PendingRequest::~PendingRequest() { + parent_.host_->cluster().stats().upstream_rq_pending_active_.dec(); + parent_.host_->cluster().resourceManager(parent_.priority_).pendingRequests().dec(); +} + +ConnPoolImpl::ActiveConn::ActiveConn(ConnPoolImpl& parent) + : parent_(parent), + connect_timer_(parent_.dispatcher_.createTimer([this]() -> void { onConnectTimeout(); })), + remaining_requests_(parent_.host_->cluster().maxRequestsPerConnection()), timed_out_(false) { + + parent_.conn_connect_ms_.reset( + new Stats::Timespan(parent_.host_->cluster().stats().upstream_cx_connect_ms_)); + + Upstream::Host::CreateConnectionData data = + parent_.host_->createConnection(parent_.dispatcher_, parent_.socket_options_); + real_host_description_ = data.host_description_; + + conn_ = std::move(data.connection_); + + conn_->detectEarlyCloseWhenReadDisabled(false); + conn_->addConnectionCallbacks(*this); + conn_->addReadFilter(Network::ReadFilterSharedPtr{new ConnReadFilter(*this)}); + + ENVOY_CONN_LOG(debug, "connecting", *conn_); + conn_->connect(); + + parent_.host_->cluster().stats().upstream_cx_total_.inc(); + parent_.host_->cluster().stats().upstream_cx_active_.inc(); + parent_.host_->stats().cx_total_.inc(); + parent_.host_->stats().cx_active_.inc(); + conn_length_.reset(new Stats::Timespan(parent_.host_->cluster().stats().upstream_cx_length_ms_)); + connect_timer_->enableTimer(parent_.host_->cluster().connectTimeout()); + parent_.host_->cluster().resourceManager(parent_.priority_).connections().inc(); + + conn_->setConnectionStats({parent_.host_->cluster().stats().upstream_cx_rx_bytes_total_, + parent_.host_->cluster().stats().upstream_cx_rx_bytes_buffered_, + parent_.host_->cluster().stats().upstream_cx_tx_bytes_total_, + parent_.host_->cluster().stats().upstream_cx_tx_bytes_buffered_, + &parent_.host_->cluster().stats().bind_errors_}); + + // We just universally set no delay on connections. Theoretically we might at some point want + // to make this configurable. + conn_->noDelay(true); +} + +ConnPoolImpl::ActiveConn::~ActiveConn() { + parent_.host_->cluster().stats().upstream_cx_active_.dec(); + parent_.host_->stats().cx_active_.dec(); + conn_length_->complete(); + parent_.host_->cluster().resourceManager(parent_.priority_).connections().dec(); + + parent_.onConnDestroyed(*this); +} + +void ConnPoolImpl::ActiveConn::onConnectTimeout() { + // We just close the connection at this point. This will result in both a timeout and a connect + // failure and will fold into all the normal connect failure logic. + ENVOY_CONN_LOG(debug, "connect timeout", *conn_); + timed_out_ = true; + parent_.host_->cluster().stats().upstream_cx_connect_timeout_.inc(); + conn_->close(Network::ConnectionCloseType::NoFlush); +} + +void ConnPoolImpl::ActiveConn::onUpstreamData(Buffer::Instance& data, bool end_stream) { + if (wrapper_ != nullptr && wrapper_->callbacks_ != nullptr) { + // Delegate to the connection owner. + wrapper_->callbacks_->onUpstreamData(data, end_stream); + } else { + // Unexpected data from upstream, close down the connection. + ENVOY_CONN_LOG(debug, "unexpected data from upstream, closing connection", *conn_); + conn_->close(Network::ConnectionCloseType::NoFlush); + } +} + +void ConnPoolImpl::ActiveConn::onEvent(Network::ConnectionEvent event) { + if (wrapper_ != nullptr && wrapper_->callbacks_ != nullptr) { + wrapper_->callbacks_->onEvent(event); + } + + parent_.onConnectionEvent(*this, event); +} + +void ConnPoolImpl::ActiveConn::onAboveWriteBufferHighWatermark() { + if (wrapper_ != nullptr && wrapper_->callbacks_ != nullptr) { + wrapper_->callbacks_->onAboveWriteBufferHighWatermark(); + } +} + +void ConnPoolImpl::ActiveConn::onBelowWriteBufferLowWatermark() { + if (wrapper_ != nullptr && wrapper_->callbacks_ != nullptr) { + wrapper_->callbacks_->onBelowWriteBufferLowWatermark(); + } +} + +} // namespace Tcp +} // namespace Envoy diff --git a/source/common/tcp/conn_pool.h b/source/common/tcp/conn_pool.h new file mode 100644 index 0000000000000..6e1846fe43c31 --- /dev/null +++ b/source/common/tcp/conn_pool.h @@ -0,0 +1,140 @@ +#pragma once + +#include +#include + +#include "envoy/event/deferred_deletable.h" +#include "envoy/event/timer.h" +#include "envoy/network/connection.h" +#include "envoy/network/filter.h" +#include "envoy/stats/timespan.h" +#include "envoy/tcp/conn_pool.h" +#include "envoy/upstream/upstream.h" + +#include "common/common/linked_object.h" +#include "common/common/logger.h" +#include "common/network/filter_impl.h" + +namespace Envoy { +namespace Tcp { + +class ConnPoolImpl : Logger::Loggable, public ConnectionPool::Instance { +public: + ConnPoolImpl(Event::Dispatcher& dispatcher, Upstream::HostConstSharedPtr host, + Upstream::ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options); + + ~ConnPoolImpl(); + + // ConnectionPool::Instance + void addDrainedCallback(DrainedCb cb) override; + void drainConnections() override; + ConnectionPool::Cancellable* newConnection(ConnectionPool::Callbacks& callbacks) override; + +protected: + struct ActiveConn; + + struct ConnectionWrapper { + ConnectionWrapper(ActiveConn& parent); + + Network::ClientConnection& connection(); + void addUpstreamCallbacks(ConnectionPool::UpstreamCallbacks& callbacks); + void release(bool closed); + + ActiveConn& parent_; + ConnectionPool::UpstreamCallbacks* callbacks_{}; + bool released_{false}; + }; + + typedef std::shared_ptr ConnectionWrapperSharedPtr; + + struct ConnectionDataImpl : public ConnectionPool::ConnectionData { + ConnectionDataImpl(ConnectionWrapperSharedPtr wrapper) : wrapper_(wrapper) {} + ~ConnectionDataImpl() { wrapper_->release(false); } + + // ConnectionPool::ConnectionData + Network::ClientConnection& connection() override { return wrapper_->connection(); } + void addUpstreamCallbacks(ConnectionPool::UpstreamCallbacks& callbacks) override { + wrapper_->addUpstreamCallbacks(callbacks); + }; + + ConnectionWrapperSharedPtr wrapper_; + }; + + struct ConnReadFilter : public Network::ReadFilterBaseImpl { + ConnReadFilter(ActiveConn& parent) : parent_(parent) {} + + // Network::ReadFilter + Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) { + parent_.onUpstreamData(data, end_stream); + return Network::FilterStatus::StopIteration; + } + + ActiveConn& parent_; + }; + + struct ActiveConn : LinkedObject, + public Network::ConnectionCallbacks, + public Event::DeferredDeletable { + ActiveConn(ConnPoolImpl& parent); + ~ActiveConn(); + + void onConnectTimeout(); + void onUpstreamData(Buffer::Instance& data, bool end_stream); + + // Network::ConnectionCallbacks + void onEvent(Network::ConnectionEvent event) override; + void onAboveWriteBufferHighWatermark() override; + void onBelowWriteBufferLowWatermark() override; + + ConnPoolImpl& parent_; + Upstream::HostDescriptionConstSharedPtr real_host_description_; + ConnectionWrapperSharedPtr wrapper_; + Network::ClientConnectionPtr conn_; + Event::TimerPtr connect_timer_; + Stats::TimespanPtr conn_length_; + uint64_t remaining_requests_; + bool timed_out_; + }; + + typedef std::unique_ptr ActiveConnPtr; + + struct PendingRequest : LinkedObject, public ConnectionPool::Cancellable { + PendingRequest(ConnPoolImpl& parent, ConnectionPool::Callbacks& callbacks); + ~PendingRequest(); + + // ConnectionPool::Cancellable + void cancel() override { parent_.onPendingRequestCancel(*this); } + + ConnPoolImpl& parent_; + ConnectionPool::Callbacks& callbacks_; + }; + + typedef std::unique_ptr PendingRequestPtr; + + void assignConnection(ActiveConn& conn, ConnectionPool::Callbacks& callbacks); + void createNewConnection(); + void onConnectionEvent(ActiveConn& conn, Network::ConnectionEvent event); + void onPendingRequestCancel(PendingRequest& request); + virtual void onConnReleased(ActiveConn& conn); + virtual void onConnDestroyed(ActiveConn& conn); + void onUpstreamReady(); + void processIdleConnection(ActiveConn& conn, bool delay); + void checkForDrained(); + + Event::Dispatcher& dispatcher_; + Upstream::HostConstSharedPtr host_; + Upstream::ResourcePriority priority_; + const Network::ConnectionSocket::OptionsSharedPtr socket_options_; + + std::list ready_conns_; + std::list busy_conns_; + std::list pending_requests_; + std::list drained_callbacks_; + Stats::TimespanPtr conn_connect_ms_; + Event::TimerPtr upstream_ready_timer_; + bool upstream_ready_enabled_{false}; +}; + +} // namespace Tcp +} // namespace Envoy diff --git a/source/common/tcp_proxy/BUILD b/source/common/tcp_proxy/BUILD index fcd3000620fb8..61de5015db426 100644 --- a/source/common/tcp_proxy/BUILD +++ b/source/common/tcp_proxy/BUILD @@ -24,6 +24,7 @@ envoy_cc_library( "//include/envoy/stats:stats_interface", "//include/envoy/stats:stats_macros", "//include/envoy/stats:timespan", + "//include/envoy/tcp:conn_pool_interface", "//include/envoy/upstream:cluster_manager_interface", "//include/envoy/upstream:upstream_interface", "//source/common/access_log:access_log_lib", diff --git a/source/common/tcp_proxy/tcp_proxy.cc b/source/common/tcp_proxy/tcp_proxy.cc index 68d4880200be7..906bdeeb696ea 100644 --- a/source/common/tcp_proxy/tcp_proxy.cc +++ b/source/common/tcp_proxy/tcp_proxy.cc @@ -135,8 +135,12 @@ Filter::~Filter() { access_log->log(nullptr, nullptr, nullptr, getRequestInfo()); } - if (upstream_connection_) { - finalizeUpstreamConnectionStats(); + if (upstream_handle_) { + upstream_handle_->cancel(); + } + + if (upstream_conn_data_) { + upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); } } @@ -144,21 +148,6 @@ TcpProxyStats Config::SharedConfig::generateStats(Stats::Scope& scope) { return {ALL_TCP_PROXY_STATS(POOL_COUNTER(scope), POOL_GAUGE(scope))}; } -namespace { -void finalizeConnectionStats(const Upstream::HostDescription& host, - Stats::Timespan connected_timespan) { - host.cluster().stats().upstream_cx_destroy_.inc(); - host.cluster().stats().upstream_cx_active_.dec(); - host.stats().cx_active_.dec(); - host.cluster().resourceManager(Upstream::ResourcePriority::Default).connections().dec(); - connected_timespan.complete(); -} -} // namespace - -void Filter::finalizeUpstreamConnectionStats() { - finalizeConnectionStats(*read_callbacks_->upstreamHost(), *connected_timespan_); -} - void Filter::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) { initialize(callbacks, true); } @@ -188,15 +177,15 @@ void Filter::initialize(Network::ReadFilterCallbacks& callbacks, bool set_connec } void Filter::readDisableUpstream(bool disable) { - if (upstream_connection_ == nullptr || - upstream_connection_->state() != Network::Connection::State::Open) { + if (upstream_conn_data_ == nullptr || + upstream_conn_data_->connection().state() != Network::Connection::State::Open) { // Because we flush write downstream, we can have a case where upstream has already disconnected // and we are waiting to flush. If we had a watermark event during this time we should no // longer touch the upstream connection. return; } - upstream_connection_->readDisable(disable); + upstream_conn_data_->connection().readDisable(disable); if (disable) { read_callbacks_->upstreamHost() ->cluster() @@ -262,13 +251,12 @@ void Filter::UpstreamCallbacks::onBelowWriteBufferLowWatermark() { } } -Network::FilterStatus Filter::UpstreamCallbacks::onData(Buffer::Instance& data, bool end_stream) { +void Filter::UpstreamCallbacks::onUpstreamData(Buffer::Instance& data, bool end_stream) { if (parent_) { parent_->onUpstreamData(data, end_stream); } else { drainer_->onData(data, end_stream); } - return Network::FilterStatus::StopIteration; } void Filter::UpstreamCallbacks::onBytesSent() { @@ -294,7 +282,7 @@ void Filter::UpstreamCallbacks::drain(Drainer& drainer) { } Network::FilterStatus Filter::initializeUpstreamConnection() { - ASSERT(upstream_connection_ == nullptr); + ASSERT(upstream_conn_data_ == nullptr); const std::string& cluster_name = getUpstreamCluster(); @@ -311,6 +299,9 @@ Network::FilterStatus Filter::initializeUpstreamConnection() { } Upstream::ClusterInfoConstSharedPtr cluster = thread_local_cluster->info(); + + // Check this here because the TCP conn pool will queue our request waiting for a connection that + // will never be released. if (!cluster->resourceManager(Upstream::ResourcePriority::Default).connections().canCreate()) { getRequestInfo().setResponseFlag(RequestInfo::ResponseFlag::UpstreamOverflow); cluster->stats().upstream_cx_overflow_.inc(); @@ -325,86 +316,105 @@ Network::FilterStatus Filter::initializeUpstreamConnection() { return Network::FilterStatus::StopIteration; } - Upstream::Host::CreateConnectionData conn_info = - cluster_manager_.tcpConnForCluster(cluster_name, this); - - upstream_connection_ = std::move(conn_info.connection_); - read_callbacks_->upstreamHost(conn_info.host_description_); - if (!upstream_connection_) { - // tcpConnForCluster() increments cluster->stats().upstream_cx_none_healthy. + Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( + cluster_name, Upstream::ResourcePriority::Default, this); + if (!conn_pool) { + // Either cluster is unknown or there are no healthy hosts. tcpConnPoolForCluster() increments + // cluster->stats().upstream_cx_none_healthy in the latter case. getRequestInfo().setResponseFlag(RequestInfo::ResponseFlag::NoHealthyUpstream); onInitFailure(UpstreamFailureReason::NO_HEALTHY_UPSTREAM); return Network::FilterStatus::StopIteration; } + connecting_ = true; connect_attempts_++; - cluster->resourceManager(Upstream::ResourcePriority::Default).connections().inc(); - upstream_connection_->addReadFilter(upstream_callbacks_); - upstream_connection_->addConnectionCallbacks(*upstream_callbacks_); - upstream_connection_->enableHalfClose(true); - upstream_connection_->setConnectionStats( - {read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_rx_bytes_total_, - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_rx_bytes_buffered_, - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_tx_bytes_total_, - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_tx_bytes_buffered_, - &read_callbacks_->upstreamHost()->cluster().stats().bind_errors_}); - upstream_connection_->connect(); - upstream_connection_->noDelay(true); - getRequestInfo().onUpstreamHostSelected(conn_info.host_description_); - getRequestInfo().setUpstreamLocalAddress(upstream_connection_->localAddress()); - - ASSERT(connect_timeout_timer_ == nullptr); - connect_timeout_timer_ = read_callbacks_->connection().dispatcher().createTimer( - [this]() -> void { onConnectTimeout(); }); - connect_timeout_timer_->enableTimer(cluster->connectTimeout()); - - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_total_.inc(); - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_active_.inc(); - read_callbacks_->upstreamHost()->stats().cx_total_.inc(); - read_callbacks_->upstreamHost()->stats().cx_active_.inc(); - connect_timespan_.reset(new Stats::Timespan( - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_connect_ms_)); - connected_timespan_.reset(new Stats::Timespan( - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_length_ms_)); - - return Network::FilterStatus::Continue; + + // Because we never return open connections to the pool, this should either return a handle while + // a connection completes or it invokes onPoolFailure inline. Either way, stop iteration. + upstream_handle_ = conn_pool->newConnection(*this); + return Network::FilterStatus::StopIteration; +} + +void Filter::onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) { + upstream_handle_ = nullptr; + + read_callbacks_->upstreamHost(host); + getRequestInfo().onUpstreamHostSelected(host); + + switch (reason) { + case Tcp::ConnectionPool::PoolFailureReason::Overflow: + case Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure: + upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose); + break; + + case Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure: + upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose); + break; + + case Tcp::ConnectionPool::PoolFailureReason::Timeout: + onConnectTimeout(); + break; + + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +void Filter::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn_data, + Upstream::HostDescriptionConstSharedPtr host) { + upstream_handle_ = nullptr; + upstream_conn_data_ = std::move(conn_data); + read_callbacks_->upstreamHost(host); + + upstream_conn_data_->addUpstreamCallbacks(*upstream_callbacks_); + + Network::ClientConnection& connection = upstream_conn_data_->connection(); + + connection.enableHalfClose(true); + + getRequestInfo().onUpstreamHostSelected(host); + getRequestInfo().setUpstreamLocalAddress(connection.localAddress()); + + // Simulate the event that onPoolReady represents. + upstream_callbacks_->onEvent(Network::ConnectionEvent::Connected); + + read_callbacks_->continueReading(); } void Filter::onConnectTimeout() { ENVOY_CONN_LOG(debug, "connect timeout", read_callbacks_->connection()); read_callbacks_->upstreamHost()->outlierDetector().putResult(Upstream::Outlier::Result::TIMEOUT); - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_connect_timeout_.inc(); getRequestInfo().setResponseFlag(RequestInfo::ResponseFlag::UpstreamConnectionFailure); - // This will cause a LocalClose event to be raised, which will trigger a reconnect if - // needed/configured. - upstream_connection_->close(Network::ConnectionCloseType::NoFlush); + // Raise LocalClose, which will trigger a reconnect if needed/configured. + upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose); } Network::FilterStatus Filter::onData(Buffer::Instance& data, bool end_stream) { ENVOY_CONN_LOG(trace, "downstream connection received {} bytes, end_stream={}", read_callbacks_->connection(), data.length(), end_stream); getRequestInfo().addBytesReceived(data.length()); - upstream_connection_->write(data, end_stream); + upstream_conn_data_->connection().write(data, end_stream); ASSERT(0 == data.length()); resetIdleTimer(); // TODO(ggreenway) PERF: do we need to reset timer on both send and receive? return Network::FilterStatus::StopIteration; } void Filter::onDownstreamEvent(Network::ConnectionEvent event) { - if (upstream_connection_) { + if (upstream_conn_data_) { if (event == Network::ConnectionEvent::RemoteClose) { - upstream_connection_->close(Network::ConnectionCloseType::FlushWrite); + upstream_conn_data_->connection().close(Network::ConnectionCloseType::FlushWrite); - if (upstream_connection_ != nullptr && - upstream_connection_->state() != Network::Connection::State::Closed) { - config_->drainManager().add(config_->sharedConfig(), std::move(upstream_connection_), + if (upstream_conn_data_ != nullptr && + upstream_conn_data_->connection().state() != Network::Connection::State::Closed) { + config_->drainManager().add(config_->sharedConfig(), std::move(upstream_conn_data_), std::move(upstream_callbacks_), std::move(idle_timer_), - read_callbacks_->upstreamHost(), - std::move(connected_timespan_)); + read_callbacks_->upstreamHost()); } } else if (event == Network::ConnectionEvent::LocalClose) { - upstream_connection_->close(Network::ConnectionCloseType::NoFlush); + upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); + upstream_conn_data_.reset(); disableIdleTimer(); } } @@ -420,36 +430,21 @@ void Filter::onUpstreamData(Buffer::Instance& data, bool end_stream) { } void Filter::onUpstreamEvent(Network::ConnectionEvent event) { - bool connecting = false; - - // The timer must be cleared before, not after, processing the event because - // if initializeUpstreamConnection() is called it will reset the timer, so - // clearing after that call will leave the timer unset. - if (connect_timeout_timer_) { - connecting = true; - connect_timeout_timer_->disableTimer(); - connect_timeout_timer_.reset(); - } + // Update the connecting flag before processing the event because we may start a new connection + // attempt in initializeUpstreamConnection. + bool connecting = connecting_; + connecting_ = false; if (event == Network::ConnectionEvent::RemoteClose || event == Network::ConnectionEvent::LocalClose) { - finalizeUpstreamConnectionStats(); - read_callbacks_->connection().dispatcher().deferredDelete(std::move(upstream_connection_)); + upstream_conn_data_.reset(); disableIdleTimer(); - auto& destroy_ctx_stat = - (event == Network::ConnectionEvent::RemoteClose) - ? read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_destroy_remote_ - : read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_destroy_local_; - destroy_ctx_stat.inc(); - if (connecting) { if (event == Network::ConnectionEvent::RemoteClose) { getRequestInfo().setResponseFlag(RequestInfo::ResponseFlag::UpstreamConnectionFailure); read_callbacks_->upstreamHost()->outlierDetector().putResult( Upstream::Outlier::Result::CONNECT_FAILED); - read_callbacks_->upstreamHost()->cluster().stats().upstream_cx_connect_fail_.inc(); - read_callbacks_->upstreamHost()->stats().cx_connect_fail_.inc(); } initializeUpstreamConnection(); @@ -459,8 +454,6 @@ void Filter::onUpstreamEvent(Network::ConnectionEvent event) { } } } else if (event == Network::ConnectionEvent::Connected) { - connect_timespan_->complete(); - // Re-enable downstream reads now that the upstream connection is established // so we have a place to send downstream data to. read_callbacks_->connection().readDisable(false); @@ -480,8 +473,10 @@ void Filter::onUpstreamEvent(Network::ConnectionEvent event) { }); resetIdleTimer(); read_callbacks_->connection().addBytesSentCallback([this](uint64_t) { resetIdleTimer(); }); - upstream_connection_->addBytesSentCallback([upstream_callbacks = upstream_callbacks_]( - uint64_t) { upstream_callbacks->onBytesSent(); }); + upstream_conn_data_->connection().addBytesSentCallback([upstream_callbacks = + upstream_callbacks_](uint64_t) { + upstream_callbacks->onBytesSent(); + }); } } } @@ -523,14 +518,12 @@ UpstreamDrainManager::~UpstreamDrainManager() { } void UpstreamDrainManager::add(const Config::SharedConfigSharedPtr& config, - Network::ClientConnectionPtr&& upstream_connection, + Tcp::ConnectionPool::ConnectionDataPtr&& upstream_conn_data, const std::shared_ptr& callbacks, Event::TimerPtr&& idle_timer, - const Upstream::HostDescriptionConstSharedPtr& upstream_host, - Stats::TimespanPtr&& connected_timespan) { - DrainerPtr drainer(new Drainer(*this, config, callbacks, std::move(upstream_connection), - std::move(idle_timer), upstream_host, - std::move(connected_timespan))); + const Upstream::HostDescriptionConstSharedPtr& upstream_host) { + DrainerPtr drainer(new Drainer(*this, config, callbacks, std::move(upstream_conn_data), + std::move(idle_timer), upstream_host)); callbacks->drain(*drainer); // Use temporary to ensure we get the pointer before we move it out of drainer @@ -547,12 +540,10 @@ void UpstreamDrainManager::remove(Drainer& drainer, Event::Dispatcher& dispatche Drainer::Drainer(UpstreamDrainManager& parent, const Config::SharedConfigSharedPtr& config, const std::shared_ptr& callbacks, - Network::ClientConnectionPtr&& connection, Event::TimerPtr&& idle_timer, - const Upstream::HostDescriptionConstSharedPtr& upstream_host, - Stats::TimespanPtr&& connected_timespan) - : parent_(parent), callbacks_(callbacks), upstream_connection_(std::move(connection)), - timer_(std::move(idle_timer)), connected_timespan_(std::move(connected_timespan)), - upstream_host_(upstream_host), config_(config) { + Tcp::ConnectionPool::ConnectionDataPtr&& conn_data, Event::TimerPtr&& idle_timer, + const Upstream::HostDescriptionConstSharedPtr& upstream_host) + : parent_(parent), callbacks_(callbacks), upstream_conn_data_(std::move(conn_data)), + timer_(std::move(idle_timer)), upstream_host_(upstream_host), config_(config) { config_->stats().upstream_flush_total_.inc(); config_->stats().upstream_flush_active_.inc(); } @@ -564,8 +555,7 @@ void Drainer::onEvent(Network::ConnectionEvent event) { timer_->disableTimer(); } config_->stats().upstream_flush_active_.dec(); - finalizeConnectionStats(*upstream_host_, *connected_timespan_); - parent_.remove(*this, upstream_connection_->dispatcher()); + parent_.remove(*this, upstream_conn_data_->connection().dispatcher()); } } @@ -592,7 +582,7 @@ void Drainer::onBytesSent() { void Drainer::cancelDrain() { // This sends onEvent(LocalClose). - upstream_connection_->close(Network::ConnectionCloseType::NoFlush); + upstream_conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); } } // namespace TcpProxy diff --git a/source/common/tcp_proxy/tcp_proxy.h b/source/common/tcp_proxy/tcp_proxy.h index 3c9173edc3328..61a803db85cca 100644 --- a/source/common/tcp_proxy/tcp_proxy.h +++ b/source/common/tcp_proxy/tcp_proxy.h @@ -14,6 +14,7 @@ #include "envoy/server/filter_config.h" #include "envoy/stats/stats_macros.h" #include "envoy/stats/timespan.h" +#include "envoy/tcp/conn_pool.h" #include "envoy/upstream/cluster_manager.h" #include "envoy/upstream/upstream.h" @@ -138,6 +139,7 @@ typedef std::shared_ptr ConfigSharedPtr; * be proxied back and forth between the two connections. */ class Filter : public Network::ReadFilter, + Tcp::ConnectionPool::Callbacks, Upstream::LoadBalancerContext, protected Logger::Loggable { public: @@ -149,6 +151,12 @@ class Filter : public Network::ReadFilter, Network::FilterStatus onNewConnection() override { return initializeUpstreamConnection(); } void initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) override; + // Tcp::ConnectionPool::Callbacks + void onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) override; + void onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn_data, + Upstream::HostDescriptionConstSharedPtr host) override; + // Upstream::LoadBalancerContext absl::optional computeHashKey() override { return {}; } const Router::MetadataMatchCriteria* metadataMatchCriteria() override { @@ -166,18 +174,15 @@ class Filter : public Network::ReadFilter, void readDisableUpstream(bool disable); void readDisableDownstream(bool disable); - struct UpstreamCallbacks : public Network::ConnectionCallbacks, - public Network::ReadFilterBaseImpl { + struct UpstreamCallbacks : public Tcp::ConnectionPool::UpstreamCallbacks { UpstreamCallbacks(Filter* parent) : parent_(parent) {} - // Network::ConnectionCallbacks + // Tcp::ConnectionPool::UpstreamCallbacks + void onUpstreamData(Buffer::Instance& data, bool end_stream) override; void onEvent(Network::ConnectionEvent event) override; void onAboveWriteBufferHighWatermark() override; void onBelowWriteBufferLowWatermark() override; - // Network::ReadFilter - Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; - void onBytesSent(); void onIdleTimeout(); void drain(Drainer& drainer); @@ -234,7 +239,6 @@ class Filter : public Network::ReadFilter, void onDownstreamEvent(Network::ConnectionEvent event); void onUpstreamData(Buffer::Instance& data, bool end_stream); void onUpstreamEvent(Network::ConnectionEvent event); - void finalizeUpstreamConnectionStats(); void onIdleTimeout(); void resetIdleTimer(); void disableIdleTimer(); @@ -242,29 +246,26 @@ class Filter : public Network::ReadFilter, const ConfigSharedPtr config_; Upstream::ClusterManager& cluster_manager_; Network::ReadFilterCallbacks* read_callbacks_{}; - Network::ClientConnectionPtr upstream_connection_; + Tcp::ConnectionPool::Cancellable* upstream_handle_{}; + Tcp::ConnectionPool::ConnectionDataPtr upstream_conn_data_; DownstreamCallbacks downstream_callbacks_; - Event::TimerPtr connect_timeout_timer_; Event::TimerPtr idle_timer_; - Stats::TimespanPtr connect_timespan_; - Stats::TimespanPtr connected_timespan_; std::shared_ptr upstream_callbacks_; // shared_ptr required for passing as a // read filter. RequestInfo::RequestInfoImpl request_info_; uint32_t connect_attempts_{}; + bool connecting_{}; }; -// This class holds ownership of an upstream connection that needs to finish -// flushing, when the downstream connection has been closed. The TcpProxy is -// destroyed when the downstream connection is closed, so moving the upstream -// connection here allows it to finish draining or timeout. +// This class deals with an upstream connection that needs to finish flushing, when the downstream +// connection has been closed. The TcpProxy is destroyed when the downstream connection is closed, +// so handling the upstream connection here allows it to finish draining or timeout. class Drainer : public Event::DeferredDeletable { public: Drainer(UpstreamDrainManager& parent, const Config::SharedConfigSharedPtr& config, const std::shared_ptr& callbacks, - Network::ClientConnectionPtr&& connection, Event::TimerPtr&& idle_timer, - const Upstream::HostDescriptionConstSharedPtr& upstream_host, - Stats::TimespanPtr&& connected_timespan); + Tcp::ConnectionPool::ConnectionDataPtr&& conn_data, Event::TimerPtr&& idle_timer, + const Upstream::HostDescriptionConstSharedPtr& upstream_host); void onEvent(Network::ConnectionEvent event); void onData(Buffer::Instance& data, bool end_stream); @@ -275,9 +276,8 @@ class Drainer : public Event::DeferredDeletable { private: UpstreamDrainManager& parent_; std::shared_ptr callbacks_; - Network::ClientConnectionPtr upstream_connection_; + Tcp::ConnectionPool::ConnectionDataPtr upstream_conn_data_; Event::TimerPtr timer_; - Stats::TimespanPtr connected_timespan_; Upstream::HostDescriptionConstSharedPtr upstream_host_; Config::SharedConfigSharedPtr config_; }; @@ -288,11 +288,10 @@ class UpstreamDrainManager : public ThreadLocal::ThreadLocalObject { public: ~UpstreamDrainManager(); void add(const Config::SharedConfigSharedPtr& config, - Network::ClientConnectionPtr&& upstream_connection, + Tcp::ConnectionPool::ConnectionDataPtr&& upstream_conn_data, const std::shared_ptr& callbacks, Event::TimerPtr&& idle_timer, - const Upstream::HostDescriptionConstSharedPtr& upstream_host, - Stats::TimespanPtr&& connected_timespan); + const Upstream::HostDescriptionConstSharedPtr& upstream_host); void remove(Drainer& drainer, Event::Dispatcher& dispatcher); private: diff --git a/source/common/tracing/http_tracer_impl.cc b/source/common/tracing/http_tracer_impl.cc index 121e0144f07a9..9f41c870289b2 100644 --- a/source/common/tracing/http_tracer_impl.cc +++ b/source/common/tracing/http_tracer_impl.cc @@ -50,7 +50,7 @@ const std::string& HttpTracerUtility::toString(OperationName operation_name) { return EGRESS_OPERATION; } - NOT_REACHED + NOT_REACHED_GCOVR_EXCL_LINE; } Decision HttpTracerUtility::isTracing(const RequestInfo::RequestInfo& request_info, @@ -79,7 +79,7 @@ Decision HttpTracerUtility::isTracing(const RequestInfo::RequestInfo& request_in return {Reason::NotTraceableRequestId, false}; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void HttpTracerUtility::finalizeSpan(Span& span, const Http::HeaderMap* request_headers, diff --git a/source/common/upstream/BUILD b/source/common/upstream/BUILD index 98c6cf013810d..d329feb0fa0e8 100644 --- a/source/common/upstream/BUILD +++ b/source/common/upstream/BUILD @@ -81,6 +81,7 @@ envoy_cc_library( "//source/common/network:utility_lib", "//source/common/protobuf:utility_lib", "//source/common/router:shadow_writer_lib", + "//source/common/tcp:conn_pool_lib", "//source/common/upstream:upstream_lib", "@envoy_api//envoy/admin/v2alpha:config_dump_cc", "@envoy_api//envoy/api/v2/core:base_cc", @@ -101,6 +102,7 @@ envoy_cc_library( "//include/envoy/upstream:health_checker_interface", "//source/common/router:router_lib", "@envoy_api//envoy/api/v2/core:health_check_cc", + "@envoy_api//envoy/data/core/v2alpha:health_check_event_cc", ], ) @@ -164,10 +166,17 @@ envoy_cc_library( srcs = ["health_discovery_service.cc"], hdrs = ["health_discovery_service.h"], deps = [ + ":health_checker_lib", + ":upstream_includes", "//include/envoy/event:dispatcher_interface", + "//include/envoy/runtime:runtime_interface", + "//include/envoy/ssl:context_manager_interface", "//include/envoy/stats:stats_macros", + "//include/envoy/upstream:cluster_manager_interface", + "//include/envoy/upstream:upstream_interface", "//source/common/common:minimal_logger_lib", "//source/common/grpc:async_client_lib", + "//source/common/network:resolver_lib", "@envoy_api//envoy/service/discovery/v2:hds_cc", ], ) diff --git a/source/common/upstream/cds_api_impl.cc b/source/common/upstream/cds_api_impl.cc index cc652c16ea35e..aab3a313cbb26 100644 --- a/source/common/upstream/cds_api_impl.cc +++ b/source/common/upstream/cds_api_impl.cc @@ -35,10 +35,11 @@ CdsApiImpl::CdsApiImpl(const envoy::api::v2::core::ConfigSource& cds_config, subscription_ = Config::SubscriptionFactory::subscriptionFromConfigSource( cds_config, local_info.node(), dispatcher, cm, random, *scope_, - [this, &cds_config, &eds_config, &cm, &dispatcher, &random, - &local_info]() -> Config::Subscription* { + [this, &cds_config, &eds_config, &cm, &dispatcher, &random, &local_info, + &scope]() -> Config::Subscription* { return new CdsSubscription(Config::Utility::generateStats(*scope_), cds_config, - eds_config, cm, dispatcher, random, local_info); + eds_config, cm, dispatcher, random, local_info, + scope.statsOptions()); }, "envoy.api.v2.ClusterDiscoveryService.FetchClusters", "envoy.api.v2.ClusterDiscoveryService.StreamClusters"); diff --git a/source/common/upstream/cds_subscription.cc b/source/common/upstream/cds_subscription.cc index d3287a88e5d2f..c571b855d313c 100644 --- a/source/common/upstream/cds_subscription.cc +++ b/source/common/upstream/cds_subscription.cc @@ -16,10 +16,11 @@ CdsSubscription::CdsSubscription( Config::SubscriptionStats stats, const envoy::api::v2::core::ConfigSource& cds_config, const absl::optional& eds_config, ClusterManager& cm, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, - const LocalInfo::LocalInfo& local_info) + const LocalInfo::LocalInfo& local_info, const Stats::StatsOptions& stats_options) : RestApiFetcher(cm, cds_config.api_config_source().cluster_names()[0], dispatcher, random, Config::Utility::apiConfigSourceRefreshDelay(cds_config.api_config_source())), - local_info_(local_info), stats_(stats), eds_config_(eds_config) { + local_info_(local_info), stats_(stats), eds_config_(eds_config), + stats_options_(stats_options) { const auto& api_config_source = cds_config.api_config_source(); UNREFERENCED_PARAMETER(api_config_source); // If we are building an CdsSubscription, the ConfigSource should be REST_LEGACY. @@ -46,7 +47,7 @@ void CdsSubscription::parseResponse(const Http::Message& response) { Protobuf::RepeatedPtrField resources; for (const Json::ObjectSharedPtr& cluster : clusters) { - Config::CdsJson::translateCluster(*cluster, eds_config_, *resources.Add()); + Config::CdsJson::translateCluster(*cluster, eds_config_, *resources.Add(), stats_options_); } std::pair hash = diff --git a/source/common/upstream/cds_subscription.h b/source/common/upstream/cds_subscription.h index 8e4f1454612db..f38ffe6ebbd9d 100644 --- a/source/common/upstream/cds_subscription.h +++ b/source/common/upstream/cds_subscription.h @@ -26,7 +26,8 @@ class CdsSubscription : public Http::RestApiFetcher, const envoy::api::v2::core::ConfigSource& cds_config, const absl::optional& eds_config, ClusterManager& cm, Event::Dispatcher& dispatcher, - Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info); + Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info, + const Stats::StatsOptions& stats_options); private: // Config::Subscription @@ -42,7 +43,7 @@ class CdsSubscription : public Http::RestApiFetcher, // We should never hit this at runtime, since this legacy adapter is only used by CdsApiImpl // that doesn't do dynamic modification of resources. UNREFERENCED_PARAMETER(resources); - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } // Http::RestApiFetcher @@ -55,6 +56,7 @@ class CdsSubscription : public Http::RestApiFetcher, Config::SubscriptionCallbacks* callbacks_ = nullptr; Config::SubscriptionStats stats_; const absl::optional& eds_config_; + const Stats::StatsOptions& stats_options_; }; } // namespace Upstream diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index 2c28735fee4b0..0b53d5119bf7e 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -27,6 +27,7 @@ #include "common/network/utility.h" #include "common/protobuf/utility.h" #include "common/router/shadow_writer_impl.h" +#include "common/tcp/conn_pool.h" #include "common/upstream/cds_api_impl.h" #include "common/upstream/load_balancer_impl.h" #include "common/upstream/maglev_lb.h" @@ -173,12 +174,13 @@ ClusterManagerImpl::ClusterManagerImpl(const envoy::config::bootstrap::v2::Boots Server::Admin& admin, SystemTimeSource& system_time_source, MonotonicTimeSource& monotonic_time_source) : factory_(factory), runtime_(runtime), stats_(stats), tls_(tls.allocateSlot()), - random_(random), bind_config_(bootstrap.cluster_manager().upstream_bind_config()), - local_info_(local_info), cm_stats_(generateStats(stats)), + random_(random), log_manager_(log_manager), + bind_config_(bootstrap.cluster_manager().upstream_bind_config()), local_info_(local_info), + cm_stats_(generateStats(stats)), init_helper_([this](Cluster& cluster) { onClusterInit(cluster); }), config_tracker_entry_( admin.getConfigTracker().add("clusters", [this] { return dumpClusterConfigs(); })), - system_time_source_(system_time_source) { + system_time_source_(system_time_source), dispatcher_(main_thread_dispatcher) { async_client_manager_ = std::make_unique(*this, tls); const auto& cm_config = bootstrap.cluster_manager(); if (cm_config.has_outlier_detection()) { @@ -214,7 +216,8 @@ ClusterManagerImpl::ClusterManagerImpl(const envoy::config::bootstrap::v2::Boots ->create(), main_thread_dispatcher, *Protobuf::DescriptorPool::generated_pool()->FindMethodByName( - "envoy.service.discovery.v2.AggregatedDiscoveryService.StreamAggregatedResources"))); + "envoy.service.discovery.v2.AggregatedDiscoveryService.StreamAggregatedResources"), + random_)); } else { ads_mux_.reset(new Config::NullGrpcMuxImpl()); } @@ -252,7 +255,7 @@ ClusterManagerImpl::ClusterManagerImpl(const envoy::config::bootstrap::v2::Boots } default: // Validated by schema. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -301,7 +304,7 @@ ClusterManagerImpl::ClusterManagerImpl(const envoy::config::bootstrap::v2::Boots Config::Utility::factoryForGrpcApiConfigSource( *async_client_manager_, load_stats_config, stats) ->create(), - main_thread_dispatcher)); + main_thread_dispatcher, ProdMonotonicTimeSource::instance_)); } } @@ -327,7 +330,35 @@ void ClusterManagerImpl::onClusterInit(Cluster& cluster) { const HostVector& hosts_removed) { // This fires when a cluster is about to have an updated member set. We need to send this // out to all of the thread local configurations. - postThreadLocalClusterUpdate(cluster, priority, hosts_added, hosts_removed); + + // Should we save this update and merge it with other updates? + // + // Note that we can only _safely_ merge updates that have no added/removed hosts. That is, + // only those updates that signal a change in host healthcheck state, weight or metadata. + // + // We've discussed merging updates related to hosts being added/removed, but it's really + // tricky to merge those given that downstream consumers of these updates expect to see the + // full list of updates, not a condensed one. This is because they use the broadcasted + // HostSharedPtrs within internal maps to track hosts. If we fail to broadcast the entire list + // of removals, these maps will leak those HostSharedPtrs. + // + // See https://github.com/envoyproxy/envoy/pull/3941 for more context. + bool scheduled = false; + const bool merging_enabled = cluster.info()->lbConfig().has_update_merge_window(); + // Remember: we only merge updates with no adds/removes — just hc/weight/metadata changes. + const bool is_mergeable = !hosts_added.size() && !hosts_removed.size(); + + if (merging_enabled) { + // If this is not mergeable, we should cancel any scheduled updates since + // we'll deliver it immediately. + scheduled = scheduleUpdate(cluster, priority, is_mergeable); + } + + // If an update was not scheduled for later, deliver it immediately. + if (!scheduled) { + cm_stats_.cluster_updated_.inc(); + postThreadLocalClusterUpdate(cluster, priority, hosts_added, hosts_removed); + } }); // Finally, if the cluster has any hosts, post updates cross-thread so the per-thread load @@ -340,6 +371,83 @@ void ClusterManagerImpl::onClusterInit(Cluster& cluster) { } } +bool ClusterManagerImpl::scheduleUpdate(const Cluster& cluster, uint32_t priority, bool mergeable) { + const auto& update_merge_window = cluster.info()->lbConfig().update_merge_window(); + const auto timeout = DurationUtil::durationToMilliseconds(update_merge_window); + + // Find pending updates for this cluster. + auto& updates_by_prio = updates_map_[cluster.info()->name()]; + if (!updates_by_prio) { + updates_by_prio.reset(new PendingUpdatesByPriorityMap()); + } + + // Find pending updates for this priority. + auto& updates = (*updates_by_prio)[priority]; + if (!updates) { + updates.reset(new PendingUpdates()); + } + + // Has an update_merge_window gone by since the last update? If so, don't schedule + // the update so it can be applied immediately. Ditto if this is not a mergeable update. + const auto delta = std::chrono::steady_clock::now() - updates->last_updated_; + const uint64_t delta_ms = std::chrono::duration_cast(delta).count(); + const bool out_of_merge_window = delta_ms > timeout; + if (out_of_merge_window || !mergeable) { + // If there was a pending update, we cancel the pending merged update. + // + // Note: it's possible that even though we are outside of a merge window (delta_ms > timeout), + // a timer is enabled. This race condition is fine, since we'll disable the timer here and + // deliver the update immediately. + + // Why wasn't the update scheduled for later delivery? We keep some stats that are helpful + // to understand why merging did not happen. There's 2 things we are tracking here: + + // 1) Was this update out of a merge window? + if (mergeable && out_of_merge_window) { + cm_stats_.update_out_of_merge_window_.inc(); + } + + // 2) Were there previous updates that we are cancelling (and delivering immediately)? + if (updates->disableTimer()) { + cm_stats_.update_merge_cancelled_.inc(); + } + + updates->last_updated_ = std::chrono::steady_clock::now(); + return false; + } + + // If there's no timer, create one. + if (updates->timer_ == nullptr) { + updates->timer_ = dispatcher_.createTimer([this, &cluster, priority, &updates]() -> void { + applyUpdates(cluster, priority, *updates); + }); + } + + // Ensure there's a timer set to deliver these updates. + if (!updates->timer_enabled_) { + updates->enableTimer(timeout); + } + + return true; +} + +void ClusterManagerImpl::applyUpdates(const Cluster& cluster, uint32_t priority, + PendingUpdates& updates) { + // Deliver pending updates. + + // Remember that these merged updates are _only_ for updates related to + // HC/weight/metadata changes. That's why added/removed are empty. All + // adds/removals were already immediately broadcasted. + static const HostVector hosts_added; + static const HostVector hosts_removed; + + postThreadLocalClusterUpdate(cluster, priority, hosts_added, hosts_removed); + + cm_stats_.cluster_updated_via_merge_.inc(); + updates.timer_enabled_ = false; + updates.last_updated_ = std::chrono::steady_clock::now(); +} + bool ClusterManagerImpl::addOrUpdateCluster(const envoy::api::v2::Cluster& cluster, const std::string& version_info) { // First we need to see if this new config is new or an update to an existing dynamic cluster. @@ -465,6 +573,9 @@ bool ClusterManagerImpl::removeCluster(const std::string& cluster_name) { if (removed) { cm_stats_.cluster_removed_.inc(); updateGauges(); + // Did we ever deliver merged updates for this cluster? + // No need to manually disable timers, this should take care of it. + updates_map_.erase(cluster_name); } return removed; @@ -474,7 +585,7 @@ void ClusterManagerImpl::loadCluster(const envoy::api::v2::Cluster& cluster, const std::string& version_info, bool added_via_api, ClusterMap& cluster_map) { ClusterSharedPtr new_cluster = - factory_.clusterFromProto(cluster, *this, outlier_event_logger_, added_via_api); + factory_.clusterFromProto(cluster, *this, outlier_event_logger_, log_manager_, added_via_api); if (!added_via_api) { if (cluster_map.find(new_cluster->info()->name()) != cluster_map.end()) { @@ -551,6 +662,20 @@ ClusterManagerImpl::httpConnPoolForCluster(const std::string& cluster, ResourceP return entry->second->connPool(priority, protocol, context); } +Tcp::ConnectionPool::Instance* +ClusterManagerImpl::tcpConnPoolForCluster(const std::string& cluster, ResourcePriority priority, + LoadBalancerContext* context) { + ThreadLocalClusterManagerImpl& cluster_manager = tls_->getTyped(); + + auto entry = cluster_manager.thread_local_clusters_.find(cluster); + if (entry == cluster_manager.thread_local_clusters_.end()) { + return nullptr; + } + + // Select a host and create a connection pool for it if it does not already exist. + return entry->second->tcpConnPool(priority, context); +} + void ClusterManagerImpl::postThreadLocalClusterUpdate(const Cluster& cluster, uint32_t priority, const HostVector& hosts_added, const HostVector& hosts_removed) { @@ -704,6 +829,7 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::~ThreadLocalClusterManagerImp ENVOY_LOG(debug, "shutting down thread local cluster manager"); destroying_ = true; host_http_conn_pool_map_.clear(); + host_tcp_conn_pool_map_.clear(); ASSERT(host_tcp_conn_map_.empty()); for (auto& cluster : thread_local_clusters_) { if (&cluster.second->priority_set_ != local_priority_set_) { @@ -715,9 +841,17 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::~ThreadLocalClusterManagerImp void ClusterManagerImpl::ThreadLocalClusterManagerImpl::drainConnPools(const HostVector& hosts) { for (const HostSharedPtr& host : hosts) { - auto container = host_http_conn_pool_map_.find(host); - if (container != host_http_conn_pool_map_.end()) { - drainConnPools(host, container->second); + { + auto container = host_http_conn_pool_map_.find(host); + if (container != host_http_conn_pool_map_.end()) { + drainConnPools(host, container->second); + } + } + { + auto container = host_tcp_conn_pool_map_.find(host); + if (container != host_tcp_conn_pool_map_.end()) { + drainTcpConnPools(host, container->second); + } } } } @@ -756,6 +890,40 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::drainConnPools( } } +void ClusterManagerImpl::ThreadLocalClusterManagerImpl::drainTcpConnPools( + HostSharedPtr old_host, TcpConnPoolsContainer& container) { + container.drains_remaining_ += container.pools_.size(); + + for (const auto& pair : container.pools_) { + pair.second->addDrainedCallback([this, old_host]() -> void { + if (destroying_) { + // It is possible for a connection pool to fire drain callbacks during destruction. Instead + // of checking if old_host actually exists in the map, it's clearer and cleaner to keep + // track of destruction as a separate state and check for it here. This also allows us to + // do this check here versus inside every different connection pool implementation. + return; + } + + TcpConnPoolsContainer& container = host_tcp_conn_pool_map_[old_host]; + ASSERT(container.drains_remaining_ > 0); + container.drains_remaining_--; + if (container.drains_remaining_ == 0) { + for (auto& pair : container.pools_) { + thread_local_dispatcher_.deferredDelete(std::move(pair.second)); + } + host_tcp_conn_pool_map_.erase(old_host); + } + }); + + // The above addDrainedCallback() drain completion callback might execute immediately. This can + // then effectively nuke 'container', which means we can't continue to loop on its contents + // (we're done here). + if (host_tcp_conn_pool_map_.count(old_host) == 0) { + break; + } + } +} + void ClusterManagerImpl::ThreadLocalClusterManagerImpl::removeTcpConn( const HostConstSharedPtr& host, Network::ClientConnection& connection) { auto host_tcp_conn_map_it = host_tcp_conn_map_.find(host); @@ -813,6 +981,15 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::onHostHealthFailure( } } } + { + const auto& container = config.host_tcp_conn_pool_map_.find(host); + if (container != config.host_tcp_conn_pool_map_.end()) { + for (const auto& pair : container->second.pools_) { + const Tcp::ConnectionPool::InstancePtr& pool = pair.second; + pool->drainConnections(); + } + } + } if (host->cluster().features() & ClusterInfo::Features::CLOSE_CONNECTIONS_ON_HOST_HEALTH_FAILURE) { @@ -950,6 +1127,44 @@ ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::connPool( return container.pools_[hash_key].get(); } +Tcp::ConnectionPool::Instance* +ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::tcpConnPool( + ResourcePriority priority, LoadBalancerContext* context) { + HostConstSharedPtr host = lb_->chooseHost(context); + if (!host) { + ENVOY_LOG(debug, "no healthy host for TCP connection pool"); + cluster_info_->stats().upstream_cx_none_healthy_.inc(); + return nullptr; + } + + // Inherit socket options from downstream connection, if set. + std::vector hash_key = {uint8_t(priority)}; + + // Use downstream connection socket options for computing connection pool hash key, if any. + // This allows socket options to control connection pooling so that connections with + // different options are not pooled together. + bool have_options = false; + if (context && context->downstreamConnection()) { + const Network::ConnectionSocket::OptionsSharedPtr& options = + context->downstreamConnection()->socketOptions(); + if (options) { + for (const auto& option : *options) { + have_options = true; + option->hashKey(hash_key); + } + } + } + + TcpConnPoolsContainer& container = parent_.host_tcp_conn_pool_map_[host]; + if (!container.pools_[hash_key]) { + container.pools_[hash_key] = parent_.parent_.factory_.allocateTcpConnPool( + parent_.thread_local_dispatcher_, host, priority, + have_options ? context->downstreamConnection()->socketOptions() : nullptr); + } + + return container.pools_[hash_key].get(); +} + ClusterManagerPtr ProdClusterManagerFactory::clusterManagerFromProto( const envoy::config::bootstrap::v2::Bootstrap& bootstrap, Stats::Store& stats, ThreadLocal::Instance& tls, Runtime::Loader& runtime, Runtime::RandomGenerator& random, @@ -974,12 +1189,20 @@ Http::ConnectionPool::InstancePtr ProdClusterManagerFactory::allocateConnPool( } } +Tcp::ConnectionPool::InstancePtr ProdClusterManagerFactory::allocateTcpConnPool( + Event::Dispatcher& dispatcher, HostConstSharedPtr host, ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options) { + return Tcp::ConnectionPool::InstancePtr{ + new Tcp::ConnPoolImpl(dispatcher, host, priority, options)}; +} + ClusterSharedPtr ProdClusterManagerFactory::clusterFromProto( const envoy::api::v2::Cluster& cluster, ClusterManager& cm, - Outlier::EventLoggerSharedPtr outlier_event_logger, bool added_via_api) { + Outlier::EventLoggerSharedPtr outlier_event_logger, AccessLog::AccessLogManager& log_manager, + bool added_via_api) { return ClusterImplBase::create(cluster, cm, stats_, tls_, dns_resolver_, ssl_context_manager_, - runtime_, random_, main_thread_dispatcher_, local_info_, - outlier_event_logger, added_via_api); + runtime_, random_, main_thread_dispatcher_, log_manager, + local_info_, outlier_event_logger, added_via_api); } CdsApiPtr ProdClusterManagerFactory::createCds( diff --git a/source/common/upstream/cluster_manager_impl.h b/source/common/upstream/cluster_manager_impl.h index 4bbffe3adf225..9acf6927c3f0a 100644 --- a/source/common/upstream/cluster_manager_impl.h +++ b/source/common/upstream/cluster_manager_impl.h @@ -53,8 +53,13 @@ class ProdClusterManagerFactory : public ClusterManagerFactory { allocateConnPool(Event::Dispatcher& dispatcher, HostConstSharedPtr host, ResourcePriority priority, Http::Protocol protocol, const Network::ConnectionSocket::OptionsSharedPtr& options) override; + Tcp::ConnectionPool::InstancePtr + allocateTcpConnPool(Event::Dispatcher& dispatcher, HostConstSharedPtr host, + ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options) override; ClusterSharedPtr clusterFromProto(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, + AccessLog::AccessLogManager& log_manager, bool added_via_api) override; CdsApiPtr createCds(const envoy::api::v2::core::ConfigSource& cds_config, const absl::optional& eds_config, @@ -133,6 +138,10 @@ class ClusterManagerInitHelper : Logger::Loggable { COUNTER(cluster_added) \ COUNTER(cluster_modified) \ COUNTER(cluster_removed) \ + COUNTER(cluster_updated) \ + COUNTER(cluster_updated_via_merge) \ + COUNTER(update_merge_cancelled) \ + COUNTER(update_out_of_merge_window) \ GAUGE (active_clusters) \ GAUGE (warming_clusters) // clang-format on @@ -179,6 +188,9 @@ class ClusterManagerImpl : public ClusterManager, Logger::Loggable, Tcp::ConnectionPool::InstancePtr> ConnPools; + + ConnPools pools_; + uint64_t drains_remaining_{}; + }; + // Holds an unowned reference to a connection, and watches for Closed events. If the connection // is closed, this container removes itself from the container that owns it. struct TcpConnContainer : public Network::ConnectionCallbacks, public Event::DeferredDeletable { @@ -250,6 +274,9 @@ class ClusterManagerImpl : public ClusterManager, Logger::Loggable host_http_conn_pool_map_; + std::unordered_map host_tcp_conn_pool_map_; std::unordered_map host_tcp_conn_map_; std::list update_callbacks_; @@ -341,14 +370,47 @@ class ClusterManagerImpl : public ClusterManager, Logger::Loggable ClusterMap; + struct PendingUpdates { + void enableTimer(const uint64_t timeout) { + ASSERT(!timer_enabled_); + if (timer_ != nullptr) { + timer_->enableTimer(std::chrono::milliseconds(timeout)); + timer_enabled_ = true; + } + } + bool disableTimer() { + const bool was_enabled = timer_enabled_; + if (timer_ != nullptr) { + timer_->disableTimer(); + timer_enabled_ = false; + } + return was_enabled; + } + + Event::TimerPtr timer_; + // TODO(rgs1): this should be part of Event::Timer's interface. + bool timer_enabled_{}; + // This is default constructed to the clock's epoch: + // https://en.cppreference.com/w/cpp/chrono/time_point/time_point + // + // This will usually be the computer's boot time, which means that given a not very large + // `Cluster.CommonLbConfig.update_merge_window`, the first update will trigger immediately + // (the expected behavior). + MonotonicTime last_updated_; + }; + using PendingUpdatesPtr = std::unique_ptr; + using PendingUpdatesByPriorityMap = std::unordered_map; + using PendingUpdatesByPriorityMapPtr = std::unique_ptr; + using ClusterUpdatesMap = std::unordered_map; + + void applyUpdates(const Cluster& cluster, uint32_t priority, PendingUpdates& updates); + bool scheduleUpdate(const Cluster& cluster, uint32_t priority, bool mergeable); void createOrUpdateThreadLocalCluster(ClusterData& cluster); ProtobufTypes::MessagePtr dumpClusterConfigs(); static ClusterManagerStats generateStats(Stats::Scope& scope); void loadCluster(const envoy::api::v2::Cluster& cluster, const std::string& version_info, bool added_via_api, ClusterMap& cluster_map); void onClusterInit(Cluster& cluster); - void postThreadLocalClusterUpdate(const Cluster& cluster, uint32_t priority, - const HostVector& hosts_added, const HostVector& hosts_removed); void postThreadLocalHealthFailure(const HostSharedPtr& host); void updateGauges(); @@ -357,6 +419,7 @@ class ClusterManagerImpl : public ClusterManager, Logger::Loggable eds_config_; @@ -373,6 +436,8 @@ class ClusterManagerImpl : public ClusterManager, Logger::Loggablestart({cluster_name_}, *this); } void EdsClusterImpl::onConfigUpdate(const ResourceVector& resources, const std::string&) { - typedef std::unique_ptr HostListPtr; - std::vector> priority_state; if (resources.empty()) { ENVOY_LOG(debug, "Missing ClusterLoadAssignment for {} in onConfigUpdate()", cluster_name_); info_->stats().update_empty_.inc(); @@ -63,33 +61,19 @@ void EdsClusterImpl::onConfigUpdate(const ResourceVector& resources, const std:: throw EnvoyException(fmt::format("Unexpected EDS cluster (expecting {}): {}", cluster_name_, cluster_load_assignment.cluster_name())); } + PriorityStateManager priority_state_manager(*this, local_info_); for (const auto& locality_lb_endpoint : cluster_load_assignment.endpoints()) { const uint32_t priority = locality_lb_endpoint.priority(); if (priority > 0 && !cluster_name_.empty() && cluster_name_ == cm_.localClusterName()) { throw EnvoyException( fmt::format("Unexpected non-zero priority for local cluster '{}'.", cluster_name_)); } - if (priority_state.size() <= priority + 1) { - priority_state.resize(priority + 1); - } - if (priority_state[priority].first == nullptr) { - priority_state[priority].first.reset(new HostVector()); - } - if (locality_lb_endpoint.has_locality() && locality_lb_endpoint.has_load_balancing_weight()) { - priority_state[priority].second[locality_lb_endpoint.locality()] = - locality_lb_endpoint.load_balancing_weight().value(); - } + priority_state_manager.initializePriorityFor(locality_lb_endpoint); + for (const auto& lb_endpoint : locality_lb_endpoint.lb_endpoints()) { - priority_state[priority].first->emplace_back(new HostImpl( - info_, "", resolveProtoAddress(lb_endpoint.endpoint().address()), lb_endpoint.metadata(), - lb_endpoint.load_balancing_weight().value(), locality_lb_endpoint.locality(), - lb_endpoint.endpoint().health_check_config())); - const auto& health_status = lb_endpoint.health_status(); - if (health_status == envoy::api::v2::core::HealthStatus::UNHEALTHY || - health_status == envoy::api::v2::core::HealthStatus::DRAINING || - health_status == envoy::api::v2::core::HealthStatus::TIMEOUT) { - priority_state[priority].first->back()->healthFlagSet(Host::HealthFlag::FAILED_EDS_HEALTH); - } + priority_state_manager.registerHostForPriority( + "", resolveProtoAddress(lb_endpoint.endpoint().address()), locality_lb_endpoint, + lb_endpoint, Host::HealthFlag::FAILED_EDS_HEALTH); } } @@ -98,14 +82,15 @@ void EdsClusterImpl::onConfigUpdate(const ResourceVector& resources, const std:: // Loop over existing priorities not present in the config. This will empty out any priorities // the config update did not refer to + auto& priority_state = priority_state_manager.priorityState(); for (size_t i = 0; i < priority_state.size(); ++i) { if (priority_state[i].first != nullptr) { if (locality_weights_map_.size() <= i) { locality_weights_map_.resize(i + 1); } cluster_rebuilt |= - updateHostsPerLocality(priority_set_.getOrCreateHostSet(i), *priority_state[i].first, - locality_weights_map_[i], priority_state[i].second); + updateHostsPerLocality(i, *priority_state[i].first, locality_weights_map_[i], + priority_state[i].second, priority_state_manager); } } @@ -118,8 +103,8 @@ void EdsClusterImpl::onConfigUpdate(const ResourceVector& resources, const std:: if (locality_weights_map_.size() <= i) { locality_weights_map_.resize(i + 1); } - cluster_rebuilt |= updateHostsPerLocality(priority_set_.getOrCreateHostSet(i), empty_hosts, - locality_weights_map_[i], empty_locality_map); + cluster_rebuilt |= updateHostsPerLocality(i, empty_hosts, locality_weights_map_[i], + empty_locality_map, priority_state_manager); } if (!cluster_rebuilt) { @@ -131,9 +116,11 @@ void EdsClusterImpl::onConfigUpdate(const ResourceVector& resources, const std:: onPreInitComplete(); } -bool EdsClusterImpl::updateHostsPerLocality(HostSet& host_set, const HostVector& new_hosts, +bool EdsClusterImpl::updateHostsPerLocality(const uint32_t priority, const HostVector& new_hosts, LocalityWeightsMap& locality_weights_map, - LocalityWeightsMap& new_locality_weights_map) { + LocalityWeightsMap& new_locality_weights_map, + PriorityStateManager& priority_state_manager) { + const auto& host_set = priority_set_.getOrCreateHostSet(priority); HostVectorSharedPtr current_hosts_copy(new HostVector(host_set.hosts())); HostVector hosts_added; @@ -149,60 +136,11 @@ bool EdsClusterImpl::updateHostsPerLocality(HostSet& host_set, const HostVector& if (updateDynamicHostList(new_hosts, *current_hosts_copy, hosts_added, hosts_removed) || locality_weights_map != new_locality_weights_map) { locality_weights_map = new_locality_weights_map; - LocalityWeightsSharedPtr locality_weights; ENVOY_LOG(debug, "EDS hosts or locality weights changed for cluster: {} ({}) priority {}", info_->name(), host_set.hosts().size(), host_set.priority()); - std::vector per_locality; - - // If we are configured for locality weighted LB we populate the locality - // weights. - const bool locality_weighted_lb = info()->lbConfig().has_locality_weighted_lb_config(); - if (locality_weighted_lb) { - locality_weights = std::make_shared(); - } - // If local locality is not defined then skip populating per locality hosts. - const auto& local_locality = local_info_.node().locality(); - ENVOY_LOG(trace, "Local locality: {}", local_info_.node().locality().DebugString()); - - // We use std::map to guarantee a stable ordering for zone aware routing. - std::map hosts_per_locality; - - for (const HostSharedPtr& host : *current_hosts_copy) { - hosts_per_locality[host->locality()].push_back(host); - } - - // Do we have hosts for the local locality? - const bool non_empty_local_locality = - local_info_.node().has_locality() && - hosts_per_locality.find(local_locality) != hosts_per_locality.end(); - - // As per HostsPerLocality::get(), the per_locality vector must have the - // local locality hosts first if non_empty_local_locality. - if (non_empty_local_locality) { - per_locality.emplace_back(hosts_per_locality[local_locality]); - if (locality_weighted_lb) { - locality_weights->emplace_back(new_locality_weights_map[local_locality]); - } - } - - // After the local locality hosts (if any), we place the remaining locality - // host groups in lexicographic order. This provides a stable ordering for - // zone aware routing. - for (auto& entry : hosts_per_locality) { - if (!non_empty_local_locality || !LocalityEqualTo()(local_locality, entry.first)) { - per_locality.emplace_back(entry.second); - if (locality_weighted_lb) { - locality_weights->emplace_back(new_locality_weights_map[entry.first]); - } - } - } - - auto per_locality_shared = - std::make_shared(std::move(per_locality), non_empty_local_locality); - host_set.updateHosts(current_hosts_copy, createHealthyHostList(*current_hosts_copy), - per_locality_shared, createHealthyHostLists(*per_locality_shared), - std::move(locality_weights), hosts_added, hosts_removed); + priority_state_manager.updateClusterPrioritySet(priority, std::move(current_hosts_copy), + hosts_added, hosts_removed, absl::nullopt); return true; } return false; diff --git a/source/common/upstream/eds.h b/source/common/upstream/eds.h index b47cc59034df2..d84f02091799e 100644 --- a/source/common/upstream/eds.h +++ b/source/common/upstream/eds.h @@ -37,9 +37,10 @@ class EdsClusterImpl : public BaseDynamicClusterImpl, private: using LocalityWeightsMap = std::unordered_map; - bool updateHostsPerLocality(HostSet& host_set, const HostVector& new_hosts, + bool updateHostsPerLocality(const uint32_t priority, const HostVector& new_hosts, LocalityWeightsMap& locality_weights_map, - LocalityWeightsMap& new_locality_weights_map); + LocalityWeightsMap& new_locality_weights_map, + PriorityStateManager& priority_state_manager); // ClusterImplBase void startPreInit() override; diff --git a/source/common/upstream/health_checker_base_impl.cc b/source/common/upstream/health_checker_base_impl.cc index b07cdb9e6e887..a87f8579bc7df 100644 --- a/source/common/upstream/health_checker_base_impl.cc +++ b/source/common/upstream/health_checker_base_impl.cc @@ -1,5 +1,7 @@ #include "common/upstream/health_checker_base_impl.h" +#include "envoy/data/core/v2alpha/health_check_event.pb.h" + #include "common/router/router.h" namespace Envoy { @@ -9,16 +11,18 @@ HealthCheckerImplBase::HealthCheckerImplBase(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random) + Runtime::RandomGenerator& random, + HealthCheckEventLoggerPtr&& event_logger) : cluster_(cluster), dispatcher_(dispatcher), timeout_(PROTOBUF_GET_MS_REQUIRED(config, timeout)), unhealthy_threshold_(PROTOBUF_GET_WRAPPED_REQUIRED(config, unhealthy_threshold)), healthy_threshold_(PROTOBUF_GET_WRAPPED_REQUIRED(config, healthy_threshold)), stats_(generateStats(cluster.info()->statsScope())), runtime_(runtime), random_(random), reuse_connection_(PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, reuse_connection, true)), - interval_(PROTOBUF_GET_MS_REQUIRED(config, interval)), + event_logger_(std::move(event_logger)), interval_(PROTOBUF_GET_MS_REQUIRED(config, interval)), no_traffic_interval_(PROTOBUF_GET_MS_OR_DEFAULT(config, no_traffic_interval, 60000)), interval_jitter_(PROTOBUF_GET_MS_OR_DEFAULT(config, interval_jitter, 0)), + interval_jitter_percent_(config.interval_jitter_percent()), unhealthy_interval_( PROTOBUF_GET_MS_OR_DEFAULT(config, unhealthy_interval, interval_.count())), unhealthy_edge_interval_( @@ -82,6 +86,10 @@ std::chrono::milliseconds HealthCheckerImplBase::interval(HealthState state, base_time_ms = no_traffic_interval_.count(); } + if (interval_jitter_percent_ > 0) { + base_time_ms += random_.random() % (interval_jitter_percent_ * base_time_ms / 100); + } + if (interval_jitter_.count() > 0) { base_time_ms += (random_.random() % interval_jitter_.count()); } @@ -160,7 +168,7 @@ void HealthCheckerImplBase::setUnhealthyCrossThread(const HostSharedPtr& host) { return; } - session->second->setUnhealthy(ActiveHealthCheckSession::FailureType::Passive); + session->second->setUnhealthy(envoy::data::core::v2alpha::HealthCheckFailureType::PASSIVE); }); } @@ -200,6 +208,9 @@ void HealthCheckerImplBase::ActiveHealthCheckSession::handleSuccess() { host_->healthFlagClear(Host::HealthFlag::FAILED_ACTIVE_HC); parent_.incHealthy(); changed_state = HealthTransition::Changed; + if (parent_.event_logger_) { + parent_.event_logger_->logAddHealthy(parent_.healthCheckerType(), host_, first_check_); + } } else { changed_state = HealthTransition::ChangePending; } @@ -213,25 +224,30 @@ void HealthCheckerImplBase::ActiveHealthCheckSession::handleSuccess() { interval_timer_->enableTimer(parent_.interval(HealthState::Healthy, changed_state)); } -HealthTransition HealthCheckerImplBase::ActiveHealthCheckSession::setUnhealthy(FailureType type) { +HealthTransition HealthCheckerImplBase::ActiveHealthCheckSession::setUnhealthy( + envoy::data::core::v2alpha::HealthCheckFailureType type) { // If we are unhealthy, reset the # of healthy to zero. num_healthy_ = 0; HealthTransition changed_state = HealthTransition::Unchanged; if (!host_->healthFlagGet(Host::HealthFlag::FAILED_ACTIVE_HC)) { - if (type != FailureType::Network || ++num_unhealthy_ == parent_.unhealthy_threshold_) { + if (type != envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK || + ++num_unhealthy_ == parent_.unhealthy_threshold_) { host_->healthFlagSet(Host::HealthFlag::FAILED_ACTIVE_HC); parent_.decHealthy(); changed_state = HealthTransition::Changed; + if (parent_.event_logger_) { + parent_.event_logger_->logEjectUnhealthy(parent_.healthCheckerType(), host_, type); + } } else { changed_state = HealthTransition::ChangePending; } } parent_.stats_.failure_.inc(); - if (type == FailureType::Network) { + if (type == envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK) { parent_.stats_.network_failure_.inc(); - } else if (type == FailureType::Passive) { + } else if (type == envoy::data::core::v2alpha::HealthCheckFailureType::PASSIVE) { parent_.stats_.passive_failure_.inc(); } @@ -240,7 +256,8 @@ HealthTransition HealthCheckerImplBase::ActiveHealthCheckSession::setUnhealthy(F return changed_state; } -void HealthCheckerImplBase::ActiveHealthCheckSession::handleFailure(FailureType type) { +void HealthCheckerImplBase::ActiveHealthCheckSession::handleFailure( + envoy::data::core::v2alpha::HealthCheckFailureType type) { HealthTransition changed_state = setUnhealthy(type); timeout_timer_->disableTimer(); interval_timer_->enableTimer(parent_.interval(HealthState::Unhealthy, changed_state)); @@ -254,7 +271,40 @@ void HealthCheckerImplBase::ActiveHealthCheckSession::onIntervalBase() { void HealthCheckerImplBase::ActiveHealthCheckSession::onTimeoutBase() { onTimeout(); - handleFailure(FailureType::Network); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK); +} + +void HealthCheckEventLoggerImpl::logEjectUnhealthy( + envoy::data::core::v2alpha::HealthCheckerType health_checker_type, + const HostDescriptionConstSharedPtr& host, + envoy::data::core::v2alpha::HealthCheckFailureType failure_type) { + envoy::data::core::v2alpha::HealthCheckEvent event; + event.set_health_checker_type(health_checker_type); + envoy::api::v2::core::Address address; + Network::Utility::addressToProtobufAddress(*host->address(), address); + *event.mutable_host() = std::move(address); + event.set_cluster_name(host->cluster().name()); + event.mutable_eject_unhealthy_event()->set_failure_type(failure_type); + // Make sure the type enums make it into the JSON + const auto json = MessageUtil::getJsonStringFromMessage(event, /* pretty_print */ false, + /* always_print_primitive_fields */ true); + file_->write(fmt::format("{}\n", json)); +} + +void HealthCheckEventLoggerImpl::logAddHealthy( + envoy::data::core::v2alpha::HealthCheckerType health_checker_type, + const HostDescriptionConstSharedPtr& host, bool first_check) { + envoy::data::core::v2alpha::HealthCheckEvent event; + event.set_health_checker_type(health_checker_type); + envoy::api::v2::core::Address address; + Network::Utility::addressToProtobufAddress(*host->address(), address); + *event.mutable_host() = std::move(address); + event.set_cluster_name(host->cluster().name()); + event.mutable_add_healthy_event()->set_first_check(first_check); + // Make sure the type enums make it into the JSON + const auto json = MessageUtil::getJsonStringFromMessage(event, /* pretty_print */ false, + /* always_print_primitive_fields */ true); + file_->write(fmt::format("{}\n", json)); } } // namespace Upstream diff --git a/source/common/upstream/health_checker_base_impl.h b/source/common/upstream/health_checker_base_impl.h index a504ad8222bbd..644c1d174c579 100644 --- a/source/common/upstream/health_checker_base_impl.h +++ b/source/common/upstream/health_checker_base_impl.h @@ -1,5 +1,6 @@ #pragma once +#include "envoy/access_log/access_log.h" #include "envoy/api/v2/core/health_check.pb.h" #include "envoy/event/timer.h" #include "envoy/runtime/runtime.h" @@ -45,17 +46,15 @@ class HealthCheckerImplBase : public HealthChecker, protected: class ActiveHealthCheckSession { public: - enum class FailureType { Active, Passive, Network }; - virtual ~ActiveHealthCheckSession(); - HealthTransition setUnhealthy(FailureType type); + HealthTransition setUnhealthy(envoy::data::core::v2alpha::HealthCheckFailureType type); void start() { onIntervalBase(); } protected: ActiveHealthCheckSession(HealthCheckerImplBase& parent, HostSharedPtr host); void handleSuccess(); - void handleFailure(FailureType type); + void handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType type); HostSharedPtr host_; @@ -77,9 +76,10 @@ class HealthCheckerImplBase : public HealthChecker, HealthCheckerImplBase(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random); + Runtime::RandomGenerator& random, HealthCheckEventLoggerPtr&& event_logger); virtual ActiveHealthCheckSessionPtr makeSession(HostSharedPtr host) PURE; + virtual envoy::data::core::v2alpha::HealthCheckerType healthCheckerType() const PURE; const Cluster& cluster_; Event::Dispatcher& dispatcher_; @@ -90,6 +90,7 @@ class HealthCheckerImplBase : public HealthChecker, Runtime::Loader& runtime_; Runtime::RandomGenerator& random_; const bool reuse_connection_; + HealthCheckEventLoggerPtr event_logger_; private: struct HealthCheckHostMonitorImpl : public HealthCheckHostMonitor { @@ -120,6 +121,7 @@ class HealthCheckerImplBase : public HealthChecker, const std::chrono::milliseconds interval_; const std::chrono::milliseconds no_traffic_interval_; const std::chrono::milliseconds interval_jitter_; + const uint32_t interval_jitter_percent_; const std::chrono::milliseconds unhealthy_interval_; const std::chrono::milliseconds unhealthy_edge_interval_; const std::chrono::milliseconds healthy_edge_interval_; @@ -127,5 +129,20 @@ class HealthCheckerImplBase : public HealthChecker, uint64_t local_process_healthy_{}; }; +class HealthCheckEventLoggerImpl : public HealthCheckEventLogger { +public: + HealthCheckEventLoggerImpl(AccessLog::AccessLogManager& log_manager, const std::string& file_name) + : file_(log_manager.createAccessLog(file_name)) {} + + void logEjectUnhealthy(envoy::data::core::v2alpha::HealthCheckerType health_checker_type, + const HostDescriptionConstSharedPtr& host, + envoy::data::core::v2alpha::HealthCheckFailureType failure_type) override; + void logAddHealthy(envoy::data::core::v2alpha::HealthCheckerType health_checker_type, + const HostDescriptionConstSharedPtr& host, bool first_check) override; + +private: + Filesystem::FileSharedPtr file_; +}; + } // namespace Upstream } // namespace Envoy diff --git a/source/common/upstream/health_checker_impl.cc b/source/common/upstream/health_checker_impl.cc index 572316db4cd7f..0e135b26eae43 100644 --- a/source/common/upstream/health_checker_impl.cc +++ b/source/common/upstream/health_checker_impl.cc @@ -10,6 +10,7 @@ #include "common/config/well_known_names.h" #include "common/grpc/common.h" #include "common/http/header_map_impl.h" +#include "common/network/address_impl.h" #include "common/router/router.h" #include "common/upstream/host_utility.h" @@ -23,54 +24,60 @@ class HealthCheckerFactoryContextImpl : public Server::Configuration::HealthChec public: HealthCheckerFactoryContextImpl(Upstream::Cluster& cluster, Envoy::Runtime::Loader& runtime, Envoy::Runtime::RandomGenerator& random, - Event::Dispatcher& dispatcher) - : cluster_(cluster), runtime_(runtime), random_(random), dispatcher_(dispatcher) {} + Event::Dispatcher& dispatcher, + HealthCheckEventLoggerPtr&& event_logger) + : cluster_(cluster), runtime_(runtime), random_(random), dispatcher_(dispatcher), + event_logger_(std::move(event_logger)) {} Upstream::Cluster& cluster() override { return cluster_; } Envoy::Runtime::Loader& runtime() override { return runtime_; } Envoy::Runtime::RandomGenerator& random() override { return random_; } Event::Dispatcher& dispatcher() override { return dispatcher_; } + HealthCheckEventLoggerPtr eventLogger() override { return std::move(event_logger_); } private: Upstream::Cluster& cluster_; Envoy::Runtime::Loader& runtime_; Envoy::Runtime::RandomGenerator& random_; Event::Dispatcher& dispatcher_; + HealthCheckEventLoggerPtr event_logger_; }; HealthCheckerSharedPtr HealthCheckerFactory::create(const envoy::api::v2::core::HealthCheck& hc_config, Upstream::Cluster& cluster, Runtime::Loader& runtime, - Runtime::RandomGenerator& random, Event::Dispatcher& dispatcher) { + Runtime::RandomGenerator& random, Event::Dispatcher& dispatcher, + AccessLog::AccessLogManager& log_manager) { + HealthCheckEventLoggerPtr event_logger; + if (!hc_config.event_log_path().empty()) { + event_logger = + std::make_unique(log_manager, hc_config.event_log_path()); + } switch (hc_config.health_checker_case()) { case envoy::api::v2::core::HealthCheck::HealthCheckerCase::kHttpHealthCheck: return std::make_shared(cluster, hc_config, dispatcher, runtime, - random); + random, std::move(event_logger)); case envoy::api::v2::core::HealthCheck::HealthCheckerCase::kTcpHealthCheck: - return std::make_shared(cluster, hc_config, dispatcher, runtime, random); + return std::make_shared(cluster, hc_config, dispatcher, runtime, random, + std::move(event_logger)); case envoy::api::v2::core::HealthCheck::HealthCheckerCase::kGrpcHealthCheck: if (!(cluster.info()->features() & Upstream::ClusterInfo::Features::HTTP2)) { throw EnvoyException(fmt::format("{} cluster must support HTTP/2 for gRPC healthchecking", cluster.info()->name())); } return std::make_shared(cluster, hc_config, dispatcher, runtime, - random); - // Deprecated redis_health_check, preserving using old config until it is removed. - case envoy::api::v2::core::HealthCheck::HealthCheckerCase::kRedisHealthCheck: - ENVOY_LOG(warn, "redis_health_check is deprecated, use custom_health_check instead"); - FALLTHRU; + random, std::move(event_logger)); case envoy::api::v2::core::HealthCheck::HealthCheckerCase::kCustomHealthCheck: { auto& factory = Config::Utility::getAndCheckFactory( - hc_config.has_redis_health_check() - ? Extensions::HealthCheckers::HealthCheckerNames::get().REDIS_HEALTH_CHECKER - : std::string(hc_config.custom_health_check().name())); + std::string(hc_config.custom_health_check().name())); std::unique_ptr context( - new HealthCheckerFactoryContextImpl(cluster, runtime, random, dispatcher)); + new HealthCheckerFactoryContextImpl(cluster, runtime, random, dispatcher, + std::move(event_logger))); return factory.createCustomHealthChecker(hc_config, *context); } default: // Checked by schema. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -78,8 +85,9 @@ HttpHealthCheckerImpl::HttpHealthCheckerImpl(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random) - : HealthCheckerImplBase(cluster, config, dispatcher, runtime, random), + Runtime::RandomGenerator& random, + HealthCheckEventLoggerPtr&& event_logger) + : HealthCheckerImplBase(cluster, config, dispatcher, runtime, random, std::move(event_logger)), path_(config.http_health_check().path()), host_value_(config.http_health_check().host()), request_headers_parser_( Router::HeaderParser::configure(config.http_health_check().request_headers_to_add())), @@ -91,7 +99,13 @@ HttpHealthCheckerImpl::HttpHealthCheckerImpl(const Cluster& cluster, HttpHealthCheckerImpl::HttpActiveHealthCheckSession::HttpActiveHealthCheckSession( HttpHealthCheckerImpl& parent, const HostSharedPtr& host) - : ActiveHealthCheckSession(parent, host), parent_(parent) {} + : ActiveHealthCheckSession(parent, host), parent_(parent), + hostname_(parent_.host_value_.empty() ? parent_.cluster_.info()->name() + : parent_.host_value_), + protocol_(parent_.codec_client_type_ == Http::CodecClient::Type::HTTP1 + ? Http::Protocol::Http11 + : Http::Protocol::Http2), + local_address_(std::make_shared("127.0.0.1")) {} HttpHealthCheckerImpl::HttpActiveHealthCheckSession::~HttpActiveHealthCheckSession() { if (client_) { @@ -120,9 +134,6 @@ void HttpHealthCheckerImpl::HttpActiveHealthCheckSession::onEvent(Network::Conne } } -const RequestInfo::RequestInfoImpl - HttpHealthCheckerImpl::HttpActiveHealthCheckSession::REQUEST_INFO; - void HttpHealthCheckerImpl::HttpActiveHealthCheckSession::onInterval() { if (!client_) { Upstream::Host::CreateConnectionData conn = @@ -137,13 +148,15 @@ void HttpHealthCheckerImpl::HttpActiveHealthCheckSession::onInterval() { Http::HeaderMapImpl request_headers{ {Http::Headers::get().Method, "GET"}, - {Http::Headers::get().Host, - parent_.host_value_.empty() ? parent_.cluster_.info()->name() : parent_.host_value_}, + {Http::Headers::get().Host, hostname_}, {Http::Headers::get().Path, parent_.path_}, {Http::Headers::get().UserAgent, Http::Headers::get().UserAgentValues.EnvoyHealthChecker}}; Router::FilterUtility::setUpstreamScheme(request_headers, *parent_.cluster_.info()); - - parent_.request_headers_parser_->evaluateHeaders(request_headers, REQUEST_INFO); + RequestInfo::RequestInfoImpl request_info(protocol_); + request_info.setDownstreamLocalAddress(local_address_); + request_info.setDownstreamRemoteAddress(local_address_); + request_info.onUpstreamHostSelected(host_); + parent_.request_headers_parser_->evaluateHeaders(request_headers, request_info); request_encoder_->encodeHeaders(request_headers, true); request_encoder_ = nullptr; } @@ -155,7 +168,7 @@ void HttpHealthCheckerImpl::HttpActiveHealthCheckSession::onResetStream(Http::St ENVOY_CONN_LOG(debug, "connection/stream error health_flags={}", *client_, HostUtility::healthFlagsToString(*host_)); - handleFailure(FailureType::Network); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK); } bool HttpHealthCheckerImpl::HttpActiveHealthCheckSession::isHealthCheckSucceeded() { @@ -185,7 +198,7 @@ void HttpHealthCheckerImpl::HttpActiveHealthCheckSession::onResponseComplete() { if (isHealthCheckSucceeded()) { handleSuccess(); } else { - handleFailure(FailureType::Active); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::ACTIVE); } if ((response_headers_->Connection() && @@ -254,8 +267,10 @@ bool TcpHealthCheckMatcher::match(const MatchSegments& expected, const Buffer::I TcpHealthCheckerImpl::TcpHealthCheckerImpl(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random) - : HealthCheckerImplBase(cluster, config, dispatcher, runtime, random), send_bytes_([&config] { + Runtime::RandomGenerator& random, + HealthCheckEventLoggerPtr&& event_logger) + : HealthCheckerImplBase(cluster, config, dispatcher, runtime, random, std::move(event_logger)), + send_bytes_([&config] { Protobuf::RepeatedPtrField send_repeated; if (!config.tcp_health_check().send().text().empty()) { send_repeated.Add()->CopyFrom(config.tcp_health_check().send()); @@ -284,7 +299,7 @@ void TcpHealthCheckerImpl::TcpActiveHealthCheckSession::onData(Buffer::Instance& void TcpHealthCheckerImpl::TcpActiveHealthCheckSession::onEvent(Network::ConnectionEvent event) { if (event == Network::ConnectionEvent::RemoteClose) { - handleFailure(FailureType::Network); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK); } if (event == Network::ConnectionEvent::RemoteClose || @@ -341,8 +356,9 @@ GrpcHealthCheckerImpl::GrpcHealthCheckerImpl(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random) - : HealthCheckerImplBase(cluster, config, dispatcher, runtime, random), + Runtime::RandomGenerator& random, + HealthCheckEventLoggerPtr&& event_logger) + : HealthCheckerImplBase(cluster, config, dispatcher, runtime, random, std::move(event_logger)), service_method_(*Protobuf::DescriptorPool::generated_pool()->FindMethodByName( "grpc.health.v1.Health.Check")) { if (!config.grpc_health_check().service_name().empty()) { @@ -496,7 +512,7 @@ void GrpcHealthCheckerImpl::GrpcActiveHealthCheckSession::onResetStream(Http::St // Http::StreamResetReason::RemoteReset or Http::StreamResetReason::ConnectionTermination (both // mean connection close), check if connection is not fresh (was used for at least 1 request) // and silently retry request on the fresh connection. This is also true for HTTP/1.1 healthcheck. - handleFailure(FailureType::Network); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK); } void GrpcHealthCheckerImpl::GrpcActiveHealthCheckSession::onGoAway() { @@ -504,7 +520,7 @@ void GrpcHealthCheckerImpl::GrpcActiveHealthCheckSession::onGoAway() { HostUtility::healthFlagsToString(*host_)); // Even if we have active health check probe, fail it on GOAWAY and schedule new one. if (request_encoder_) { - handleFailure(FailureType::Network); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK); expect_reset_ = true; request_encoder_->getStream().resetStream(Http::StreamResetReason::LocalReset); } @@ -531,7 +547,7 @@ void GrpcHealthCheckerImpl::GrpcActiveHealthCheckSession::onRpcComplete( if (isHealthCheckSucceeded(grpc_status)) { handleSuccess(); } else { - handleFailure(FailureType::Active); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::ACTIVE); } // |end_stream| will be false if we decided to stop healthcheck before HTTP stream has ended - @@ -581,7 +597,7 @@ void GrpcHealthCheckerImpl::GrpcActiveHealthCheckSession::logHealthCheckStatus( break; default: // Should not happen really, Protobuf should not parse undefined enums values. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; break; } } diff --git a/source/common/upstream/health_checker_impl.h b/source/common/upstream/health_checker_impl.h index 9d7ed96d1a31f..0db4bfd70a3fa 100644 --- a/source/common/upstream/health_checker_impl.h +++ b/source/common/upstream/health_checker_impl.h @@ -1,5 +1,6 @@ #pragma once +#include "envoy/access_log/access_log.h" #include "envoy/api/v2/core/health_check.pb.h" #include "envoy/grpc/status.h" @@ -27,12 +28,14 @@ class HealthCheckerFactory : public Logger::Loggable * @param runtime supplies the runtime loader. * @param random supplies the random generator. * @param dispatcher supplies the dispatcher. + * @param event_logger supplies the event_logger. * @return a health checker. */ static HealthCheckerSharedPtr create(const envoy::api::v2::core::HealthCheck& hc_config, Upstream::Cluster& cluster, Runtime::Loader& runtime, Runtime::RandomGenerator& random, - Event::Dispatcher& dispatcher); + Event::Dispatcher& dispatcher, + AccessLog::AccessLogManager& log_manager); }; /** @@ -42,7 +45,7 @@ class HttpHealthCheckerImpl : public HealthCheckerImplBase { public: HttpHealthCheckerImpl(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random); + Runtime::RandomGenerator& random, HealthCheckEventLoggerPtr&& event_logger); private: struct HttpActiveHealthCheckSession : public ActiveHealthCheckSession, @@ -87,13 +90,14 @@ class HttpHealthCheckerImpl : public HealthCheckerImplBase { HttpActiveHealthCheckSession& parent_; }; - static const RequestInfo::RequestInfoImpl REQUEST_INFO; - ConnectionCallbackImpl connection_callback_impl_{*this}; HttpHealthCheckerImpl& parent_; Http::CodecClientPtr client_; Http::StreamEncoder* request_encoder_{}; Http::HeaderMapPtr response_headers_; + const std::string& hostname_; + const Http::Protocol protocol_; + Network::Address::InstanceConstSharedPtr local_address_; bool expect_reset_{}; }; @@ -105,6 +109,9 @@ class HttpHealthCheckerImpl : public HealthCheckerImplBase { ActiveHealthCheckSessionPtr makeSession(HostSharedPtr host) override { return std::make_unique(*this, host); } + envoy::data::core::v2alpha::HealthCheckerType healthCheckerType() const override { + return envoy::data::core::v2alpha::HealthCheckerType::HTTP; + } Http::CodecClient::Type codecClientType(bool use_http2); @@ -187,7 +194,7 @@ class TcpHealthCheckerImpl : public HealthCheckerImplBase { public: TcpHealthCheckerImpl(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random); + Runtime::RandomGenerator& random, HealthCheckEventLoggerPtr&& event_logger); private: struct TcpActiveHealthCheckSession; @@ -233,6 +240,9 @@ class TcpHealthCheckerImpl : public HealthCheckerImplBase { ActiveHealthCheckSessionPtr makeSession(HostSharedPtr host) override { return std::make_unique(*this, host); } + envoy::data::core::v2alpha::HealthCheckerType healthCheckerType() const override { + return envoy::data::core::v2alpha::HealthCheckerType::TCP; + } const TcpHealthCheckMatcher::MatchSegments send_bytes_; const TcpHealthCheckMatcher::MatchSegments receive_bytes_; @@ -245,7 +255,7 @@ class GrpcHealthCheckerImpl : public HealthCheckerImplBase { public: GrpcHealthCheckerImpl(const Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, - Runtime::RandomGenerator& random); + Runtime::RandomGenerator& random, HealthCheckEventLoggerPtr&& event_logger); private: struct GrpcActiveHealthCheckSession : public ActiveHealthCheckSession, @@ -320,6 +330,9 @@ class GrpcHealthCheckerImpl : public HealthCheckerImplBase { ActiveHealthCheckSessionPtr makeSession(HostSharedPtr host) override { return std::make_unique(*this, host); } + envoy::data::core::v2alpha::HealthCheckerType healthCheckerType() const override { + return envoy::data::core::v2alpha::HealthCheckerType::GRPC; + } const Protobuf::MethodDescriptor& service_method_; absl::optional service_name_; diff --git a/source/common/upstream/health_discovery_service.cc b/source/common/upstream/health_discovery_service.cc index 10d04488bf955..1a5fcd76d5dd5 100644 --- a/source/common/upstream/health_discovery_service.cc +++ b/source/common/upstream/health_discovery_service.cc @@ -6,19 +6,31 @@ namespace Envoy { namespace Upstream { HdsDelegate::HdsDelegate(const envoy::api::v2::core::Node& node, Stats::Scope& scope, - Grpc::AsyncClientPtr async_client, Event::Dispatcher& dispatcher) + Grpc::AsyncClientPtr async_client, Event::Dispatcher& dispatcher, + Runtime::Loader& runtime, Envoy::Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, + Secret::SecretManager& secret_manager, Runtime::RandomGenerator& random, + ClusterInfoFactory& info_factory, + AccessLog::AccessLogManager& access_log_manager) : stats_{ALL_HDS_STATS(POOL_COUNTER_PREFIX(scope, "hds_delegate."))}, - async_client_(std::move(async_client)), service_method_(*Protobuf::DescriptorPool::generated_pool()->FindMethodByName( - "envoy.service.discovery.v2.HealthDiscoveryService.StreamHealthCheck")) { + "envoy.service.discovery.v2.HealthDiscoveryService.StreamHealthCheck")), + async_client_(std::move(async_client)), dispatcher_(dispatcher), runtime_(runtime), + store_stats(stats), ssl_context_manager_(ssl_context_manager), + secret_manager_(secret_manager), random_(random), info_factory_(info_factory), + access_log_manager_(access_log_manager) { health_check_request_.mutable_node()->MergeFrom(node); - retry_timer_ = dispatcher.createTimer([this]() -> void { establishNewStream(); }); - response_timer_ = dispatcher.createTimer([this]() -> void { sendHealthCheckRequest(); }); + hds_retry_timer_ = dispatcher.createTimer([this]() -> void { establishNewStream(); }); + hds_stream_response_timer_ = dispatcher.createTimer([this]() -> void { sendResponse(); }); establishNewStream(); } -void HdsDelegate::setRetryTimer() { - retry_timer_->enableTimer(std::chrono::milliseconds(RETRY_DELAY_MS)); +void HdsDelegate::setHdsRetryTimer() { + hds_retry_timer_->enableTimer(std::chrono::milliseconds(RetryDelayMilliseconds)); +} + +void HdsDelegate::setHdsStreamResponseTimer() { + hds_stream_response_timer_->enableTimer(std::chrono::milliseconds(server_response_ms_)); } void HdsDelegate::establishNewStream() { @@ -30,20 +42,46 @@ void HdsDelegate::establishNewStream() { return; } - sendHealthCheckRequest(); -} - -void HdsDelegate::sendHealthCheckRequest() { - ENVOY_LOG(debug, "Sending HealthCheckRequest"); + // TODO(lilika): Add support for other types of healthchecks + health_check_request_.mutable_capability()->add_health_check_protocol( + envoy::service::discovery::v2::Capability::HTTP); + ENVOY_LOG(debug, "Sending HealthCheckRequest {} ", health_check_request_.DebugString()); stream_->sendMessage(health_check_request_, false); stats_.responses_.inc(); } +// TODO(lilika) : Use jittered backoff as in https://github.com/envoyproxy/envoy/pull/3791 void HdsDelegate::handleFailure() { - ENVOY_LOG(warn, "Load reporter stats stream/connection failure, will retry in {} ms.", - RETRY_DELAY_MS); + ENVOY_LOG(warn, "HdsDelegate stream/connection failure, will retry in {} ms.", + RetryDelayMilliseconds); stats_.errors_.inc(); - setRetryTimer(); + setHdsRetryTimer(); +} + +// TODO(lilika): Add support for the same endpoint in different clusters/ports +envoy::service::discovery::v2::HealthCheckRequestOrEndpointHealthResponse +HdsDelegate::sendResponse() { + envoy::service::discovery::v2::HealthCheckRequestOrEndpointHealthResponse response; + for (const auto& cluster : hds_clusters_) { + for (const auto& hosts : cluster->prioritySet().hostSetsPerPriority()) { + for (const auto& host : hosts->hosts()) { + auto* endpoint = response.mutable_endpoint_health_response()->add_endpoints_health(); + Network::Utility::addressToProtobufAddress( + *host->address(), *endpoint->mutable_endpoint()->mutable_address()); + // TODO(lilika): Add support for more granular options of envoy::api::v2::core::HealthStatus + if (host->healthy()) { + endpoint->set_health_status(envoy::api::v2::core::HealthStatus::HEALTHY); + } else { + endpoint->set_health_status(envoy::api::v2::core::HealthStatus::UNHEALTHY); + } + } + } + } + ENVOY_LOG(debug, "Sending EndpointHealthResponse to server {}", response.DebugString()); + stream_->sendMessage(response, false); + stats_.responses_.inc(); + setHdsStreamResponseTimer(); + return response; } void HdsDelegate::onCreateInitialMetadata(Http::HeaderMap& metadata) { @@ -54,12 +92,59 @@ void HdsDelegate::onReceiveInitialMetadata(Http::HeaderMapPtr&& metadata) { UNREFERENCED_PARAMETER(metadata); } +void HdsDelegate::processMessage( + std::unique_ptr&& message) { + ENVOY_LOG(debug, "New health check response message {} ", message->DebugString()); + ASSERT(message); + + for (const auto& cluster_health_check : message->health_check()) { + // Create HdsCluster config + static const envoy::api::v2::core::BindConfig bind_config; + envoy::api::v2::Cluster cluster_config; + + cluster_config.set_name(cluster_health_check.cluster_name()); + cluster_config.mutable_connect_timeout()->set_seconds(ClusterTimeoutSeconds); + cluster_config.mutable_per_connection_buffer_limit_bytes()->set_value( + ClusterConnectionBufferLimitBytes); + + // Add endpoints to cluster + for (const auto& locality_endpoints : cluster_health_check.endpoints()) { + for (const auto& endpoint : locality_endpoints.endpoints()) { + cluster_config.add_hosts()->MergeFrom(endpoint.address()); + } + } + + // TODO(lilika): Add support for optional per-endpoint health checks + + // Add healthchecks to cluster + for (auto& health_check : cluster_health_check.health_checks()) { + cluster_config.add_health_checks()->MergeFrom(health_check); + } + + ENVOY_LOG(debug, "New HdsCluster config {} ", cluster_config.DebugString()); + + // Create HdsCluster + hds_clusters_.emplace_back(new HdsCluster(runtime_, cluster_config, bind_config, store_stats, + ssl_context_manager_, secret_manager_, false, + info_factory_)); + + hds_clusters_.back()->startHealthchecks(access_log_manager_, runtime_, random_, dispatcher_); + } +} + +// TODO(lilika): Add support for subsequent HealthCheckSpecifier messages that +// might modify the HdsClusters void HdsDelegate::onReceiveMessage( std::unique_ptr&& message) { - ENVOY_LOG(debug, "New health check response ", message->DebugString()); stats_.requests_.inc(); - stream_->sendMessage(health_check_request_, false); - stats_.responses_.inc(); + ENVOY_LOG(debug, "New health check response message {} ", message->DebugString()); + + // Process the HealthCheckSpecifier message + processMessage(std::move(message)); + + // Set response + server_response_ms_ = PROTOBUF_GET_MS_REQUIRED(*message, interval); + setHdsStreamResponseTimer(); } void HdsDelegate::onReceiveTrailingMetadata(Http::HeaderMapPtr&& metadata) { @@ -68,10 +153,84 @@ void HdsDelegate::onReceiveTrailingMetadata(Http::HeaderMapPtr&& metadata) { void HdsDelegate::onRemoteClose(Grpc::Status::GrpcStatus status, const std::string& message) { ENVOY_LOG(warn, "gRPC config stream closed: {}, {}", status, message); - response_timer_->disableTimer(); + hds_stream_response_timer_->disableTimer(); stream_ = nullptr; handleFailure(); } +HdsCluster::HdsCluster(Runtime::Loader& runtime, const envoy::api::v2::Cluster& cluster, + const envoy::api::v2::core::BindConfig& bind_config, Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, + Secret::SecretManager& secret_manager, bool added_via_api, + ClusterInfoFactory& info_factory) + : runtime_(runtime), cluster_(cluster), bind_config_(bind_config), stats_(stats), + ssl_context_manager_(ssl_context_manager), secret_manager_(secret_manager), + added_via_api_(added_via_api), initial_hosts_(new HostVector()) { + ENVOY_LOG(debug, "Creating an HdsCluster"); + priority_set_.getOrCreateHostSet(0); + + info_ = info_factory.createClusterInfo(runtime_, cluster_, bind_config_, stats_, + ssl_context_manager_, secret_manager_, added_via_api_); + + for (const auto& host : cluster.hosts()) { + initial_hosts_->emplace_back( + new HostImpl(info_, "", Network::Address::resolveProtoAddress(host), + envoy::api::v2::core::Metadata::default_instance(), 1, + envoy::api::v2::core::Locality().default_instance(), + envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); + } + initialize([] {}); +} + +ClusterSharedPtr HdsCluster::create() { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + +HostVectorConstSharedPtr HdsCluster::createHealthyHostList(const HostVector& hosts) { + HostVectorSharedPtr healthy_list(new HostVector()); + for (const auto& host : hosts) { + if (host->healthy()) { + healthy_list->emplace_back(host); + } + } + return healthy_list; +} + +ClusterInfoConstSharedPtr ProdClusterInfoFactory::createClusterInfo( + Runtime::Loader& runtime, const envoy::api::v2::Cluster& cluster, + const envoy::api::v2::core::BindConfig& bind_config, Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, Secret::SecretManager& secret_manager, + bool added_via_api) { + + return std::make_unique(cluster, bind_config, runtime, stats, + ssl_context_manager, secret_manager, added_via_api); +} + +void HdsCluster::startHealthchecks(AccessLog::AccessLogManager& access_log_manager, + Runtime::Loader& runtime, Runtime::RandomGenerator& random, + Event::Dispatcher& dispatcher) { + + for (auto& health_check : cluster_.health_checks()) { + health_checkers_.push_back(Upstream::HealthCheckerFactory::create( + health_check, *this, runtime, random, dispatcher, access_log_manager)); + health_checkers_.back()->start(); + } +} + +void HdsCluster::initialize(std::function callback) { + initialization_complete_callback_ = callback; + for (const auto& host : *initial_hosts_) { + host->healthFlagSet(Host::HealthFlag::FAILED_ACTIVE_HC); + } + + auto& first_host_set = priority_set_.getOrCreateHostSet(0); + auto healthy = createHealthyHostList(*initial_hosts_); + + first_host_set.updateHosts(initial_hosts_, healthy, HostsPerLocalityImpl::empty(), + HostsPerLocalityImpl::empty(), {}, *initial_hosts_, {}); +} + +void HdsCluster::setOutlierDetector(const Outlier::DetectorSharedPtr&) { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; +} + } // namespace Upstream } // namespace Envoy diff --git a/source/common/upstream/health_discovery_service.h b/source/common/upstream/health_discovery_service.h index 980d8a450bae9..f167a20324fb8 100644 --- a/source/common/upstream/health_discovery_service.h +++ b/source/common/upstream/health_discovery_service.h @@ -2,16 +2,88 @@ #include "envoy/event/dispatcher.h" #include "envoy/service/discovery/v2/hds.pb.h" +#include "envoy/ssl/context_manager.h" #include "envoy/stats/stats_macros.h" +#include "envoy/upstream/upstream.h" #include "common/common/logger.h" #include "common/grpc/async_client_impl.h" +#include "common/network/resolver_impl.h" +#include "common/upstream/health_checker_impl.h" +#include "common/upstream/upstream_impl.h" namespace Envoy { namespace Upstream { +class ProdClusterInfoFactory : public ClusterInfoFactory, Logger::Loggable { +public: + ClusterInfoConstSharedPtr + createClusterInfo(Runtime::Loader& runtime, const envoy::api::v2::Cluster& cluster, + const envoy::api::v2::core::BindConfig& bind_config, Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, Secret::SecretManager& secret_manager, + bool added_via_api) override; +}; + +// TODO(lilika): Add HdsClusters to the /clusters endpoint to get detailed stats about each HC host. + +/** + * Implementation of Upstream::Cluster for hds clusters, clusters that are used + * by HdsDelegates + */ + +class HdsCluster : public Cluster, Logger::Loggable { +public: + static ClusterSharedPtr create(); + HdsCluster(Runtime::Loader& runtime, const envoy::api::v2::Cluster& cluster, + const envoy::api::v2::core::BindConfig& bind_config, Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, Secret::SecretManager& secret_manager, + bool added_via_api, ClusterInfoFactory& info_factory); + + // From Upstream::Cluster + InitializePhase initializePhase() const override { return InitializePhase::Primary; } + PrioritySet& prioritySet() override { return priority_set_; } + const PrioritySet& prioritySet() const override { return priority_set_; } + void setOutlierDetector(const Outlier::DetectorSharedPtr& outlier_detector); + HealthChecker* healthChecker() override { return health_checker_.get(); } + ClusterInfoConstSharedPtr info() const override { return info_; } + Outlier::Detector* outlierDetector() override { return outlier_detector_.get(); } + const Outlier::Detector* outlierDetector() const override { return outlier_detector_.get(); } + void initialize(std::function callback) override; + + // Creates and starts healthcheckers to its endpoints + void startHealthchecks(AccessLog::AccessLogManager& access_log_manager, Runtime::Loader& runtime, + Runtime::RandomGenerator& random, Event::Dispatcher& dispatcher); + + std::vector healthCheckers() { return health_checkers_; }; + +protected: + PrioritySetImpl priority_set_; + HealthCheckerSharedPtr health_checker_; + Outlier::DetectorSharedPtr outlier_detector_; + + // Creates a vector containing any healthy hosts + static HostVectorConstSharedPtr createHealthyHostList(const HostVector& hosts); + +private: + std::function initialization_complete_callback_; + + Runtime::Loader& runtime_; + const envoy::api::v2::Cluster& cluster_; + const envoy::api::v2::core::BindConfig& bind_config_; + Stats::Store& stats_; + Ssl::ContextManager& ssl_context_manager_; + Secret::SecretManager& secret_manager_; + bool added_via_api_; + + HostVectorSharedPtr initial_hosts_; + ClusterInfoConstSharedPtr info_; + std::vector health_checkers_; +}; + +typedef std::shared_ptr HdsClusterPtr; + /** - * All load reporter stats. @see stats_macros.h + * All hds stats. @see stats_macros.h */ // clang-format off #define ALL_HDS_STATS(COUNTER) \ @@ -21,18 +93,29 @@ namespace Upstream { // clang-format on /** - * Struct definition for all load reporter stats. @see stats_macros.h + * Struct definition for all hds stats. @see stats_macros.h */ struct HdsDelegateStats { ALL_HDS_STATS(GENERATE_COUNTER_STRUCT) }; +// TODO(lilika): Add /config_dump support for HdsDelegate + +/** + * The HdsDelegate class is responsible for receiving requests from a management + * server with a set of hosts to healthcheck, healthchecking them, and reporting + * back the results. + */ class HdsDelegate : Grpc::TypedAsyncStreamCallbacks, Logger::Loggable { public: HdsDelegate(const envoy::api::v2::core::Node& node, Stats::Scope& scope, - Grpc::AsyncClientPtr async_client, Event::Dispatcher& dispatcher); + Grpc::AsyncClientPtr async_client, Event::Dispatcher& dispatcher, + Runtime::Loader& runtime, Envoy::Stats::Store& stats, + Ssl::ContextManager& ssl_context_manager, Secret::SecretManager& secret_manager, + Runtime::RandomGenerator& random, ClusterInfoFactory& info_factory, + AccessLog::AccessLogManager& access_log_manager); // Grpc::TypedAsyncStreamCallbacks void onCreateInitialMetadata(Http::HeaderMap& metadata) override; @@ -41,25 +124,59 @@ class HdsDelegate std::unique_ptr&& message) override; void onReceiveTrailingMetadata(Http::HeaderMapPtr&& metadata) override; void onRemoteClose(Grpc::Status::GrpcStatus status, const std::string& message) override; + envoy::service::discovery::v2::HealthCheckRequestOrEndpointHealthResponse sendResponse(); - // TODO(htuch): Make this configurable or some static. - const uint32_t RETRY_DELAY_MS = 5000; + std::vector hdsClusters() { return hds_clusters_; }; private: - void setRetryTimer(); - void establishNewStream(); - void sendHealthCheckRequest(); + friend class HdsDelegateFriend; + + void setHdsRetryTimer(); + void setHdsStreamResponseTimer(); void handleFailure(); + // Establishes a connection with the management server + void establishNewStream(); + void + processMessage(std::unique_ptr&& message); HdsDelegateStats stats_; + const Protobuf::MethodDescriptor& service_method_; + Grpc::AsyncClientPtr async_client_; Grpc::AsyncStream* stream_{}; - const Protobuf::MethodDescriptor& service_method_; - Event::TimerPtr retry_timer_; - Event::TimerPtr response_timer_; + Event::Dispatcher& dispatcher_; + Runtime::Loader& runtime_; + Envoy::Stats::Store& store_stats; + Ssl::ContextManager& ssl_context_manager_; + Secret::SecretManager& secret_manager_; + Runtime::RandomGenerator& random_; + ClusterInfoFactory& info_factory_; + AccessLog::AccessLogManager& access_log_manager_; + envoy::service::discovery::v2::HealthCheckRequest health_check_request_; std::unique_ptr health_check_message_; + std::vector clusters_; + std::vector hds_clusters_; + + Event::TimerPtr hds_stream_response_timer_; + Event::TimerPtr hds_retry_timer_; + + // TODO(lilika): Add API knob for RetryDelayMilliseconds, instead of + // hardcoding it. + // How often we retry to establish a stream to the management server + const uint32_t RetryDelayMilliseconds = 5000; + + // Soft limit on size of the cluster’s connections read and write buffers. + static constexpr uint32_t ClusterConnectionBufferLimitBytes = 32768; + + // TODO(lilika): Add API knob for ClusterTimeoutSeconds, instead of + // hardcoding it. + // The timeout for new network connections to hosts in the cluster. + static constexpr uint32_t ClusterTimeoutSeconds = 1; + + // How often envoy reports the healthcheck results to the server + uint32_t server_response_ms_ = 0; }; typedef std::unique_ptr HdsDelegatePtr; diff --git a/source/common/upstream/load_balancer_impl.cc b/source/common/upstream/load_balancer_impl.cc index 3e6258517ab16..9d0eefee6849f 100644 --- a/source/common/upstream/load_balancer_impl.cc +++ b/source/common/upstream/load_balancer_impl.cc @@ -34,7 +34,7 @@ uint32_t LoadBalancerBase::choosePriority(uint64_t hash, } } // The percentages should always add up to 100 but we have to have a return for the compiler. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } LoadBalancerBase::LoadBalancerBase(const PrioritySet& priority_set, ClusterStats& stats, @@ -403,7 +403,7 @@ const HostVector& ZoneAwareLoadBalancerBase::hostSourceToHosts(HostsSource hosts case HostsSource::SourceType::LocalityHealthyHosts: return host_set.healthyHostsPerLocality().get()[hosts_source.locality_index_]; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } diff --git a/source/common/upstream/load_balancer_impl.h b/source/common/upstream/load_balancer_impl.h index f10c389e43abd..82946aa091cf5 100644 --- a/source/common/upstream/load_balancer_impl.h +++ b/source/common/upstream/load_balancer_impl.h @@ -362,7 +362,8 @@ class LoadBalancerSubsetInfoImpl : public LoadBalancerSubsetInfo { LoadBalancerSubsetInfoImpl(const envoy::api::v2::Cluster::LbSubsetConfig& subset_config) : enabled_(!subset_config.subset_selectors().empty()), fallback_policy_(subset_config.fallback_policy()), - default_subset_(subset_config.default_subset()) { + default_subset_(subset_config.default_subset()), + locality_weight_aware_(subset_config.locality_weight_aware()) { for (const auto& subset : subset_config.subset_selectors()) { if (!subset.keys().empty()) { subset_keys_.emplace_back( @@ -378,12 +379,14 @@ class LoadBalancerSubsetInfoImpl : public LoadBalancerSubsetInfo { } const ProtobufWkt::Struct& defaultSubset() const override { return default_subset_; } const std::vector>& subsetKeys() const override { return subset_keys_; } + bool localityWeightAware() const override { return locality_weight_aware_; } private: const bool enabled_; const envoy::api::v2::Cluster::LbSubsetConfig::LbSubsetFallbackPolicy fallback_policy_; const ProtobufWkt::Struct default_subset_; std::vector> subset_keys_; + const bool locality_weight_aware_; }; } // namespace Upstream diff --git a/source/common/upstream/load_stats_reporter.cc b/source/common/upstream/load_stats_reporter.cc index 43587b3003293..2bc986a3d5908 100644 --- a/source/common/upstream/load_stats_reporter.cc +++ b/source/common/upstream/load_stats_reporter.cc @@ -8,12 +8,14 @@ namespace Upstream { LoadStatsReporter::LoadStatsReporter(const envoy::api::v2::core::Node& node, ClusterManager& cluster_manager, Stats::Scope& scope, Grpc::AsyncClientPtr async_client, - Event::Dispatcher& dispatcher) + Event::Dispatcher& dispatcher, + MonotonicTimeSource& time_source) : cm_(cluster_manager), stats_{ALL_LOAD_REPORTER_STATS( POOL_COUNTER_PREFIX(scope, "load_reporter."))}, async_client_(std::move(async_client)), service_method_(*Protobuf::DescriptorPool::generated_pool()->FindMethodByName( - "envoy.service.load_stats.v2.LoadReportingService.StreamLoadStats")) { + "envoy.service.load_stats.v2.LoadReportingService.StreamLoadStats")), + time_source_(time_source) { request_.mutable_node()->MergeFrom(node); retry_timer_ = dispatcher.createTimer([this]() -> void { establishNewStream(); }); response_timer_ = dispatcher.createTimer([this]() -> void { sendLoadStatsRequest(); }); @@ -39,7 +41,8 @@ void LoadStatsReporter::establishNewStream() { void LoadStatsReporter::sendLoadStatsRequest() { request_.mutable_cluster_stats()->Clear(); - for (const std::string& cluster_name : clusters_) { + for (const auto& cluster_name_and_timestamp : clusters_) { + const std::string& cluster_name = cluster_name_and_timestamp.first; auto cluster_info_map = cm_.clusters(); auto it = cluster_info_map.find(cluster_name); if (it == cluster_info_map.end()) { @@ -73,6 +76,12 @@ void LoadStatsReporter::sendLoadStatsRequest() { } cluster_stats->set_total_dropped_requests( cluster.info()->loadReportStats().upstream_rq_dropped_.latch()); + const auto now = time_source_.currentTime().time_since_epoch(); + const auto measured_interval = now - cluster_name_and_timestamp.second; + cluster_stats->mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration( + std::chrono::duration_cast(measured_interval).count())); + clusters_[cluster_name] = now; } ENVOY_LOG(trace, "Sending LoadStatsRequest: {}", request_.DebugString()); @@ -109,15 +118,33 @@ void LoadStatsReporter::onReceiveMessage( } void LoadStatsReporter::startLoadReportPeriod() { + // Once a cluster is tracked, we don't want to reset its stats between reports + // to avoid racing between request/response. + // TODO(htuch): They key here could be absl::string_view, but this causes + // problems due to referencing of temporaries in the below loop with Google's + // internal string type. Consider this optimization when the string types + // converge. + std::unordered_map existing_clusters; + for (const std::string& cluster_name : message_->clusters()) { + if (clusters_.count(cluster_name) > 0) { + existing_clusters.emplace(cluster_name, clusters_[cluster_name]); + } + } clusters_.clear(); // Reset stats for all hosts in clusters we are tracking. for (const std::string& cluster_name : message_->clusters()) { - clusters_.emplace_back(cluster_name); + clusters_.emplace(cluster_name, existing_clusters.count(cluster_name) > 0 + ? existing_clusters[cluster_name] + : time_source_.currentTime().time_since_epoch()); auto cluster_info_map = cm_.clusters(); auto it = cluster_info_map.find(cluster_name); if (it == cluster_info_map.end()) { continue; } + // Don't reset stats for existing tracked clusters. + if (existing_clusters.count(cluster_name) > 0) { + continue; + } auto& cluster = it->second.get(); for (auto& host_set : cluster.prioritySet().hostSetsPerPriority()) { for (auto host : host_set->hosts()) { diff --git a/source/common/upstream/load_stats_reporter.h b/source/common/upstream/load_stats_reporter.h index e3d54241892f0..dc23bec2f8e5c 100644 --- a/source/common/upstream/load_stats_reporter.h +++ b/source/common/upstream/load_stats_reporter.h @@ -34,7 +34,7 @@ class LoadStatsReporter public: LoadStatsReporter(const envoy::api::v2::core::Node& node, ClusterManager& cluster_manager, Stats::Scope& scope, Grpc::AsyncClientPtr async_client, - Event::Dispatcher& dispatcher); + Event::Dispatcher& dispatcher, MonotonicTimeSource& time_source); // Grpc::TypedAsyncStreamCallbacks void onCreateInitialMetadata(Http::HeaderMap& metadata) override; @@ -63,7 +63,9 @@ class LoadStatsReporter Event::TimerPtr response_timer_; envoy::service::load_stats::v2::LoadStatsRequest request_; std::unique_ptr message_; - std::vector clusters_; + // Map from cluster name to start of measurement interval. + std::unordered_map clusters_; + MonotonicTimeSource& time_source_; }; typedef std::unique_ptr LoadStatsReporterPtr; diff --git a/source/common/upstream/logical_dns_cluster.cc b/source/common/upstream/logical_dns_cluster.cc index 7245af7e31565..5b3709041b021 100644 --- a/source/common/upstream/logical_dns_cluster.cc +++ b/source/common/upstream/logical_dns_cluster.cc @@ -18,6 +18,7 @@ namespace Upstream { LogicalDnsCluster::LogicalDnsCluster(const envoy::api::v2::Cluster& cluster, Runtime::Loader& runtime, Stats::Store& stats, Ssl::ContextManager& ssl_context_manager, + const LocalInfo::LocalInfo& local_info, Network::DnsResolverSharedPtr dns_resolver, ThreadLocal::SlotAllocator& tls, ClusterManager& cm, Event::Dispatcher& dispatcher, bool added_via_api) @@ -27,10 +28,19 @@ LogicalDnsCluster::LogicalDnsCluster(const envoy::api::v2::Cluster& cluster, dns_refresh_rate_ms_( std::chrono::milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(cluster, dns_refresh_rate, 5000))), tls_(tls.allocateSlot()), - resolve_timer_(dispatcher.createTimer([this]() -> void { startResolve(); })) { - const auto& hosts = cluster.hosts(); - if (hosts.size() != 1) { - throw EnvoyException("logical_dns clusters must have a single host"); + resolve_timer_(dispatcher.createTimer([this]() -> void { startResolve(); })), + local_info_(local_info), + load_assignment_(cluster.has_load_assignment() + ? cluster.load_assignment() + : Config::Utility::translateClusterHosts(cluster.hosts())) { + const auto& locality_lb_endpoints = load_assignment_.endpoints(); + if (locality_lb_endpoints.size() != 1 || locality_lb_endpoints[0].lb_endpoints().size() != 1) { + if (cluster.has_load_assignment()) { + throw EnvoyException( + "LOGICAL_DNS clusters must have a single locality_lb_endpoint and a single lb_endpoint"); + } else { + throw EnvoyException("LOGICAL_DNS clusters must have a single host"); + } } switch (cluster.dns_lookup_family()) { @@ -44,10 +54,11 @@ LogicalDnsCluster::LogicalDnsCluster(const envoy::api::v2::Cluster& cluster, dns_lookup_family_ = Network::DnsLookupFamily::Auto; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } - const auto& socket_address = hosts[0].socket_address(); + const envoy::api::v2::core::SocketAddress& socket_address = + lbEndpoint().endpoint().address().socket_address(); dns_url_ = fmt::format("tcp://{}:{}", socket_address.address(), socket_address.port_value()); hostname_ = Network::Utility::hostFromTcpUrl(dns_url_); Network::Utility::portFromTcpUrl(dns_url_); @@ -88,7 +99,8 @@ void LogicalDnsCluster::startResolve() { current_resolved_address_ = new_address; // Capture URL to avoid a race with another update. tls_->runOnAllThreads([this, new_address]() -> void { - tls_->getTyped().current_resolved_address_ = new_address; + PerThreadCurrentHostData& data = tls_->getTyped(); + data.current_resolved_address_ = new_address; }); } @@ -107,14 +119,16 @@ void LogicalDnsCluster::startResolve() { new LogicalHost(info_, hostname_, Network::Utility::getIpv6AnyAddress(), *this)); break; } - HostVectorSharedPtr new_hosts(new HostVector()); - new_hosts->emplace_back(logical_host_); - // Given the current config, only EDS clusters support multiple priorities. - ASSERT(priority_set_.hostSetsPerPriority().size() == 1); - auto& first_host_set = priority_set_.getOrCreateHostSet(0); - first_host_set.updateHosts(new_hosts, createHealthyHostList(*new_hosts), - HostsPerLocalityImpl::empty(), HostsPerLocalityImpl::empty(), - {}, *new_hosts, {}); + const auto& locality_lb_endpoint = localityLbEndpoint(); + PriorityStateManager priority_state_manager(*this, local_info_); + priority_state_manager.initializePriorityFor(locality_lb_endpoint); + priority_state_manager.registerHostForPriority(logical_host_, locality_lb_endpoint, + lbEndpoint(), absl::nullopt); + + const uint32_t priority = locality_lb_endpoint.priority(); + priority_state_manager.updateClusterPrioritySet( + priority, std::move(priority_state_manager.priorityState()[priority].first), + absl::nullopt, absl::nullopt, absl::nullopt); } } @@ -131,7 +145,8 @@ Upstream::Host::CreateConnectionData LogicalDnsCluster::LogicalHost::createConne return {HostImpl::createConnection(dispatcher, *parent_.info_, data.current_resolved_address_, options), HostDescriptionConstSharedPtr{ - new RealHostDescription(data.current_resolved_address_, shared_from_this())}}; + new RealHostDescription(data.current_resolved_address_, parent_.localityLbEndpoint(), + parent_.lbEndpoint(), shared_from_this())}}; } } // namespace Upstream diff --git a/source/common/upstream/logical_dns_cluster.h b/source/common/upstream/logical_dns_cluster.h index 138285fe24dc2..2feec15579b78 100644 --- a/source/common/upstream/logical_dns_cluster.h +++ b/source/common/upstream/logical_dns_cluster.h @@ -30,6 +30,7 @@ class LogicalDnsCluster : public ClusterImplBase { public: LogicalDnsCluster(const envoy::api::v2::Cluster& cluster, Runtime::Loader& runtime, Stats::Store& stats, Ssl::ContextManager& ssl_context_manager, + const LocalInfo::LocalInfo& local_info, Network::DnsResolverSharedPtr dns_resolver, ThreadLocal::SlotAllocator& tls, ClusterManager& cm, Event::Dispatcher& dispatcher, bool added_via_api); @@ -42,9 +43,10 @@ class LogicalDnsCluster : public ClusterImplBase { struct LogicalHost : public HostImpl { LogicalHost(ClusterInfoConstSharedPtr cluster, const std::string& hostname, Network::Address::InstanceConstSharedPtr address, LogicalDnsCluster& parent) - : HostImpl(cluster, hostname, address, envoy::api::v2::core::Metadata::default_instance(), - 1, envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance()), + : HostImpl(cluster, hostname, address, parent.lbEndpoint().metadata(), + parent.lbEndpoint().load_balancing_weight().value(), + parent.localityLbEndpoint().locality(), + parent.lbEndpoint().endpoint().health_check_config()), parent_(parent) {} // Upstream::Host @@ -57,14 +59,26 @@ class LogicalDnsCluster : public ClusterImplBase { struct RealHostDescription : public HostDescription { RealHostDescription(Network::Address::InstanceConstSharedPtr address, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint, HostConstSharedPtr logical_host) - : address_(address), logical_host_(logical_host) {} + : address_(address), logical_host_(logical_host), + metadata_(std::make_shared(lb_endpoint.metadata())), + health_check_address_( + lb_endpoint.endpoint().health_check_config().port_value() == 0 + ? address + : Network::Utility::getAddressWithPort( + *address, lb_endpoint.endpoint().health_check_config().port_value())), + locality_lb_endpoint_(locality_lb_endpoint), lb_endpoint_(lb_endpoint) {} // Upstream:HostDescription bool canary() const override { return false; } - const envoy::api::v2::core::Metadata& metadata() const override { - return envoy::api::v2::core::Metadata::default_instance(); + void canary(bool) override {} + const std::shared_ptr metadata() const override { + return metadata_; } + void metadata(const envoy::api::v2::core::Metadata&) override {} + const ClusterInfo& cluster() const override { return logical_host_->cluster(); } HealthCheckHostMonitor& healthChecker() const override { return logical_host_->healthChecker(); @@ -76,20 +90,36 @@ class LogicalDnsCluster : public ClusterImplBase { const std::string& hostname() const override { return logical_host_->hostname(); } Network::Address::InstanceConstSharedPtr address() const override { return address_; } const envoy::api::v2::core::Locality& locality() const override { - return envoy::api::v2::core::Locality().default_instance(); + return locality_lb_endpoint_.locality(); } - // TODO(dio): To support different address port. Network::Address::InstanceConstSharedPtr healthCheckAddress() const override { - return address_; + return health_check_address_; } + uint32_t priority() const { return locality_lb_endpoint_.priority(); } Network::Address::InstanceConstSharedPtr address_; HostConstSharedPtr logical_host_; + const std::shared_ptr metadata_; + Network::Address::InstanceConstSharedPtr health_check_address_; + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint_; + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint_; }; struct PerThreadCurrentHostData : public ThreadLocal::ThreadLocalObject { Network::Address::InstanceConstSharedPtr current_resolved_address_; }; + const envoy::api::v2::endpoint::LocalityLbEndpoints& localityLbEndpoint() const { + // This is checked in the constructor, i.e. at config load time. + ASSERT(load_assignment_.endpoints().size() == 1); + return load_assignment_.endpoints()[0]; + } + + const envoy::api::v2::endpoint::LbEndpoint& lbEndpoint() const { + // This is checked in the constructor, i.e. at config load time. + ASSERT(localityLbEndpoint().lb_endpoints().size() == 1); + return localityLbEndpoint().lb_endpoints()[0]; + } + void startResolve(); // ClusterImplBase @@ -105,6 +135,8 @@ class LogicalDnsCluster : public ClusterImplBase { Network::Address::InstanceConstSharedPtr current_resolved_address_; HostSharedPtr logical_host_; Network::ActiveDnsQuery* active_dns_query_{}; + const LocalInfo::LocalInfo& local_info_; + const envoy::api::v2::ClusterLoadAssignment load_assignment_; }; } // namespace Upstream diff --git a/source/common/upstream/outlier_detection_impl.cc b/source/common/upstream/outlier_detection_impl.cc index 7a5e31cf0a8a5..ac8c38d781b25 100644 --- a/source/common/upstream/outlier_detection_impl.cc +++ b/source/common/upstream/outlier_detection_impl.cc @@ -234,7 +234,7 @@ bool DetectorImpl::enforceEjection(EjectionType type) { config_.enforcingSuccessRate()); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void DetectorImpl::updateEnforcedEjectionStats(EjectionType type) { @@ -342,7 +342,7 @@ void DetectorImpl::onConsecutiveErrorWorker(HostSharedPtr host, EjectionType typ host_monitors_[host]->resetConsecutiveGatewayFailure(); break; case EjectionType::SuccessRate: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -539,7 +539,7 @@ std::string EventLoggerImpl::typeToString(EjectionType type) { return "SuccessRate"; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } int EventLoggerImpl::secsSinceLastAction(const absl::optional& lastActionTime, diff --git a/source/common/upstream/ring_hash_lb.cc b/source/common/upstream/ring_hash_lb.cc index 3161ec48cb6c1..e5b5c164e6e8c 100644 --- a/source/common/upstream/ring_hash_lb.cc +++ b/source/common/upstream/ring_hash_lb.cc @@ -103,8 +103,8 @@ RingHashLoadBalancer::Ring::Ring( // new address that is larger, or runs on a platform where UDS is larger. I don't think it's // worth the defensive coding to deal with the heap allocation case (e.g. via // absl::InlinedVector) at the current time. - RELEASE_ASSERT(address_string.size() + 1 + StringUtil::MIN_ITOA_OUT_LEN <= - sizeof(hash_key_buffer)); + RELEASE_ASSERT( + address_string.size() + 1 + StringUtil::MIN_ITOA_OUT_LEN <= sizeof(hash_key_buffer), ""); memcpy(hash_key_buffer, address_string.c_str(), offset_start); hash_key_buffer[offset_start++] = '_'; for (uint64_t i = 0; i < hashes_per_host; i++) { diff --git a/source/common/upstream/sds_subscription.h b/source/common/upstream/sds_subscription.h index 4a6d6c78ef57e..110dae2ebbb9b 100644 --- a/source/common/upstream/sds_subscription.h +++ b/source/common/upstream/sds_subscription.h @@ -42,7 +42,7 @@ class SdsSubscription : public Http::RestApiFetcher, // We should never hit this at runtime, since this legacy adapter is only used by EdsClusterImpl // that doesn't do dynamic modification of resources. UNREFERENCED_PARAMETER(resources); - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } // Http::RestApiFetcher diff --git a/source/common/upstream/subset_lb.cc b/source/common/upstream/subset_lb.cc index 698bd472c7460..e15f6250ac3fc 100644 --- a/source/common/upstream/subset_lb.cc +++ b/source/common/upstream/subset_lb.cc @@ -27,22 +27,56 @@ SubsetLoadBalancer::SubsetLoadBalancer( default_subset_metadata_(subsets.defaultSubset().fields().begin(), subsets.defaultSubset().fields().end()), subset_keys_(subsets.subsetKeys()), original_priority_set_(priority_set), - original_local_priority_set_(local_priority_set) { + original_local_priority_set_(local_priority_set), + locality_weight_aware_(subsets.localityWeightAware()) { ASSERT(subsets.isEnabled()); // Create filtered default subset (if necessary) and other subsets based on current hosts. - for (auto& host_set : priority_set.hostSetsPerPriority()) { - update(host_set->priority(), host_set->hosts(), {}); - } + refreshSubsets(); // Configure future updates. original_priority_set_callback_handle_ = priority_set.addMemberUpdateCb( [this](uint32_t priority, const HostVector& hosts_added, const HostVector& hosts_removed) { - update(priority, hosts_added, hosts_removed); + if (!hosts_added.size() && !hosts_removed.size()) { + // It's possible that metadata changed, without hosts being added nor removed. + // If so we need to add any new subsets, remove unused ones, and regroup hosts into + // the right subsets. + // + // Note, note, note: if metadata for existing endpoints changed _and_ hosts were also + // added or removed, we don't need to hit this path. That's fine, given that + // findOrCreateSubset() will be called from processSubsets because it'll be triggered by + // either hosts_added or hosts_removed. That's where the new subsets will be created. + refreshSubsets(priority); + } else { + // This is a regular update with deltas. + update(priority, hosts_added, hosts_removed); + } }); } -SubsetLoadBalancer::~SubsetLoadBalancer() { original_priority_set_callback_handle_->remove(); } +SubsetLoadBalancer::~SubsetLoadBalancer() { + original_priority_set_callback_handle_->remove(); + + // Ensure gauges reflect correct values. + forEachSubset(subsets_, [&](LbSubsetEntryPtr entry) { + if (entry->initialized() && entry->active()) { + stats_.lb_subsets_removed_.inc(); + stats_.lb_subsets_active_.dec(); + } + }); +} + +void SubsetLoadBalancer::refreshSubsets() { + for (auto& host_set : original_priority_set_.hostSetsPerPriority()) { + update(host_set->priority(), host_set->hosts(), {}); + } +} + +void SubsetLoadBalancer::refreshSubsets(uint32_t priority) { + const auto& host_sets = original_priority_set_.hostSetsPerPriority(); + ASSERT(priority < host_sets.size()); + update(priority, host_sets[priority]->hosts(), {}); +} HostConstSharedPtr SubsetLoadBalancer::chooseHost(LoadBalancerContext* context) { if (context) { @@ -148,7 +182,8 @@ void SubsetLoadBalancer::updateFallbackSubset(uint32_t priority, const HostVecto } fallback_subset_.reset(new LbSubsetEntry()); - fallback_subset_->priority_subset_.reset(new PrioritySubsetImpl(*this, predicate)); + fallback_subset_->priority_subset_.reset( + new PrioritySubsetImpl(*this, predicate, locality_weight_aware_)); return; } @@ -240,7 +275,8 @@ void SubsetLoadBalancer::update(uint32_t priority, const HostVector& hosts_added // Initialize new entry with hosts and update stats. (An uninitialized entry // with only removed hosts is a degenerate case and we leave the entry // uninitialized.) - entry->priority_subset_.reset(new PrioritySubsetImpl(*this, predicate)); + entry->priority_subset_.reset( + new PrioritySubsetImpl(*this, predicate, locality_weight_aware_)); stats_.lb_subsets_active_.inc(); stats_.lb_subsets_created_.inc(); } @@ -248,13 +284,24 @@ void SubsetLoadBalancer::update(uint32_t priority, const HostVector& hosts_added } bool SubsetLoadBalancer::hostMatches(const SubsetMetadata& kvs, const Host& host) { - const envoy::api::v2::core::Metadata& host_metadata = host.metadata(); + const envoy::api::v2::core::Metadata& host_metadata = *host.metadata(); + const auto filter_it = + host_metadata.filter_metadata().find(Config::MetadataFilters::get().ENVOY_LB); + + if (filter_it == host_metadata.filter_metadata().end()) { + return kvs.size() == 0; + } + + const ProtobufWkt::Struct& data_struct = filter_it->second; + const auto& fields = data_struct.fields(); for (const auto& kv : kvs) { - const ProtobufWkt::Value& host_value = Config::Metadata::metadataValue( - host_metadata, Config::MetadataFilters::get().ENVOY_LB, kv.first); + const auto entry_it = fields.find(kv.first); + if (entry_it == fields.end()) { + return false; + } - if (!ValueUtil::equal(host_value, kv.second)) { + if (!ValueUtil::equal(entry_it->second, kv.second)) { return false; } } @@ -269,7 +316,7 @@ SubsetLoadBalancer::extractSubsetMetadata(const std::set& subset_ke const Host& host) { SubsetMetadata kvs; - const envoy::api::v2::core::Metadata& metadata = host.metadata(); + const envoy::api::v2::core::Metadata& metadata = *host.metadata(); const auto& filter_it = metadata.filter_metadata().find(Config::MetadataFilters::get().ENVOY_LB); if (filter_it == metadata.filter_metadata().end()) { return kvs; @@ -368,9 +415,10 @@ void SubsetLoadBalancer::forEachSubset(LbSubsetMap& subsets, // Initialize a new HostSubsetImpl and LoadBalancer from the SubsetLoadBalancer, filtering hosts // with the given predicate. SubsetLoadBalancer::PrioritySubsetImpl::PrioritySubsetImpl(const SubsetLoadBalancer& subset_lb, - HostPredicate predicate) + HostPredicate predicate, + bool locality_weight_aware) : PrioritySetImpl(), original_priority_set_(subset_lb.original_priority_set_), - predicate_(predicate) { + predicate_(predicate), locality_weight_aware_(locality_weight_aware) { for (size_t i = 0; i < original_priority_set_.hostSetsPerPriority().size(); ++i) { empty_ &= getOrCreateHostSet(i).hosts().empty(); @@ -421,7 +469,7 @@ SubsetLoadBalancer::PrioritySubsetImpl::PrioritySubsetImpl(const SubsetLoadBalan break; case LoadBalancerType::OriginalDst: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } triggerCallbacks(); @@ -433,9 +481,12 @@ SubsetLoadBalancer::PrioritySubsetImpl::PrioritySubsetImpl(const SubsetLoadBalan void SubsetLoadBalancer::HostSubsetImpl::update(const HostVector& hosts_added, const HostVector& hosts_removed, std::function predicate) { + std::unordered_set predicate_added; + HostVector filtered_added; for (const auto host : hosts_added) { if (predicate(*host)) { + predicate_added.insert(host); filtered_added.emplace_back(host); } } @@ -450,8 +501,12 @@ void SubsetLoadBalancer::HostSubsetImpl::update(const HostVector& hosts_added, HostVectorSharedPtr hosts(new HostVector()); HostVectorSharedPtr healthy_hosts(new HostVector()); + // It's possible that hosts_added == original_host_set_.hosts(), e.g.: when + // calling refreshSubsets() if only metadata change. If so, we can avoid the + // predicate() call. for (const auto host : original_host_set_.hosts()) { - if (predicate(*host)) { + bool host_seen = predicate_added.count(host) == 1; + if (host_seen || predicate(*host)) { hosts->emplace_back(host); if (host->healthy()) { healthy_hosts->emplace_back(host); @@ -459,31 +514,40 @@ void SubsetLoadBalancer::HostSubsetImpl::update(const HostVector& hosts_added, } } - HostsPerLocalityConstSharedPtr hosts_per_locality = - original_host_set_.hostsPerLocality().filter(predicate); + // Calling predicate() is expensive since it involves metadata lookups; so we + // avoid it in the 2nd call to filter() by using the result from the first call + // to filter() as the starting point. + // + // Also, if we only have one locality we can avoid the first call to filter() by + // just creating a new HostsPerLocality from the list of all hosts. + // + // TODO(rgs1): merge these two filter() calls in one loop. + HostsPerLocalityConstSharedPtr hosts_per_locality; + + if (original_host_set_.hostsPerLocality().get().size() == 1) { + hosts_per_locality.reset( + new HostsPerLocalityImpl(*hosts, original_host_set_.hostsPerLocality().hasLocalLocality())); + } else { + hosts_per_locality = original_host_set_.hostsPerLocality().filter(predicate); + } + HostsPerLocalityConstSharedPtr healthy_hosts_per_locality = - original_host_set_.hostsPerLocality().filter( - [&predicate](const Host& host) { return predicate(host) && host.healthy(); }); - - // We pass in an empty list of locality weights here. This effectively disables locality balancing - // for subset LB. - // TODO(htuch): We should consider adding locality awareness here, but we need to do some design - // work first, and this might not even be a desirable thing to do. Consider for example a - // situation in which you have 50/50 split across two localities X/Y which have 100 hosts each - // without subsetting. If the subset LB results in X having only 1 host selected but Y having 100, - // then a lot more load is being dumped on the single host in X than originally anticipated in the - // load balancing assignment delivered via EDS. It might seem you want to further weight by subset - // size in order for this to make sense. However, while the original X/Y weightings can be - // respected this way, those weightings were made by a management server that was not taking into - // consideration subsets (e.g. LRS only reports at locality level). - HostSetImpl::updateHosts(hosts, healthy_hosts, hosts_per_locality, healthy_hosts_per_locality, {}, - filtered_added, filtered_removed); + hosts_per_locality->filter([](const Host& host) { return host.healthy(); }); + + if (locality_weight_aware_) { + HostSetImpl::updateHosts(hosts, healthy_hosts, hosts_per_locality, healthy_hosts_per_locality, + original_host_set_.localityWeights(), filtered_added, + filtered_removed); + } else { + HostSetImpl::updateHosts(hosts, healthy_hosts, hosts_per_locality, healthy_hosts_per_locality, + {}, filtered_added, filtered_removed); + } } HostSetImplPtr SubsetLoadBalancer::PrioritySubsetImpl::createHostSet(uint32_t priority) { - RELEASE_ASSERT(priority < original_priority_set_.hostSetsPerPriority().size()); - return HostSetImplPtr{ - new HostSubsetImpl(*original_priority_set_.hostSetsPerPriority()[priority])}; + RELEASE_ASSERT(priority < original_priority_set_.hostSetsPerPriority().size(), ""); + return HostSetImplPtr{new HostSubsetImpl(*original_priority_set_.hostSetsPerPriority()[priority], + locality_weight_aware_)}; } void SubsetLoadBalancer::PrioritySubsetImpl::update(uint32_t priority, diff --git a/source/common/upstream/subset_lb.h b/source/common/upstream/subset_lb.h index 6bb99b844d3c1..3eef95806f5a8 100644 --- a/source/common/upstream/subset_lb.h +++ b/source/common/upstream/subset_lb.h @@ -38,8 +38,9 @@ class SubsetLoadBalancer : public LoadBalancer, Logger::Loggable new_cluster; // We make this a shared pointer to deal with the distinct ownership @@ -384,17 +382,17 @@ ClusterSharedPtr ClusterImplBase::create(const envoy::api::v2::Cluster& cluster, switch (cluster.type()) { case envoy::api::v2::Cluster::STATIC: - new_cluster.reset( - new StaticClusterImpl(cluster, runtime, stats, ssl_context_manager, cm, added_via_api)); + new_cluster.reset(new StaticClusterImpl(cluster, runtime, stats, ssl_context_manager, + local_info, cm, added_via_api)); break; case envoy::api::v2::Cluster::STRICT_DNS: new_cluster.reset(new StrictDnsClusterImpl(cluster, runtime, stats, ssl_context_manager, - selected_dns_resolver, cm, dispatcher, + local_info, selected_dns_resolver, cm, dispatcher, added_via_api)); break; case envoy::api::v2::Cluster::LOGICAL_DNS: new_cluster.reset(new LogicalDnsCluster(cluster, runtime, stats, ssl_context_manager, - selected_dns_resolver, tls, cm, dispatcher, + local_info, selected_dns_resolver, tls, cm, dispatcher, added_via_api)); break; case envoy::api::v2::Cluster::ORIGINAL_DST: @@ -419,14 +417,17 @@ ClusterSharedPtr ClusterImplBase::create(const envoy::api::v2::Cluster& cluster, cm, dispatcher, random, added_via_api)); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } if (!cluster.health_checks().empty()) { // TODO(htuch): Need to support multiple health checks in v2. - ASSERT(cluster.health_checks().size() == 1); - new_cluster->setHealthChecker(HealthCheckerFactory::create( - cluster.health_checks()[0], *new_cluster, runtime, random, dispatcher)); + if (cluster.health_checks().size() != 1) { + throw EnvoyException("Multiple health checks not supported"); + } else { + new_cluster->setHealthChecker(HealthCheckerFactory::create( + cluster.health_checks()[0], *new_cluster, runtime, random, dispatcher, log_manager)); + } } new_cluster->setOutlierDetector(Outlier::DetectorImplFactory::createForCluster( @@ -622,7 +623,7 @@ ClusterInfoImpl::ResourceManagers::load(const envoy::api::v2::Cluster& config, priority_name = "high"; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } const std::string runtime_prefix = @@ -645,38 +646,165 @@ ClusterInfoImpl::ResourceManagers::load(const envoy::api::v2::Cluster& config, runtime, runtime_prefix, max_connections, max_pending_requests, max_requests, max_retries)}; } +PriorityStateManager::PriorityStateManager(ClusterImplBase& cluster, + const LocalInfo::LocalInfo& local_info) + : parent_(cluster), local_info_node_(local_info.node()) {} + +void PriorityStateManager::initializePriorityFor( + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint) { + const uint32_t priority = locality_lb_endpoint.priority(); + if (priority_state_.size() <= priority) { + priority_state_.resize(priority + 1); + } + if (priority_state_[priority].first == nullptr) { + priority_state_[priority].first.reset(new HostVector()); + } + if (locality_lb_endpoint.has_locality() && locality_lb_endpoint.has_load_balancing_weight()) { + priority_state_[priority].second[locality_lb_endpoint.locality()] = + locality_lb_endpoint.load_balancing_weight().value(); + } +} + +void PriorityStateManager::registerHostForPriority( + const std::string& hostname, Network::Address::InstanceConstSharedPtr address, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint, + const absl::optional health_checker_flag) { + const HostSharedPtr host(new HostImpl(parent_.info(), hostname, address, lb_endpoint.metadata(), + lb_endpoint.load_balancing_weight().value(), + locality_lb_endpoint.locality(), + lb_endpoint.endpoint().health_check_config())); + registerHostForPriority(host, locality_lb_endpoint, lb_endpoint, health_checker_flag); +} + +void PriorityStateManager::registerHostForPriority( + const HostSharedPtr& host, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint, + const absl::optional health_checker_flag) { + const uint32_t priority = locality_lb_endpoint.priority(); + // Should be called after initializePriorityFor. + ASSERT(priority_state_[priority].first); + priority_state_[priority].first->emplace_back(host); + if (health_checker_flag.has_value()) { + const auto& health_status = lb_endpoint.health_status(); + if (health_status == envoy::api::v2::core::HealthStatus::UNHEALTHY || + health_status == envoy::api::v2::core::HealthStatus::DRAINING || + health_status == envoy::api::v2::core::HealthStatus::TIMEOUT) { + priority_state_[priority].first->back()->healthFlagSet(health_checker_flag.value()); + } + } +} + +void PriorityStateManager::updateClusterPrioritySet( + const uint32_t priority, HostVectorSharedPtr&& current_hosts, + const absl::optional& hosts_added, const absl::optional& hosts_removed, + const absl::optional health_checker_flag) { + // If local locality is not defined then skip populating per locality hosts. + const auto& local_locality = local_info_node_.locality(); + ENVOY_LOG(trace, "Local locality: {}", local_locality.DebugString()); + + // For non-EDS, most likely the current hosts are from priority_state_[priority].first. + HostVectorSharedPtr hosts(std::move(current_hosts)); + LocalityWeightsMap empty_locality_map; + LocalityWeightsMap& locality_weights_map = + priority_state_.size() > priority ? priority_state_[priority].second : empty_locality_map; + ASSERT(priority_state_.size() > priority || locality_weights_map.empty()); + LocalityWeightsSharedPtr locality_weights; + std::vector per_locality; + + // If we are configured for locality weighted LB we populate the locality weights. + const bool locality_weighted_lb = parent_.info()->lbConfig().has_locality_weighted_lb_config(); + if (locality_weighted_lb) { + locality_weights = std::make_shared(); + } + + // We use std::map to guarantee a stable ordering for zone aware routing. + std::map hosts_per_locality; + + for (const HostSharedPtr& host : *hosts) { + // Take into consideration when a non-EDS cluster has active health checking, i.e. to mark all + // the hosts unhealthy (host->healthFlagSet(Host::HealthFlag::FAILED_ACTIVE_HC)) and then fire + // update callbacks to start the health checking process. + if (health_checker_flag.has_value()) { + host->healthFlagSet(health_checker_flag.value()); + } + hosts_per_locality[host->locality()].push_back(host); + } + + // Do we have hosts for the local locality? + const bool non_empty_local_locality = + local_info_node_.has_locality() && + hosts_per_locality.find(local_locality) != hosts_per_locality.end(); + + // As per HostsPerLocality::get(), the per_locality vector must have the local locality hosts + // first if non_empty_local_locality. + if (non_empty_local_locality) { + per_locality.emplace_back(hosts_per_locality[local_locality]); + if (locality_weighted_lb) { + locality_weights->emplace_back(locality_weights_map[local_locality]); + } + } + + // After the local locality hosts (if any), we place the remaining locality host groups in + // lexicographic order. This provides a stable ordering for zone aware routing. + for (auto& entry : hosts_per_locality) { + if (!non_empty_local_locality || !LocalityEqualTo()(local_locality, entry.first)) { + per_locality.emplace_back(entry.second); + if (locality_weighted_lb) { + locality_weights->emplace_back(locality_weights_map[entry.first]); + } + } + } + + auto per_locality_shared = + std::make_shared(std::move(per_locality), non_empty_local_locality); + + auto& host_set = + static_cast(parent_.prioritySet()).getOrCreateHostSet(priority); + host_set.updateHosts(hosts, ClusterImplBase::createHealthyHostList(*hosts), per_locality_shared, + ClusterImplBase::createHealthyHostLists(*per_locality_shared), + std::move(locality_weights), hosts_added.value_or(*hosts), + hosts_removed.value_or({})); +} + StaticClusterImpl::StaticClusterImpl(const envoy::api::v2::Cluster& cluster, Runtime::Loader& runtime, Stats::Store& stats, - Ssl::ContextManager& ssl_context_manager, ClusterManager& cm, + Ssl::ContextManager& ssl_context_manager, + const LocalInfo::LocalInfo& local_info, ClusterManager& cm, bool added_via_api) : ClusterImplBase(cluster, cm.bindConfig(), runtime, stats, ssl_context_manager, cm.clusterManagerFactory().secretManager(), added_via_api), - initial_hosts_(new HostVector()) { - - for (const auto& host : cluster.hosts()) { - initial_hosts_->emplace_back(HostSharedPtr{new HostImpl( - info_, "", resolveProtoAddress(host), envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())}); + priority_state_manager_(new PriorityStateManager(*this, local_info)) { + // TODO(dio): Use by-reference when cluster.hosts() is removed. + const envoy::api::v2::ClusterLoadAssignment cluster_load_assignment( + cluster.has_load_assignment() ? cluster.load_assignment() + : Config::Utility::translateClusterHosts(cluster.hosts())); + + for (const auto& locality_lb_endpoint : cluster_load_assignment.endpoints()) { + priority_state_manager_->initializePriorityFor(locality_lb_endpoint); + for (const auto& lb_endpoint : locality_lb_endpoint.lb_endpoints()) { + priority_state_manager_->registerHostForPriority( + "", resolveProtoAddress(lb_endpoint.endpoint().address()), locality_lb_endpoint, + lb_endpoint, absl::nullopt); + } } } void StaticClusterImpl::startPreInit() { - // At this point see if we have a health checker. If so, mark all the hosts unhealthy and then - // fire update callbacks to start the health checking process. - if (health_checker_) { - for (const auto& host : *initial_hosts_) { - host->healthFlagSet(Host::HealthFlag::FAILED_ACTIVE_HC); - } + // At this point see if we have a health checker. If so, mark all the hosts unhealthy and + // then fire update callbacks to start the health checking process. + const auto& health_checker_flag = + health_checker_ != nullptr + ? absl::optional(Host::HealthFlag::FAILED_ACTIVE_HC) + : absl::nullopt; + + auto& priority_state = priority_state_manager_->priorityState(); + for (size_t i = 0; i < priority_state.size(); ++i) { + priority_state_manager_->updateClusterPrioritySet( + i, std::move(priority_state[i].first), absl::nullopt, absl::nullopt, health_checker_flag); } - - // Given the current config, only EDS clusters support multiple priorities. - ASSERT(priority_set_.hostSetsPerPriority().size() == 1); - auto& first_host_set = priority_set_.getOrCreateHostSet(0); - first_host_set.updateHosts(initial_hosts_, createHealthyHostList(*initial_hosts_), - HostsPerLocalityImpl::empty(), HostsPerLocalityImpl::empty(), {}, - *initial_hosts_, {}); - initial_hosts_ = nullptr; + priority_state_manager_.reset(); onPreInitComplete(); } @@ -686,16 +814,22 @@ bool BaseDynamicClusterImpl::updateDynamicHostList(const HostVector& new_hosts, HostVector& hosts_added, HostVector& hosts_removed) { uint64_t max_host_weight = 1; + + // Did hosts change? + // // Has the EDS health status changed the health of any endpoint? If so, we // rebuild the hosts vectors. We only do this if the health status of an // endpoint has materially changed (e.g. if previously failing active health // checks, we just note it's now failing EDS health status but don't rebuild). + // + // Likewise, if metadata for an endpoint changed we rebuild the hosts vectors. + // // TODO(htuch): We can be smarter about this potentially, and not force a full // host set update on health status change. The way this would work is to // implement a HealthChecker subclass that provides thread local health - // updates to the Cluster objeect. This will probably make sense to do in + // updates to the Cluster object. This will probably make sense to do in // conjunction with https://github.com/envoyproxy/envoy/issues/2874. - bool health_changed = false; + bool hosts_changed = false; // Go through and see if the list we have is different from what we just got. If it is, we make a // new host list and raise a change notification. This uses an N^2 search given that this does not @@ -727,15 +861,32 @@ bool BaseDynamicClusterImpl::updateDynamicHostList(const HostVector& new_hosts, (*i)->healthFlagSet(Host::HealthFlag::FAILED_EDS_HEALTH); // If the host was previously healthy and we're now unhealthy, we need to // rebuild. - health_changed |= previously_healthy; + hosts_changed |= previously_healthy; } else { (*i)->healthFlagClear(Host::HealthFlag::FAILED_EDS_HEALTH); // If the host was previously unhealthy and now healthy, we need to // rebuild. - health_changed |= !previously_healthy && (*i)->healthy(); + hosts_changed |= !previously_healthy && (*i)->healthy(); } } + // Did metadata change? + const bool metadata_changed = + !Protobuf::util::MessageDifferencer::Equivalent(*host->metadata(), *(*i)->metadata()); + if (metadata_changed) { + // First, update the entire metadata for the endpoint. + (*i)->metadata(*host->metadata()); + + // Also, given that the canary attribute of an endpoint is derived from its metadata + // (e.g.: from envoy.lb/canary), we do a blind update here since it's cheaper than testing + // to see if it actually changed. We must update this besides just updating the metadata, + // because it'll be used by the router filter to compute upstream stats. + (*i)->canary(host->canary()); + + // If metadata changed, we need to rebuild. See github issue #3810. + hosts_changed = true; + } + (*i)->weight(host->weight()); final_hosts.push_back(*i); i = current_hosts.erase(i); @@ -793,23 +944,24 @@ bool BaseDynamicClusterImpl::updateDynamicHostList(const HostVector& new_hosts, // During the search we moved all of the hosts from hosts_ into final_hosts so just // move them back. current_hosts = std::move(final_hosts); - // We return false here in the absence of EDS health status, because we + // We return false here in the absence of EDS health status or metadata changes, because we // have no changes to host vector status (modulo weights). When we have EDS - // health status, we return true, causing updateHosts() to fire in the + // health status or metadata changed, we return true, causing updateHosts() to fire in the // caller. - return health_changed; + return hosts_changed; } } StrictDnsClusterImpl::StrictDnsClusterImpl(const envoy::api::v2::Cluster& cluster, Runtime::Loader& runtime, Stats::Store& stats, Ssl::ContextManager& ssl_context_manager, + const LocalInfo::LocalInfo& local_info, Network::DnsResolverSharedPtr dns_resolver, ClusterManager& cm, Event::Dispatcher& dispatcher, bool added_via_api) : BaseDynamicClusterImpl(cluster, cm.bindConfig(), runtime, stats, ssl_context_manager, cm.clusterManagerFactory().secretManager(), added_via_api), - dns_resolver_(dns_resolver), + local_info_(local_info), dns_resolver_(dns_resolver), dns_refresh_rate_ms_( std::chrono::milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(cluster, dns_refresh_rate, 5000))) { switch (cluster.dns_lookup_family()) { @@ -823,14 +975,21 @@ StrictDnsClusterImpl::StrictDnsClusterImpl(const envoy::api::v2::Cluster& cluste dns_lookup_family_ = Network::DnsLookupFamily::Auto; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } - for (const auto& host : cluster.hosts()) { - resolve_targets_.emplace_back( - new ResolveTarget(*this, dispatcher, - fmt::format("tcp://{}:{}", host.socket_address().address(), - host.socket_address().port_value()))); + const envoy::api::v2::ClusterLoadAssignment load_assignment( + cluster.has_load_assignment() ? cluster.load_assignment() + : Config::Utility::translateClusterHosts(cluster.hosts())); + const auto& locality_lb_endpoints = load_assignment.endpoints(); + for (const auto& locality_lb_endpoint : locality_lb_endpoints) { + for (const auto& lb_endpoint : locality_lb_endpoint.lb_endpoints()) { + const auto& host = lb_endpoint.endpoint().address(); + const std::string& url = fmt::format("tcp://{}:{}", host.socket_address().address(), + host.socket_address().port_value()); + resolve_targets_.emplace_back( + new ResolveTarget(*this, dispatcher, url, locality_lb_endpoint, lb_endpoint)); + } } } @@ -841,29 +1000,34 @@ void StrictDnsClusterImpl::startPreInit() { } void StrictDnsClusterImpl::updateAllHosts(const HostVector& hosts_added, - const HostVector& hosts_removed) { + const HostVector& hosts_removed, + uint32_t current_priority) { + PriorityStateManager priority_state_manager(*this, local_info_); // At this point we know that we are different so make a new host list and notify. - HostVectorSharedPtr new_hosts(new HostVector()); for (const ResolveTargetPtr& target : resolve_targets_) { + priority_state_manager.initializePriorityFor(target->locality_lb_endpoint_); for (const HostSharedPtr& host : target->hosts_) { - new_hosts->emplace_back(host); + if (target->locality_lb_endpoint_.priority() == current_priority) { + priority_state_manager.registerHostForPriority(host, target->locality_lb_endpoint_, + target->lb_endpoint_, absl::nullopt); + } } } - // Given the current config, only EDS clusters support multiple priorities. - ASSERT(priority_set_.hostSetsPerPriority().size() == 1); - auto& first_host_set = priority_set_.getOrCreateHostSet(0); - first_host_set.updateHosts(new_hosts, createHealthyHostList(*new_hosts), - HostsPerLocalityImpl::empty(), HostsPerLocalityImpl::empty(), {}, - hosts_added, hosts_removed); + // TODO(dio): Add assertion in here. + priority_state_manager.updateClusterPrioritySet( + current_priority, std::move(priority_state_manager.priorityState()[current_priority].first), + hosts_added, hosts_removed, absl::nullopt); } -StrictDnsClusterImpl::ResolveTarget::ResolveTarget(StrictDnsClusterImpl& parent, - Event::Dispatcher& dispatcher, - const std::string& url) +StrictDnsClusterImpl::ResolveTarget::ResolveTarget( + StrictDnsClusterImpl& parent, Event::Dispatcher& dispatcher, const std::string& url, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint) : parent_(parent), dns_address_(Network::Utility::hostFromTcpUrl(url)), port_(Network::Utility::portFromTcpUrl(url)), - resolve_timer_(dispatcher.createTimer([this]() -> void { startResolve(); })) {} + resolve_timer_(dispatcher.createTimer([this]() -> void { startResolve(); })), + locality_lb_endpoint_(locality_lb_endpoint), lb_endpoint_(lb_endpoint) {} StrictDnsClusterImpl::ResolveTarget::~ResolveTarget() { if (active_query_) { @@ -884,28 +1048,28 @@ void StrictDnsClusterImpl::ResolveTarget::startResolve() { HostVector new_hosts; for (const Network::Address::InstanceConstSharedPtr& address : address_list) { - // TODO(mattklein123): Currently the DNS interface does not consider port. We need to make - // a new address that has port in it. We need to both support IPv6 as well as potentially - // move port handling into the DNS interface itself, which would work better for SRV. + // TODO(mattklein123): Currently the DNS interface does not consider port. We need to + // make a new address that has port in it. We need to both support IPv6 as well as + // potentially move port handling into the DNS interface itself, which would work better + // for SRV. ASSERT(address != nullptr); new_hosts.emplace_back(new HostImpl( parent_.info_, dns_address_, Network::Utility::getAddressWithPort(*address, port_), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); + lb_endpoint_.metadata(), lb_endpoint_.load_balancing_weight().value(), + locality_lb_endpoint_.locality(), lb_endpoint_.endpoint().health_check_config())); } HostVector hosts_added; HostVector hosts_removed; if (parent_.updateDynamicHostList(new_hosts, hosts_, hosts_added, hosts_removed)) { ENVOY_LOG(debug, "DNS hosts have changed for {}", dns_address_); - parent_.updateAllHosts(hosts_added, hosts_removed); + parent_.updateAllHosts(hosts_added, hosts_removed, locality_lb_endpoint_.priority()); } // If there is an initialize callback, fire it now. Note that if the cluster refers to - // multiple DNS names, this will return initialized after a single DNS resolution completes. - // This is not perfect but is easier to code and unclear if the extra complexity is needed - // so will start with this. + // multiple DNS names, this will return initialized after a single DNS resolution + // completes. This is not perfect but is easier to code and unclear if the extra + // complexity is needed so will start with this. parent_.onPreInitComplete(); resolve_timer_->enableTimer(parent_.dns_refresh_rate_ms_); }); diff --git a/source/common/upstream/upstream_impl.h b/source/common/upstream/upstream_impl.h index 6540b8ac4ca8a..90b7935d55d44 100644 --- a/source/common/upstream/upstream_impl.h +++ b/source/common/upstream/upstream_impl.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -69,13 +70,32 @@ class HostDescriptionImpl : virtual public HostDescription { canary_(Config::Metadata::metadataValue(metadata, Config::MetadataFilters::get().ENVOY_LB, Config::MetadataEnvoyLbKeys::get().CANARY) .bool_value()), - metadata_(metadata), locality_(locality), stats_{ALL_HOST_STATS(POOL_COUNTER(stats_store_), - POOL_GAUGE(stats_store_))} { - } + metadata_(std::make_shared(metadata)), + locality_(locality), stats_{ALL_HOST_STATS(POOL_COUNTER(stats_store_), + POOL_GAUGE(stats_store_))} {} // Upstream::HostDescription bool canary() const override { return canary_; } - const envoy::api::v2::core::Metadata& metadata() const override { return metadata_; } + void canary(bool is_canary) override { canary_ = is_canary; } + + // Metadata getter/setter. + // + // It's possible that the lock that guards the metadata will become highly contended (e.g.: + // endpoints churning during a deploy of a large cluster). A possible improvement + // would be to use TLS and post metadata updates from the main thread. This model would + // possibly benefit other related and expensive computations too (e.g.: updating subsets). + // + // TODO(rgs1): we should move to absl locks, once there's support for R/W locks. We should + // also add lock annotations, once they work correctly with R/W locks. + const std::shared_ptr metadata() const override { + std::shared_lock lock(metadata_mutex_); + return metadata_; + } + virtual void metadata(const envoy::api::v2::core::Metadata& new_metadata) override { + std::unique_lock lock(metadata_mutex_); + metadata_ = std::make_shared(new_metadata); + } + const ClusterInfo& cluster() const override { return *cluster_; } HealthCheckHostMonitor& healthChecker() const override { if (health_checker_) { @@ -108,8 +128,9 @@ class HostDescriptionImpl : virtual public HostDescription { const std::string hostname_; Network::Address::InstanceConstSharedPtr address_; Network::Address::InstanceConstSharedPtr health_check_address_; - const bool canary_; - const envoy::api::v2::core::Metadata metadata_; + std::atomic canary_; + mutable std::shared_timed_mutex metadata_mutex_; + std::shared_ptr metadata_; const envoy::api::v2::core::Locality locality_; Stats::IsolatedStoreImpl stats_store_; HostStats stats_; @@ -413,14 +434,13 @@ class ClusterInfoImpl : public ClusterInfo, class ClusterImplBase : public Cluster, protected Logger::Loggable { public: - static ClusterSharedPtr create(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, - Stats::Store& stats, ThreadLocal::Instance& tls, - Network::DnsResolverSharedPtr dns_resolver, - Ssl::ContextManager& ssl_context_manager, Runtime::Loader& runtime, - Runtime::RandomGenerator& random, Event::Dispatcher& dispatcher, - const LocalInfo::LocalInfo& local_info, - Outlier::EventLoggerSharedPtr outlier_event_logger, - bool added_via_api); + static ClusterSharedPtr + create(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Stats::Store& stats, + ThreadLocal::Instance& tls, Network::DnsResolverSharedPtr dns_resolver, + Ssl::ContextManager& ssl_context_manager, Runtime::Loader& runtime, + Runtime::RandomGenerator& random, Event::Dispatcher& dispatcher, + AccessLog::AccessLogManager& log_manager, const LocalInfo::LocalInfo& local_info, + Outlier::EventLoggerSharedPtr outlier_event_logger, bool added_via_api); // From Upstream::Cluster virtual PrioritySet& prioritySet() override { return priority_set_; } virtual const PrioritySet& prioritySet() const override { return priority_set_; } @@ -447,6 +467,9 @@ class ClusterImplBase : public Cluster, protected Logger::Loggable HostListPtr; +typedef std::unordered_map + LocalityWeightsMap; +typedef std::vector> PriorityState; + +/** + * Manages PriorityState of a cluster. PriorityState is a per-priority binding of a set of hosts + * with its corresponding locality weight map. This is useful to store priorities/hosts/localities + * before updating the cluster priority set. + */ +class PriorityStateManager : protected Logger::Loggable { +public: + PriorityStateManager(ClusterImplBase& cluster, const LocalInfo::LocalInfo& local_info); + + // Initializes the PriorityState vector based on the priority specified in locality_lb_endpoint. + void + initializePriorityFor(const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint); + + // Registers a host based on its address to the PriorityState based on the specified priority (the + // priority is specified by locality_lb_endpoint.priority()). + // + // The specified health_checker_flag is used to set the registered-host's health-flag when the + // lb_endpoint health status is unhealty, draining or timeout. + void + registerHostForPriority(const std::string& hostname, + Network::Address::InstanceConstSharedPtr address, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint, + const absl::optional health_checker_flag); + + void + registerHostForPriority(const HostSharedPtr& host, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint, + const absl::optional health_checker_flag); + + void + updateClusterPrioritySet(const uint32_t priority, HostVectorSharedPtr&& current_hosts, + const absl::optional& hosts_added, + const absl::optional& hosts_removed, + const absl::optional health_checker_flag); + + // Returns the size of the current cluster priority state. + size_t size() const { return priority_state_.size(); } + + // Returns the saved priority state. + PriorityState& priorityState() { return priority_state_; } + +private: + ClusterImplBase& parent_; + PriorityState priority_state_; + const envoy::api::v2::core::Node& local_info_node_; +}; + +typedef std::unique_ptr PriorityStateManagerPtr; + /** * Implementation of Upstream::Cluster for static clusters (clusters that have a fixed number of * hosts with resolved IP addresses). @@ -502,7 +577,7 @@ class StaticClusterImpl : public ClusterImplBase { public: StaticClusterImpl(const envoy::api::v2::Cluster& cluster, Runtime::Loader& runtime, Stats::Store& stats, Ssl::ContextManager& ssl_context_manager, - ClusterManager& cm, bool added_via_api); + const LocalInfo::LocalInfo& local_info, ClusterManager& cm, bool added_via_api); // Upstream::Cluster InitializePhase initializePhase() const override { return InitializePhase::Primary; } @@ -511,7 +586,7 @@ class StaticClusterImpl : public ClusterImplBase { // ClusterImplBase void startPreInit() override; - HostVectorSharedPtr initial_hosts_; + PriorityStateManagerPtr priority_state_manager_; }; /** @@ -533,6 +608,7 @@ class StrictDnsClusterImpl : public BaseDynamicClusterImpl { public: StrictDnsClusterImpl(const envoy::api::v2::Cluster& cluster, Runtime::Loader& runtime, Stats::Store& stats, Ssl::ContextManager& ssl_context_manager, + const LocalInfo::LocalInfo& local_info, Network::DnsResolverSharedPtr dns_resolver, ClusterManager& cm, Event::Dispatcher& dispatcher, bool added_via_api); @@ -542,7 +618,9 @@ class StrictDnsClusterImpl : public BaseDynamicClusterImpl { private: struct ResolveTarget { ResolveTarget(StrictDnsClusterImpl& parent, Event::Dispatcher& dispatcher, - const std::string& url); + const std::string& url, + const envoy::api::v2::endpoint::LocalityLbEndpoints& locality_lb_endpoint, + const envoy::api::v2::endpoint::LbEndpoint& lb_endpoint); ~ResolveTarget(); void startResolve(); @@ -552,15 +630,19 @@ class StrictDnsClusterImpl : public BaseDynamicClusterImpl { uint32_t port_; Event::TimerPtr resolve_timer_; HostVector hosts_; + const envoy::api::v2::endpoint::LocalityLbEndpoints locality_lb_endpoint_; + const envoy::api::v2::endpoint::LbEndpoint lb_endpoint_; }; typedef std::unique_ptr ResolveTargetPtr; - void updateAllHosts(const HostVector& hosts_added, const HostVector& hosts_removed); + void updateAllHosts(const HostVector& hosts_added, const HostVector& hosts_removed, + uint32_t priority); // ClusterImplBase void startPreInit() override; + const LocalInfo::LocalInfo& local_info_; Network::DnsResolverSharedPtr dns_resolver_; std::list resolve_targets_; const std::chrono::milliseconds dns_refresh_rate_ms_; diff --git a/source/exe/BUILD b/source/exe/BUILD index b9fa77f1663a3..8d1e932de1d10 100644 --- a/source/exe/BUILD +++ b/source/exe/BUILD @@ -9,6 +9,7 @@ load( load( "//source/extensions:all_extensions.bzl", "envoy_all_extensions", + "envoy_windows_extensions", ) envoy_package() @@ -37,12 +38,18 @@ envoy_cc_library( "//source/server:options_lib", "//source/server:server_lib", "//source/server:test_hooks_lib", - ] + envoy_all_extensions(), + ] + select({ + "//bazel:windows_x86_64": envoy_windows_extensions(), + "//conditions:default": envoy_all_extensions(), + }), ) envoy_cc_library( name = "envoy_main_entry_lib", srcs = ["main.cc"], + external_deps = [ + "abseil_symbolize", + ], deps = [ ":envoy_main_common_lib", ], diff --git a/source/exe/main.cc b/source/exe/main.cc index e1cb0ed4a651f..fae6d4b6585ef 100644 --- a/source/exe/main.cc +++ b/source/exe/main.cc @@ -1,5 +1,7 @@ #include "exe/main_common.h" +#include "absl/debugging/symbolize.h" + // NOLINT(namespace-envoy) /** @@ -10,6 +12,11 @@ * after setting up command line options. */ int main(int argc, char** argv) { +#ifndef __APPLE__ + // absl::Symbolize mostly works without this, but this improves corner case + // handling, such as running in a chroot jail. + absl::InitializeSymbolizer(argv[0]); +#endif std::unique_ptr main_common; // Initialize the server's main context under a try/catch loop and simply return EXIT_FAILURE diff --git a/source/exe/main_common.cc b/source/exe/main_common.cc index 04e6d471a3f46..862c0b95c7404 100644 --- a/source/exe/main_common.cc +++ b/source/exe/main_common.cc @@ -41,9 +41,8 @@ Runtime::LoaderPtr ProdComponentFactory::createRuntime(Server::Instance& server, MainCommonBase::MainCommonBase(OptionsImpl& options) : options_(options) { ares_library_init(ARES_LIB_INIT_ALL); Event::Libevent::Global::initialize(); - RELEASE_ASSERT(Envoy::Server::validateProtoDescriptors()); + RELEASE_ASSERT(Envoy::Server::validateProtoDescriptors(), ""); - Stats::RawStatData::configure(options_); switch (options_.mode()) { case Server::Mode::InitOnly: case Server::Mode::Serve: { @@ -62,7 +61,8 @@ MainCommonBase::MainCommonBase(OptionsImpl& options) : options_(options) { auto local_address = Network::Utility::getLocalAddress(options_.localAddressIpVersion()); Logger::Registry::initialize(options_.logLevel(), options_.logFormat(), log_lock); - stats_store_.reset(new Stats::ThreadLocalStoreImpl(restarter_->statsAllocator())); + stats_store_ = std::make_unique(options_.statsOptions(), + restarter_->statsAllocator()); server_.reset(new Server::InstanceImpl( options_, local_address, default_test_hooks_, *restarter_, *stats_store_, access_log_lock, component_factory_, std::make_unique(), *tls_)); @@ -90,7 +90,7 @@ bool MainCommonBase::run() { PERF_DUMP(); return true; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } MainCommon::MainCommon(int argc, const char* const* argv) diff --git a/source/exe/signal_action.cc b/source/exe/signal_action.cc index ac6211f280654..bcd297e141448 100644 --- a/source/exe/signal_action.cc +++ b/source/exe/signal_action.cc @@ -46,7 +46,7 @@ void SignalAction::installSigHandlers() { stack.ss_size = altstack_size_; // ... guard page at the other stack.ss_flags = 0; - RELEASE_ASSERT(sigaltstack(&stack, &previous_altstack_) == 0); + RELEASE_ASSERT(sigaltstack(&stack, &previous_altstack_) == 0, ""); int hidx = 0; for (const auto& sig : FATAL_SIGS) { @@ -56,7 +56,7 @@ void SignalAction::installSigHandlers() { saction.sa_flags = (SA_SIGINFO | SA_ONSTACK | SA_RESETHAND | SA_NODEFER); saction.sa_sigaction = sigHandler; auto* handler = &previous_handlers_[hidx++]; - RELEASE_ASSERT(sigaction(sig, &saction, handler) == 0); + RELEASE_ASSERT(sigaction(sig, &saction, handler) == 0, ""); } } @@ -67,12 +67,12 @@ void SignalAction::removeSigHandlers() { previous_altstack_.ss_size = MINSIGSTKSZ; } #endif - RELEASE_ASSERT(sigaltstack(&previous_altstack_, nullptr) == 0); + RELEASE_ASSERT(sigaltstack(&previous_altstack_, nullptr) == 0, ""); int hidx = 0; for (const auto& sig : FATAL_SIGS) { auto* handler = &previous_handlers_[hidx++]; - RELEASE_ASSERT(sigaction(sig, handler, nullptr) == 0); + RELEASE_ASSERT(sigaction(sig, handler, nullptr) == 0, ""); } } @@ -85,9 +85,10 @@ void SignalAction::mapAndProtectStackMemory() { // library hint that might be used in the future. altstack_ = static_cast(mmap(nullptr, mapSizeWithGuards(), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0)); - RELEASE_ASSERT(altstack_); - RELEASE_ASSERT(mprotect(altstack_, guard_size_, PROT_NONE) == 0); - RELEASE_ASSERT(mprotect(altstack_ + guard_size_ + altstack_size_, guard_size_, PROT_NONE) == 0); + RELEASE_ASSERT(altstack_, ""); + RELEASE_ASSERT(mprotect(altstack_, guard_size_, PROT_NONE) == 0, ""); + RELEASE_ASSERT(mprotect(altstack_ + guard_size_ + altstack_size_, guard_size_, PROT_NONE) == 0, + ""); } void SignalAction::unmapStackMemory() { munmap(altstack_, mapSizeWithGuards()); } diff --git a/source/extensions/access_loggers/file/BUILD b/source/extensions/access_loggers/file/BUILD index 124d5f838e5e5..e83a33f02a9f7 100644 --- a/source/extensions/access_loggers/file/BUILD +++ b/source/extensions/access_loggers/file/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Access log implementation that writes to a file. # Public docs: docs/root/configuration/access_log.rst diff --git a/source/extensions/access_loggers/file/config.cc b/source/extensions/access_loggers/file/config.cc index 8494f39db931f..8490a72eccb04 100644 --- a/source/extensions/access_loggers/file/config.cc +++ b/source/extensions/access_loggers/file/config.cc @@ -35,7 +35,7 @@ ProtobufTypes::MessagePtr FileAccessLogFactory::createEmptyConfigProto() { return ProtobufTypes::MessagePtr{new envoy::config::accesslog::v2::FileAccessLog()}; } -std::string FileAccessLogFactory::name() const { return AccessLogNames::get().FILE; } +std::string FileAccessLogFactory::name() const { return AccessLogNames::get().File; } /** * Static registration for the file access log. @see RegisterFactory. diff --git a/source/extensions/access_loggers/http_grpc/BUILD b/source/extensions/access_loggers/http_grpc/BUILD index 8839d961083e3..8b372ef7b8648 100644 --- a/source/extensions/access_loggers/http_grpc/BUILD +++ b/source/extensions/access_loggers/http_grpc/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Access log implementation that writes to a gRPC service. # Public docs: TODO(rodaine): Docs needed. diff --git a/source/extensions/access_loggers/http_grpc/config.cc b/source/extensions/access_loggers/http_grpc/config.cc index ec24ebaea4072..4953c2bd3d596 100644 --- a/source/extensions/access_loggers/http_grpc/config.cc +++ b/source/extensions/access_loggers/http_grpc/config.cc @@ -44,7 +44,7 @@ ProtobufTypes::MessagePtr HttpGrpcAccessLogFactory::createEmptyConfigProto() { return ProtobufTypes::MessagePtr{new envoy::config::accesslog::v2::HttpGrpcAccessLogConfig()}; } -std::string HttpGrpcAccessLogFactory::name() const { return AccessLogNames::get().HTTP_GRPC; } +std::string HttpGrpcAccessLogFactory::name() const { return AccessLogNames::get().HttpGrpc; } /** * Static registration for the HTTP gRPC access log. @see RegisterFactory. diff --git a/source/extensions/access_loggers/http_grpc/grpc_access_log_impl.cc b/source/extensions/access_loggers/http_grpc/grpc_access_log_impl.cc index b90bbfc352dbf..6804cdd9bd617 100644 --- a/source/extensions/access_loggers/http_grpc/grpc_access_log_impl.cc +++ b/source/extensions/access_loggers/http_grpc/grpc_access_log_impl.cc @@ -88,55 +88,55 @@ void HttpGrpcAccessLog::responseFlagsToAccessLogResponseFlags( static_assert(RequestInfo::ResponseFlag::LastFlag == 0x1000, "A flag has been added. Fix this code."); - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::FailedLocalHealthCheck)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::FailedLocalHealthCheck)) { common_access_log.mutable_response_flags()->set_failed_local_healthcheck(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::NoHealthyUpstream)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::NoHealthyUpstream)) { common_access_log.mutable_response_flags()->set_no_healthy_upstream(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::UpstreamRequestTimeout)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::UpstreamRequestTimeout)) { common_access_log.mutable_response_flags()->set_upstream_request_timeout(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::LocalReset)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::LocalReset)) { common_access_log.mutable_response_flags()->set_local_reset(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::UpstreamRemoteReset)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::UpstreamRemoteReset)) { common_access_log.mutable_response_flags()->set_upstream_remote_reset(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::UpstreamConnectionFailure)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::UpstreamConnectionFailure)) { common_access_log.mutable_response_flags()->set_upstream_connection_failure(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::UpstreamConnectionTermination)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::UpstreamConnectionTermination)) { common_access_log.mutable_response_flags()->set_upstream_connection_termination(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::UpstreamOverflow)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::UpstreamOverflow)) { common_access_log.mutable_response_flags()->set_upstream_overflow(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::NoRouteFound)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::NoRouteFound)) { common_access_log.mutable_response_flags()->set_no_route_found(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::DelayInjected)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::DelayInjected)) { common_access_log.mutable_response_flags()->set_delay_injected(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::FaultInjected)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::FaultInjected)) { common_access_log.mutable_response_flags()->set_fault_injected(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::RateLimited)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::RateLimited)) { common_access_log.mutable_response_flags()->set_rate_limited(true); } - if (request_info.getResponseFlag(RequestInfo::ResponseFlag::UnauthorizedExternalService)) { + if (request_info.hasResponseFlag(RequestInfo::ResponseFlag::UnauthorizedExternalService)) { common_access_log.mutable_response_flags()->mutable_unauthorized_details()->set_reason( envoy::data::accesslog::v2::ResponseFlags_Unauthorized_Reason:: ResponseFlags_Unauthorized_Reason_EXTERNAL_SERVICE); diff --git a/source/extensions/access_loggers/well_known_names.h b/source/extensions/access_loggers/well_known_names.h index 1b1c22fa758a6..737a4b9c04d38 100644 --- a/source/extensions/access_loggers/well_known_names.h +++ b/source/extensions/access_loggers/well_known_names.h @@ -13,9 +13,9 @@ namespace AccessLoggers { class AccessLogNameValues { public: // File access log - const std::string FILE = "envoy.file_access_log"; + const std::string File = "envoy.file_access_log"; // HTTP gRPC access log - const std::string HTTP_GRPC = "envoy.http_grpc_access_log"; + const std::string HttpGrpc = "envoy.http_grpc_access_log"; }; typedef ConstSingleton AccessLogNames; diff --git a/source/extensions/all_extensions.bzl b/source/extensions/all_extensions.bzl index 6649a0cfdd526..1dc1a34aad6bc 100644 --- a/source/extensions/all_extensions.bzl +++ b/source/extensions/all_extensions.bzl @@ -1,16 +1,30 @@ -load("@envoy_build_config//:extensions_build_config.bzl", "EXTENSIONS") +load("@envoy_build_config//:extensions_build_config.bzl", "EXTENSIONS", "WINDOWS_EXTENSIONS") # Return all extensions to be compiled into Envoy. def envoy_all_extensions(): - # These extensions are registered using the extension system but are required for the core - # Envoy build. - all_extensions = [ - "//source/extensions/transport_sockets/raw_buffer:config", - "//source/extensions/transport_sockets/ssl:config", - ] + # These extensions are registered using the extension system but are required for the core + # Envoy build. + all_extensions = [ + "//source/extensions/transport_sockets/raw_buffer:config", + "//source/extensions/transport_sockets/ssl:config", + ] - # These extensions can be removed on a site specific basis. - for path in EXTENSIONS.values(): - all_extensions.append(path) + # These extensions can be removed on a site specific basis. + for path in EXTENSIONS.values(): + all_extensions.append(path) - return all_extensions + return all_extensions + +def envoy_windows_extensions(): + # These extensions are registered using the extension system but are required for the core + # Envoy build. + windows_extensions = [ + "//source/extensions/transport_sockets/raw_buffer:config", + "//source/extensions/transport_sockets/ssl:config", + ] + + # These extensions can be removed on a site specific basis. + for path in WINDOWS_EXTENSIONS.values(): + windows_extensions.append(path) + + return windows_extensions diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index 75383a02a578a..736f8c319c72c 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -70,14 +70,27 @@ EXTENSIONS = { "envoy.filters.network.tcp_proxy": "//source/extensions/filters/network/tcp_proxy:config", "envoy.filters.network.thrift_proxy": "//source/extensions/filters/network/thrift_proxy:config", + # + # Resource monitors + # + + "envoy.resource_monitors.fixed_heap": "//source/extensions/resource_monitors/fixed_heap:config", + # # Stat sinks # "envoy.stat_sinks.dog_statsd": "//source/extensions/stat_sinks/dog_statsd:config", + "envoy.stat_sinks.hystrix": "//source/extensions/stat_sinks/hystrix:config", "envoy.stat_sinks.metrics_service": "//source/extensions/stat_sinks/metrics_service:config", "envoy.stat_sinks.statsd": "//source/extensions/stat_sinks/statsd:config", + # + # Thrift filters + # + + "envoy.filters.thrift.router": "//source/extensions/filters/network/thrift_proxy/router:config", + # # Tracers # @@ -94,3 +107,96 @@ EXTENSIONS = { "envoy.transport_sockets.alts": "//source/extensions/transport_sockets/alts:tsi_handshaker", "envoy.transport_sockets.capture": "//source/extensions/transport_sockets/capture:config", } + +WINDOWS_EXTENSIONS = { + # + # Access loggers + # + + "envoy.access_loggers.file": "//source/extensions/access_loggers/file:config", + #"envoy.access_loggers.http_grpc": "//source/extensions/access_loggers/http_grpc:config", + + # + # gRPC Credentials Plugins + # + + #"envoy.grpc_credentials.file_based_metadata": "//source/extensions/grpc_credentials/file_based_metadata:config", + + # + # Health checkers + # + + #"envoy.health_checkers.redis": "//source/extensions/health_checkers/redis:config", + + # + # HTTP filters + # + + #"envoy.filters.http.buffer": "//source/extensions/filters/http/buffer:config", + #"envoy.filters.http.cors": "//source/extensions/filters/http/cors:config", + #"envoy.filters.http.dynamo": "//source/extensions/filters/http/dynamo:config", + #"envoy.filters.http.ext_authz": "//source/extensions/filters/http/ext_authz:config", + #"envoy.filters.http.fault": "//source/extensions/filters/http/fault:config", + #"envoy.filters.http.grpc_http1_bridge": "//source/extensions/filters/http/grpc_http1_bridge:config", + #"envoy.filters.http.grpc_json_transcoder": "//source/extensions/filters/http/grpc_json_transcoder:config", + #"envoy.filters.http.grpc_web": "//source/extensions/filters/http/grpc_web:config", + #"envoy.filters.http.gzip": "//source/extensions/filters/http/gzip:config", + #"envoy.filters.http.health_check": "//source/extensions/filters/http/health_check:config", + #"envoy.filters.http.ip_tagging": "//source/extensions/filters/http/ip_tagging:config", + #"envoy.filters.http.lua": "//source/extensions/filters/http/lua:config", + #"envoy.filters.http.ratelimit": "//source/extensions/filters/http/ratelimit:config", + #"envoy.filters.http.rbac": "//source/extensions/filters/http/rbac:config", + #"envoy.filters.http.router": "//source/extensions/filters/http/router:config", + #"envoy.filters.http.squash": "//source/extensions/filters/http/squash:config", + + # + # Listener filters + # + + # NOTE: The proxy_protocol filter is implicitly loaded if proxy_protocol functionality is + # configured on the listener. Do not remove it in that case or configs will fail to load. + "envoy.filters.listener.proxy_protocol": "//source/extensions/filters/listener/proxy_protocol:config", + + # NOTE: The original_dst filter is implicitly loaded if original_dst functionality is + # configured on the listener. Do not remove it in that case or configs will fail to load. + #"envoy.filters.listener.original_dst": "//source/extensions/filters/listener/original_dst:config", + + "envoy.filters.listener.tls_inspector": "//source/extensions/filters/listener/tls_inspector:config", + + # + # Network filters + # + + "envoy.filters.network.client_ssl_auth": "//source/extensions/filters/network/client_ssl_auth:config", + #"envoy.filters.network.echo": "//source/extensions/filters/network/echo:config", + #"envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config", + #"envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config", + #"envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config", + #"envoy.filters.network.redis_proxy": "//source/extensions/filters/network/redis_proxy:config", + #"envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config", + "envoy.filters.network.tcp_proxy": "//source/extensions/filters/network/tcp_proxy:config", + # TODO(zuercher): switch to config target once a filter exists + #"envoy.filters.network.thrift_proxy": "//source/extensions/filters/network/thrift_proxy:transport_lib", + + # + # Stat sinks + # + + #"envoy.stat_sinks.dog_statsd": "//source/extensions/stat_sinks/dog_statsd:config", + #"envoy.stat_sinks.metrics_service": "//source/extensions/stat_sinks/metrics_service:config", + #"envoy.stat_sinks.statsd": "//source/extensions/stat_sinks/statsd:config", + + # + # Tracers + # + + #"envoy.tracers.dynamic_ot": "//source/extensions/tracers/dynamic_ot:config", + #"envoy.tracers.lightstep": "//source/extensions/tracers/lightstep:config", + #"envoy.tracers.zipkin": "//source/extensions/tracers/zipkin:config", + + # + # Transport sockets + # + + #"envoy.transport_sockets.capture": "//source/extensions/transport_sockets/capture:config", +} diff --git a/source/extensions/filters/common/ext_authz/BUILD b/source/extensions/filters/common/ext_authz/BUILD index c43ff4924e8b6..d954a8acd6b97 100644 --- a/source/extensions/filters/common/ext_authz/BUILD +++ b/source/extensions/filters/common/ext_authz/BUILD @@ -12,16 +12,18 @@ envoy_cc_library( name = "ext_authz_interface", hdrs = ["ext_authz.h"], deps = [ - "//include/envoy/tracing:http_tracer_interface", + "//include/envoy/http:codes_interface", + "//source/common/tracing:http_tracer_lib", "@envoy_api//envoy/service/auth/v2alpha:external_auth_cc", ], ) envoy_cc_library( - name = "ext_authz_lib", - srcs = ["ext_authz_impl.cc"], - hdrs = ["ext_authz_impl.h"], + name = "ext_authz_grpc_lib", + srcs = ["ext_authz_grpc_impl.cc"], + hdrs = ["ext_authz_grpc_impl.h"], deps = [ + ":check_request_utils_lib", ":ext_authz_interface", "//include/envoy/grpc:async_client_interface", "//include/envoy/grpc:async_client_manager_interface", @@ -31,7 +33,6 @@ envoy_cc_library( "//include/envoy/network:address_interface", "//include/envoy/network:connection_interface", "//include/envoy/network:filter_interface", - "//include/envoy/ssl:connection_interface", "//include/envoy/upstream:cluster_manager_interface", "//source/common/common:assert_lib", "//source/common/grpc:async_client_lib", @@ -42,3 +43,29 @@ envoy_cc_library( "//source/common/tracing:http_tracer_lib", ], ) + +envoy_cc_library( + name = "ext_authz_http_lib", + srcs = ["ext_authz_http_impl.cc"], + hdrs = ["ext_authz_http_impl.h"], + deps = [ + ":check_request_utils_lib", + ":ext_authz_interface", + "//source/common/common:minimal_logger_lib", + "//source/common/http:async_client_lib", + ], +) + +envoy_cc_library( + name = "check_request_utils_lib", + srcs = ["check_request_utils.cc"], + hdrs = ["check_request_utils.h"], + deps = [ + "//include/envoy/grpc:async_client_interface", + "//include/envoy/grpc:async_client_manager_interface", + "//include/envoy/http:filter_interface", + "//include/envoy/upstream:cluster_manager_interface", + "//source/common/grpc:async_client_lib", + "@envoy_api//envoy/service/auth/v2alpha:external_auth_cc", + ], +) diff --git a/source/extensions/filters/common/ext_authz/ext_authz_impl.cc b/source/extensions/filters/common/ext_authz/check_request_utils.cc similarity index 72% rename from source/extensions/filters/common/ext_authz/ext_authz_impl.cc rename to source/extensions/filters/common/ext_authz/check_request_utils.cc index 49fe65c1385b9..90f82d372def4 100644 --- a/source/extensions/filters/common/ext_authz/ext_authz_impl.cc +++ b/source/extensions/filters/common/ext_authz/check_request_utils.cc @@ -1,72 +1,30 @@ -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" +#include "extensions/filters/common/ext_authz/check_request_utils.h" #include #include #include #include -#include "envoy/access_log/access_log.h" #include "envoy/ssl/connection.h" +#include "common/buffer/buffer_impl.h" #include "common/common/assert.h" +#include "common/common/enum_to_int.h" #include "common/grpc/async_client_impl.h" +#include "common/http/codes.h" #include "common/http/headers.h" #include "common/http/utility.h" #include "common/network/utility.h" #include "common/protobuf/protobuf.h" +#include "absl/strings/str_cat.h" + namespace Envoy { namespace Extensions { namespace Filters { namespace Common { namespace ExtAuthz { -GrpcClientImpl::GrpcClientImpl(Grpc::AsyncClientPtr&& async_client, - const absl::optional& timeout) - : service_method_(*Protobuf::DescriptorPool::generated_pool()->FindMethodByName( - // TODO(dio): Define the following service method name as a constant value. - "envoy.service.auth.v2alpha.Authorization.Check")), - async_client_(std::move(async_client)), timeout_(timeout) {} - -GrpcClientImpl::~GrpcClientImpl() { ASSERT(!callbacks_); } - -void GrpcClientImpl::cancel() { - ASSERT(callbacks_ != nullptr); - request_->cancel(); - callbacks_ = nullptr; -} - -void GrpcClientImpl::check(RequestCallbacks& callbacks, - const envoy::service::auth::v2alpha::CheckRequest& request, - Tracing::Span& parent_span) { - ASSERT(callbacks_ == nullptr); - callbacks_ = &callbacks; - - request_ = async_client_->send(service_method_, request, *this, parent_span, timeout_); -} - -void GrpcClientImpl::onSuccess( - std::unique_ptr&& response, Tracing::Span& span) { - CheckStatus status = CheckStatus::OK; - ASSERT(response->status().code() != Grpc::Status::GrpcStatus::Unknown); - if (response->status().code() != Grpc::Status::GrpcStatus::Ok) { - status = CheckStatus::Denied; - span.setTag(Constants::get().TraceStatus, Constants::get().TraceUnauthz); - } else { - span.setTag(Constants::get().TraceStatus, Constants::get().TraceOk); - } - - callbacks_->onComplete(status); - callbacks_ = nullptr; -} - -void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::string&, - Tracing::Span&) { - ASSERT(status != Grpc::Status::GrpcStatus::Ok); - callbacks_->onComplete(CheckStatus::Error); - callbacks_ = nullptr; -} - void CheckRequestUtils::setAttrContextPeer( envoy::service::auth::v2alpha::AttributeContext_Peer& peer, const Network::Connection& connection, const std::string& service, const bool local) { diff --git a/source/extensions/filters/common/ext_authz/ext_authz_impl.h b/source/extensions/filters/common/ext_authz/check_request_utils.h similarity index 66% rename from source/extensions/filters/common/ext_authz/ext_authz_impl.h rename to source/extensions/filters/common/ext_authz/check_request_utils.h index f3266c4875b2e..fa94c45fbc3e7 100644 --- a/source/extensions/filters/common/ext_authz/ext_authz_impl.h +++ b/source/extensions/filters/common/ext_authz/check_request_utils.h @@ -13,59 +13,19 @@ #include "envoy/network/address.h" #include "envoy/network/connection.h" #include "envoy/network/filter.h" +#include "envoy/service/auth/v2alpha/external_auth.pb.h" #include "envoy/tracing/http_tracer.h" #include "envoy/upstream/cluster_manager.h" +#include "common/http/async_client_impl.h" #include "common/singleton/const_singleton.h" -#include "extensions/filters/common/ext_authz/ext_authz.h" - namespace Envoy { namespace Extensions { namespace Filters { namespace Common { namespace ExtAuthz { -typedef Grpc::TypedAsyncRequestCallbacks - ExtAuthzAsyncCallbacks; - -struct ConstantValues { - const std::string TraceStatus = "ext_authz_status"; - const std::string TraceUnauthz = "ext_authz_unauthorized"; - const std::string TraceOk = "ext_authz_ok"; -}; - -typedef ConstSingleton Constants; - -// NOTE: We create gRPC client for each filter stack instead of a client per thread. -// That is ok since this is unary RPC and the cost of doing this is minimal. -class GrpcClientImpl : public Client, public ExtAuthzAsyncCallbacks { -public: - GrpcClientImpl(Grpc::AsyncClientPtr&& async_client, - const absl::optional& timeout); - ~GrpcClientImpl(); - - // ExtAuthz::Client - void cancel() override; - void check(RequestCallbacks& callbacks, - const envoy::service::auth::v2alpha::CheckRequest& request, - Tracing::Span& parent_span) override; - - // Grpc::AsyncRequestCallbacks - void onCreateInitialMetadata(Http::HeaderMap&) override {} - void onSuccess(std::unique_ptr&& response, - Tracing::Span& span) override; - void onFailure(Grpc::Status::GrpcStatus status, const std::string& message, - Tracing::Span& span) override; - -private: - const Protobuf::MethodDescriptor& service_method_; - Grpc::AsyncClientPtr async_client_; - Grpc::AsyncRequest* request_{}; - absl::optional timeout_; - RequestCallbacks* callbacks_{}; -}; - /** * For creating ext_authz.proto (authorization) request. * CheckRequestUtils is used to extract attributes from the TCP/HTTP request diff --git a/source/extensions/filters/common/ext_authz/ext_authz.h b/source/extensions/filters/common/ext_authz/ext_authz.h index 53c67b749b908..0d242f622822a 100644 --- a/source/extensions/filters/common/ext_authz/ext_authz.h +++ b/source/extensions/filters/common/ext_authz/ext_authz.h @@ -6,6 +6,7 @@ #include #include "envoy/common/pure.h" +#include "envoy/http/codes.h" #include "envoy/service/auth/v2alpha/external_auth.pb.h" #include "envoy/tracing/http_tracer.h" @@ -27,6 +28,24 @@ enum class CheckStatus { Denied }; +/** + * Authorization response object for a RequestCallback. + */ +struct Response { + // Call status. + CheckStatus status; + // Optional http headers used on either denied or ok responses. + Http::HeaderVector headers_to_append; + // Optional http headers used on either denied or ok responses. + Http::HeaderVector headers_to_add; + // Optional http body used only on denied response. + std::string body; + // Optional http status used only on denied response. + Http::Code status_code{}; +}; + +typedef std::unique_ptr ResponsePtr; + /** * Async callbacks used during check() calls. */ @@ -35,9 +54,9 @@ class RequestCallbacks { virtual ~RequestCallbacks() {} /** - * Called when a check request is complete. The resulting status is supplied. + * Called when a check request is complete. The resulting ResponsePtr is supplied. */ - virtual void onComplete(CheckStatus status) PURE; + virtual void onComplete(ResponsePtr&& response) PURE; }; class Client { diff --git a/source/extensions/filters/common/ext_authz/ext_authz_grpc_impl.cc b/source/extensions/filters/common/ext_authz/ext_authz_grpc_impl.cc new file mode 100644 index 0000000000000..2dc5cd33023d0 --- /dev/null +++ b/source/extensions/filters/common/ext_authz/ext_authz_grpc_impl.cc @@ -0,0 +1,95 @@ +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" + +#include "common/common/assert.h" +#include "common/grpc/async_client_impl.h" +#include "common/http/headers.h" +#include "common/http/utility.h" +#include "common/network/utility.h" +#include "common/protobuf/protobuf.h" + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +GrpcClientImpl::GrpcClientImpl(Grpc::AsyncClientPtr&& async_client, + const absl::optional& timeout) + : service_method_(*Protobuf::DescriptorPool::generated_pool()->FindMethodByName( + // TODO(dio): Define the following service method name as a constant value. + "envoy.service.auth.v2alpha.Authorization.Check")), + async_client_(std::move(async_client)), timeout_(timeout) {} + +GrpcClientImpl::~GrpcClientImpl() { ASSERT(!callbacks_); } + +void GrpcClientImpl::cancel() { + ASSERT(callbacks_ != nullptr); + request_->cancel(); + callbacks_ = nullptr; +} + +void GrpcClientImpl::check(RequestCallbacks& callbacks, + const envoy::service::auth::v2alpha::CheckRequest& request, + Tracing::Span& parent_span) { + ASSERT(callbacks_ == nullptr); + callbacks_ = &callbacks; + + request_ = async_client_->send(service_method_, request, *this, parent_span, timeout_); +} + +void GrpcClientImpl::onSuccess( + std::unique_ptr&& response, Tracing::Span& span) { + ASSERT(response->status().code() != Grpc::Status::GrpcStatus::Unknown); + ResponsePtr authz_response = std::make_unique(Response{}); + + if (response->status().code() == Grpc::Status::GrpcStatus::Ok) { + span.setTag(Constants::get().TraceStatus, Constants::get().TraceOk); + authz_response->status = CheckStatus::OK; + if (response->has_ok_response()) { + toAuthzResponseHeader(authz_response, response->ok_response().headers()); + } + } else { + span.setTag(Constants::get().TraceStatus, Constants::get().TraceUnauthz); + authz_response->status = CheckStatus::Denied; + if (response->has_denied_response()) { + toAuthzResponseHeader(authz_response, response->denied_response().headers()); + authz_response->status_code = + static_cast(response->denied_response().status().code()); + authz_response->body = response->denied_response().body(); + } else { + authz_response->status_code = Http::Code::Forbidden; + } + } + + callbacks_->onComplete(std::move(authz_response)); + callbacks_ = nullptr; +} + +void GrpcClientImpl::onFailure(Grpc::Status::GrpcStatus status, const std::string&, + Tracing::Span&) { + ASSERT(status != Grpc::Status::GrpcStatus::Ok); + ResponsePtr authz_response = std::make_unique(Response{}); + authz_response->status = CheckStatus::Error; + callbacks_->onComplete(std::move(authz_response)); + callbacks_ = nullptr; +} + +void GrpcClientImpl::toAuthzResponseHeader( + ResponsePtr& response, + const Protobuf::RepeatedPtrField& headers) { + for (const auto& header : headers) { + if (header.append().value()) { + response->headers_to_append.emplace_back(Http::LowerCaseString(header.header().key()), + header.header().value()); + } else { + response->headers_to_add.emplace_back(Http::LowerCaseString(header.header().key()), + header.header().value()); + } + } +} + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/common/ext_authz/ext_authz_grpc_impl.h b/source/extensions/filters/common/ext_authz/ext_authz_grpc_impl.h new file mode 100644 index 0000000000000..0c44a92aebdc2 --- /dev/null +++ b/source/extensions/filters/common/ext_authz/ext_authz_grpc_impl.h @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include +#include + +#include "envoy/grpc/async_client.h" +#include "envoy/grpc/async_client_manager.h" +#include "envoy/http/filter.h" +#include "envoy/http/header_map.h" +#include "envoy/http/protocol.h" +#include "envoy/network/address.h" +#include "envoy/network/connection.h" +#include "envoy/network/filter.h" +#include "envoy/tracing/http_tracer.h" +#include "envoy/upstream/cluster_manager.h" + +#include "common/singleton/const_singleton.h" + +#include "extensions/filters/common/ext_authz/check_request_utils.h" +#include "extensions/filters/common/ext_authz/ext_authz.h" + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +typedef Grpc::TypedAsyncRequestCallbacks + ExtAuthzAsyncCallbacks; + +struct ConstantValues { + const std::string TraceStatus = "ext_authz_status"; + const std::string TraceUnauthz = "ext_authz_unauthorized"; + const std::string TraceOk = "ext_authz_ok"; +}; + +typedef ConstSingleton Constants; + +/* + * This client implementation is used when the Ext_Authz filter needs to communicate with an gRPC + * authorization server. Unlike the HTTP client, the gRPC allows the server to define response + * objects which contain the HTTP attributes to be sent to the upstream or to the downstream client. + * The gRPC client does not rewrite path. NOTE: We create gRPC client for each filter stack instead + * of a client per thread. That is ok since this is unary RPC and the cost of doing this is minimal. + */ +class GrpcClientImpl : public Client, public ExtAuthzAsyncCallbacks { +public: + GrpcClientImpl(Grpc::AsyncClientPtr&& async_client, + const absl::optional& timeout); + ~GrpcClientImpl(); + + // ExtAuthz::Client + void cancel() override; + void check(RequestCallbacks& callbacks, + const envoy::service::auth::v2alpha::CheckRequest& request, + Tracing::Span& parent_span) override; + + // Grpc::AsyncRequestCallbacks + void onCreateInitialMetadata(Http::HeaderMap&) override {} + void onSuccess(std::unique_ptr&& response, + Tracing::Span& span) override; + void onFailure(Grpc::Status::GrpcStatus status, const std::string& message, + Tracing::Span& span) override; + +private: + void toAuthzResponseHeader( + ResponsePtr& response, + const Protobuf::RepeatedPtrField& headers); + const Protobuf::MethodDescriptor& service_method_; + Grpc::AsyncClientPtr async_client_; + Grpc::AsyncRequest* request_{}; + absl::optional timeout_; + RequestCallbacks* callbacks_{}; +}; + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc new file mode 100644 index 0000000000000..705b368a6605e --- /dev/null +++ b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.cc @@ -0,0 +1,117 @@ +#include "extensions/filters/common/ext_authz/ext_authz_http_impl.h" + +#include "common/common/enum_to_int.h" +#include "common/http/async_client_impl.h" + +#include "absl/strings/str_cat.h" + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +namespace { + +const Http::HeaderMap* getZeroContentLengthHeader() { + static const Http::HeaderMap* header_map = + new Http::HeaderMapImpl{{Http::Headers::get().ContentLength, std::to_string(0)}}; + return header_map; +} +} // namespace + +RawHttpClientImpl::RawHttpClientImpl( + const std::string& cluster_name, Upstream::ClusterManager& cluster_manager, + const absl::optional& timeout, const std::string& path_prefix, + const std::vector& response_headers_to_remove) + : cluster_name_(cluster_name), path_prefix_(path_prefix), + response_headers_to_remove_(response_headers_to_remove), timeout_(timeout), + cm_(cluster_manager) {} + +RawHttpClientImpl::~RawHttpClientImpl() { ASSERT(!callbacks_); } + +void RawHttpClientImpl::cancel() { + ASSERT(callbacks_ != nullptr); + request_->cancel(); + callbacks_ = nullptr; +} + +void RawHttpClientImpl::check(RequestCallbacks& callbacks, + const envoy::service::auth::v2alpha::CheckRequest& request, + Tracing::Span&) { + ASSERT(callbacks_ == nullptr); + callbacks_ = &callbacks; + + Http::HeaderMapPtr headers = std::make_unique(*getZeroContentLengthHeader()); + for (const auto& header : request.attributes().request().http().headers()) { + + const Http::LowerCaseString key{header.first}; + if (key == Http::Headers::get().Path && !path_prefix_.empty()) { + std::string value; + absl::StrAppend(&value, path_prefix_, header.second); + headers->addCopy(key, value); + } else { + headers->addCopy(key, header.second); + } + } + + request_ = cm_.httpAsyncClientForCluster(cluster_name_) + .send(std::make_unique(std::move(headers)), *this, + timeout_); +} + +void RawHttpClientImpl::onSuccess(Http::MessagePtr&& response) { + ResponsePtr authz_response = std::make_unique(Response{}); + + uint64_t status_code; + if (StringUtil::atoul(response->headers().Status()->value().c_str(), status_code)) { + if (status_code == enumToInt(Http::Code::OK)) { + // Header that should not be sent to the upstream. + response->headers().removeStatus(); + response->headers().removeMethod(); + response->headers().removePath(); + response->headers().removeContentLength(); + + // Optional/Configurable headers the should not be sent to the upstream. + for (const auto& header_to_remove : response_headers_to_remove_) { + response->headers().remove(header_to_remove); + } + + authz_response->status = CheckStatus::OK; + authz_response->status_code = Http::Code::OK; + } else { + authz_response->status = CheckStatus::Denied; + authz_response->body = response->bodyAsString(); + authz_response->status_code = static_cast(status_code); + } + } else { + ENVOY_LOG(warn, "Authz_Ext failed to parse the HTTP response code."); + authz_response->status_code = Http::Code::Forbidden; + authz_response->status = CheckStatus::Denied; + } + + response->headers().iterate( + [](const Http::HeaderEntry& header, void* context) -> Http::HeaderMap::Iterate { + static_cast(context)->emplace_back( + Http::LowerCaseString{header.key().c_str()}, std::string{header.value().c_str()}); + return Http::HeaderMap::Iterate::Continue; + }, + &authz_response->headers_to_add); + + callbacks_->onComplete(std::move(authz_response)); + callbacks_ = nullptr; +} + +void RawHttpClientImpl::onFailure(Http::AsyncClient::FailureReason reason) { + ASSERT(reason == Http::AsyncClient::FailureReason::Reset); + Response authz_response{}; + authz_response.status = CheckStatus::Error; + callbacks_->onComplete(std::make_unique(authz_response)); + callbacks_ = nullptr; +} + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h new file mode 100644 index 0000000000000..fc5f26fb4c57e --- /dev/null +++ b/source/extensions/filters/common/ext_authz/ext_authz_http_impl.h @@ -0,0 +1,56 @@ +#pragma once + +#include "envoy/upstream/cluster_manager.h" + +#include "common/common/logger.h" + +#include "extensions/filters/common/ext_authz/ext_authz.h" + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +/** + * This client implementation is used when the Ext_Authz filter needs to communicate with an + * HTTP authorization server. Unlike the gRPC client that allows the server to define the + * response object, in the HTTP client, all headers and body provided in the response are + * dispatched to the downstream, and some headers to the upstream. The HTTP client also allows + * setting a path prefix witch is not available for gRPC. + */ +class RawHttpClientImpl : public Client, + public Http::AsyncClient::Callbacks, + Logger::Loggable { +public: + explicit RawHttpClientImpl(const std::string& cluster_name, + Upstream::ClusterManager& cluster_manager, + const absl::optional& timeout, + const std::string& path_prefix, + const std::vector& response_headers_to_remove); + ~RawHttpClientImpl(); + + // ExtAuthz::Client + void cancel() override; + void check(RequestCallbacks& callbacks, + const envoy::service::auth::v2alpha::CheckRequest& request, Tracing::Span&) override; + + // Http::AsyncClient::Callbacks + void onSuccess(Http::MessagePtr&& response) override; + void onFailure(Http::AsyncClient::FailureReason reason) override; + +private: + const std::string cluster_name_; + const std::string path_prefix_; + const std::vector response_headers_to_remove_; + absl::optional timeout_; + Upstream::ClusterManager& cm_; + Http::AsyncClient::Request* request_{}; + RequestCallbacks* callbacks_{}; +}; + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/common/lua/lua.h b/source/extensions/filters/common/lua/lua.h index e2c1d1fe23d62..9f2ee60328347 100644 --- a/source/extensions/filters/common/lua/lua.h +++ b/source/extensions/filters/common/lua/lua.h @@ -368,6 +368,21 @@ class ThreadLocalState : Logger::Loggable { [this]() { T::registerType(tls_slot_->getTyped().state_.get()); }); } + /** + * Return the number of bytes used by the runtime. + */ + uint64_t runtimeBytesUsed() { + uint64_t bytes_used = + lua_gc(tls_slot_->getTyped().state_.get(), LUA_GCCOUNT, 0) * 1024; + bytes_used += lua_gc(tls_slot_->getTyped().state_.get(), LUA_GCCOUNTB, 0); + return bytes_used; + } + + /** + * Force a full runtime GC. + */ + void runtimeGC() { lua_gc(tls_slot_->getTyped().state_.get(), LUA_GCCOLLECT, 0); } + private: struct LuaThreadLocal : public ThreadLocal::ThreadLocalObject { LuaThreadLocal(const std::string& code); diff --git a/source/extensions/filters/common/lua/wrappers.cc b/source/extensions/filters/common/lua/wrappers.cc index f431bc012fb15..efb522fe18a92 100644 --- a/source/extensions/filters/common/lua/wrappers.cc +++ b/source/extensions/filters/common/lua/wrappers.cc @@ -26,7 +26,7 @@ int BufferWrapper::luaGetBytes(lua_State* state) { return 1; } -void MetadataMapWrapper::setValue(lua_State* state, const ProtobufWkt::Value& value) { +void MetadataMapHelper::setValue(lua_State* state, const ProtobufWkt::Value& value) { ProtobufWkt::Value::KindCase kind = value.kind_case(); switch (kind) { @@ -72,11 +72,11 @@ void MetadataMapWrapper::setValue(lua_State* state, const ProtobufWkt::Value& va } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } -void MetadataMapWrapper::createTable( +void MetadataMapHelper::createTable( lua_State* state, const Protobuf::Map& fields) { lua_createtable(state, 0, fields.size()); @@ -98,7 +98,7 @@ int MetadataMapIterator::luaPairsIterator(lua_State* state) { } lua_pushstring(state, current_->first.c_str()); - parent_.setValue(state, current_->second); + MetadataMapHelper::setValue(state, current_->second); current_++; return 2; @@ -111,7 +111,7 @@ int MetadataMapWrapper::luaGet(lua_State* state) { return 0; } - setValue(state, filter_it->second); + MetadataMapHelper::setValue(state, filter_it->second); return 1; } @@ -125,6 +125,20 @@ int MetadataMapWrapper::luaPairs(lua_State* state) { return 1; } +int ConnectionWrapper::luaSsl(lua_State* state) { + const auto& ssl = connection_->ssl(); + if (ssl != nullptr) { + if (ssl_connection_wrapper_.get() != nullptr) { + ssl_connection_wrapper_.pushStack(); + } else { + ssl_connection_wrapper_.reset(SslConnectionWrapper::create(state, ssl), true); + } + } else { + lua_pushnil(state); + } + return 1; +} + } // namespace Lua } // namespace Common } // namespace Filters diff --git a/source/extensions/filters/common/lua/wrappers.h b/source/extensions/filters/common/lua/wrappers.h index 0cc84deb34bd6..2019a8887961f 100644 --- a/source/extensions/filters/common/lua/wrappers.h +++ b/source/extensions/filters/common/lua/wrappers.h @@ -42,6 +42,13 @@ class BufferWrapper : public BaseLuaObject { class MetadataMapWrapper; +struct MetadataMapHelper { + static void setValue(lua_State* state, const ProtobufWkt::Value& value); + static void + createTable(lua_State* state, + const Protobuf::Map& fields); +}; + /** * Iterator over a metadata map. */ @@ -89,16 +96,45 @@ class MetadataMapWrapper : public BaseLuaObject { iterator_.reset(); } - void setValue(lua_State* state, const ProtobufWkt::Value& value); - void createTable(lua_State* state, - const Protobuf::Map& fields); - const ProtobufWkt::Struct metadata_; LuaDeathRef iterator_; friend class MetadataMapIterator; }; +/** + * Lua wrapper for Ssl::Connection. + */ +class SslConnectionWrapper : public BaseLuaObject { +public: + SslConnectionWrapper(const Ssl::Connection*) {} + static ExportedFunctions exportedFunctions() { return {}; } + + // TODO(dio): Add more Lua APIs around Ssl::Connection. +}; + +/** + * Lua wrapper for Network::Connection. + */ +class ConnectionWrapper : public BaseLuaObject { +public: + ConnectionWrapper(const Network::Connection* connection) : connection_{connection} {} + static ExportedFunctions exportedFunctions() { return {{"ssl", static_luaSsl}}; } + +private: + /** + * Get the Ssl::Connection wrapper + * @return object if secured and nil if not. + */ + DECLARE_LUA_FUNCTION(ConnectionWrapper, luaSsl); + + // Envoy::Lua::BaseLuaObject + void onMarkDead() override { ssl_connection_wrapper_.reset(); } + + const Network::Connection* connection_; + LuaDeathRef ssl_connection_wrapper_; +}; + } // namespace Lua } // namespace Common } // namespace Filters diff --git a/source/extensions/filters/common/rbac/BUILD b/source/extensions/filters/common/rbac/BUILD index 59c0cf7beb7dd..0170fc2d7b8f1 100644 --- a/source/extensions/filters/common/rbac/BUILD +++ b/source/extensions/filters/common/rbac/BUILD @@ -16,8 +16,10 @@ envoy_cc_library( "//include/envoy/http:header_map_interface", "//include/envoy/network:connection_interface", "//source/common/common:assert_lib", + "//source/common/common:matchers_lib", "//source/common/http:header_utility_lib", "//source/common/network:cidr_range_lib", + "@envoy_api//envoy/api/v2/core:base_cc", "@envoy_api//envoy/config/rbac/v2alpha:rbac_cc", ], ) @@ -39,6 +41,7 @@ envoy_cc_library( deps = [ "//source/extensions/filters/common/rbac:engine_interface", "//source/extensions/filters/common/rbac:matchers_lib", + "@envoy_api//envoy/api/v2/core:base_cc", "@envoy_api//envoy/config/filter/http/rbac/v2:rbac_cc", ], ) diff --git a/source/extensions/filters/common/rbac/engine.h b/source/extensions/filters/common/rbac/engine.h index a122f58af5471..1281f6319bcf2 100644 --- a/source/extensions/filters/common/rbac/engine.h +++ b/source/extensions/filters/common/rbac/engine.h @@ -1,5 +1,6 @@ #pragma once +#include "envoy/api/v2/core/base.pb.h" #include "envoy/http/filter.h" #include "envoy/http/header_map.h" #include "envoy/network/connection.h" @@ -23,9 +24,10 @@ class RoleBasedAccessControlEngine { * @param connection the downstream connection used to identify the action/principal. * @param headers the headers of the incoming request used to identify the action/principal. An * empty map should be used if there are no headers available. + * @param metadata the metadata with additional information about the action/principal. */ - virtual bool allowed(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const PURE; + virtual bool allowed(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const PURE; }; } // namespace RBAC diff --git a/source/extensions/filters/common/rbac/engine_impl.cc b/source/extensions/filters/common/rbac/engine_impl.cc index adf62d3eb7ac1..1608476f8efff 100644 --- a/source/extensions/filters/common/rbac/engine_impl.cc +++ b/source/extensions/filters/common/rbac/engine_impl.cc @@ -15,11 +15,12 @@ RoleBasedAccessControlEngineImpl::RoleBasedAccessControlEngineImpl( } } -bool RoleBasedAccessControlEngineImpl::allowed(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const { +bool RoleBasedAccessControlEngineImpl::allowed( + const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const { bool matched = false; for (const auto& policy : policies_) { - if (policy.matches(connection, headers)) { + if (policy.matches(connection, headers, metadata)) { matched = true; break; } diff --git a/source/extensions/filters/common/rbac/engine_impl.h b/source/extensions/filters/common/rbac/engine_impl.h index d0747e5342eb4..6be4d18798d27 100644 --- a/source/extensions/filters/common/rbac/engine_impl.h +++ b/source/extensions/filters/common/rbac/engine_impl.h @@ -15,8 +15,8 @@ class RoleBasedAccessControlEngineImpl : public RoleBasedAccessControlEngine { public: RoleBasedAccessControlEngineImpl(const envoy::config::rbac::v2alpha::RBAC& rules); - bool allowed(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool allowed(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const override; private: const bool allowed_if_matched_; diff --git a/source/extensions/filters/common/rbac/matchers.cc b/source/extensions/filters/common/rbac/matchers.cc index 9cead82116c70..dad7ee9c58b42 100644 --- a/source/extensions/filters/common/rbac/matchers.cc +++ b/source/extensions/filters/common/rbac/matchers.cc @@ -22,8 +22,12 @@ MatcherConstSharedPtr Matcher::create(const envoy::config::rbac::v2alpha::Permis return std::make_shared(permission.destination_port()); case envoy::config::rbac::v2alpha::Permission::RuleCase::kAny: return std::make_shared(); + case envoy::config::rbac::v2alpha::Permission::RuleCase::kMetadata: + return std::make_shared(permission.metadata()); + case envoy::config::rbac::v2alpha::Permission::RuleCase::kNotRule: + return std::make_shared(permission.not_rule()); default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -41,8 +45,12 @@ MatcherConstSharedPtr Matcher::create(const envoy::config::rbac::v2alpha::Princi return std::make_shared(principal.header()); case envoy::config::rbac::v2alpha::Principal::IdentifierCase::kAny: return std::make_shared(); + case envoy::config::rbac::v2alpha::Principal::IdentifierCase::kMetadata: + return std::make_shared(principal.metadata()); + case envoy::config::rbac::v2alpha::Principal::IdentifierCase::kNotId: + return std::make_shared(principal.not_id()); default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -59,9 +67,10 @@ AndMatcher::AndMatcher(const envoy::config::rbac::v2alpha::Principal_Set& set) { } bool AndMatcher::matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const { + const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const { for (const auto& matcher : matchers_) { - if (!matcher->matches(connection, headers)) { + if (!matcher->matches(connection, headers, metadata)) { return false; } } @@ -84,9 +93,10 @@ OrMatcher::OrMatcher( } bool OrMatcher::matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const { + const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const { for (const auto& matcher : matchers_) { - if (matcher->matches(connection, headers)) { + if (matcher->matches(connection, headers, metadata)) { return true; } } @@ -94,27 +104,34 @@ bool OrMatcher::matches(const Network::Connection& connection, return false; } -bool HeaderMatcher::matches(const Network::Connection&, - const Envoy::Http::HeaderMap& headers) const { +bool NotMatcher::matches(const Network::Connection& connection, + const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const { + return !matcher_->matches(connection, headers, metadata); +} + +bool HeaderMatcher::matches(const Network::Connection&, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const { return Envoy::Http::HeaderUtility::matchHeaders(headers, header_); } -bool IPMatcher::matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap&) const { +bool IPMatcher::matches(const Network::Connection& connection, const Envoy::Http::HeaderMap&, + const envoy::api::v2::core::Metadata&) const { const Envoy::Network::Address::InstanceConstSharedPtr& ip = destination_ ? connection.localAddress() : connection.remoteAddress(); return range_.isInRange(*ip.get()); } -bool PortMatcher::matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap&) const { +bool PortMatcher::matches(const Network::Connection& connection, const Envoy::Http::HeaderMap&, + const envoy::api::v2::core::Metadata&) const { const Envoy::Network::Address::Ip* ip = connection.localAddress().get()->ip(); return ip && ip->port() == port_; } bool AuthenticatedMatcher::matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap&) const { + const Envoy::Http::HeaderMap&, + const envoy::api::v2::core::Metadata&) const { const auto* ssl = connection.ssl(); if (!ssl) { // connection was not authenticated return false; @@ -128,9 +145,16 @@ bool AuthenticatedMatcher::matches(const Network::Connection& connection, return principal == name_; } +bool MetadataMatcher::matches(const Network::Connection&, const Envoy::Http::HeaderMap&, + const envoy::api::v2::core::Metadata& metadata) const { + return matcher_.match(metadata); +} + bool PolicyMatcher::matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const { - return permissions_.matches(connection, headers) && principals_.matches(connection, headers); + const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const { + return permissions_.matches(connection, headers, metadata) && + principals_.matches(connection, headers, metadata); } } // namespace RBAC diff --git a/source/extensions/filters/common/rbac/matchers.h b/source/extensions/filters/common/rbac/matchers.h index 2c09288a08354..85487cc408f7a 100644 --- a/source/extensions/filters/common/rbac/matchers.h +++ b/source/extensions/filters/common/rbac/matchers.h @@ -2,10 +2,12 @@ #include +#include "envoy/api/v2/core/base.pb.h" #include "envoy/config/rbac/v2alpha/rbac.pb.h" #include "envoy/http/header_map.h" #include "envoy/network/connection.h" +#include "common/common/matchers.h" #include "common/http/header_utility.h" #include "common/network/cidr_range.h" @@ -31,9 +33,10 @@ class Matcher { * @param connection the downstream connection used to match against. * @param headers the request headers used to match against. An empty map should be used if * there are none headers available. + * @param metadata the additional information about the action/principal. */ - virtual bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const PURE; + virtual bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const PURE; /** * Creates a shared instance of a matcher based off the rules defined in the Permission config @@ -53,7 +56,8 @@ class Matcher { */ class AlwaysMatcher : public Matcher { public: - bool matches(const Network::Connection&, const Envoy::Http::HeaderMap&) const override { + bool matches(const Network::Connection&, const Envoy::Http::HeaderMap&, + const envoy::api::v2::core::Metadata&) const override { return true; } }; @@ -67,8 +71,8 @@ class AndMatcher : public Matcher { AndMatcher(const envoy::config::rbac::v2alpha::Permission_Set& rules); AndMatcher(const envoy::config::rbac::v2alpha::Principal_Set& ids); - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: std::vector matchers_; @@ -85,13 +89,27 @@ class OrMatcher : public Matcher { OrMatcher(const Protobuf::RepeatedPtrField<::envoy::config::rbac::v2alpha::Permission>& rules); OrMatcher(const Protobuf::RepeatedPtrField<::envoy::config::rbac::v2alpha::Principal>& ids); - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: std::vector matchers_; }; +class NotMatcher : public Matcher { +public: + NotMatcher(const envoy::config::rbac::v2alpha::Permission& permission) + : matcher_(Matcher::create(permission)) {} + NotMatcher(const envoy::config::rbac::v2alpha::Principal& principal) + : matcher_(Matcher::create(principal)) {} + + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; + +private: + MatcherConstSharedPtr matcher_; +}; + /** * Perform a match against any HTTP header (or pseudo-header, such as `:path` or `:authority`). Will * always fail to match on any non-HTTP connection. @@ -100,8 +118,8 @@ class HeaderMatcher : public Matcher { public: HeaderMatcher(const envoy::api::v2::route::HeaderMatcher& matcher) : header_(matcher) {} - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: const Envoy::Http::HeaderUtility::HeaderData header_; @@ -116,8 +134,8 @@ class IPMatcher : public Matcher { IPMatcher(const envoy::api::v2::core::CidrRange& range, bool destination) : range_(Network::Address::CidrRange::create(range)), destination_(destination) {} - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: const Network::Address::CidrRange range_; @@ -131,8 +149,8 @@ class PortMatcher : public Matcher { public: PortMatcher(const uint32_t port) : port_(port) {} - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: const uint32_t port_; @@ -147,8 +165,8 @@ class AuthenticatedMatcher : public Matcher { AuthenticatedMatcher(const envoy::config::rbac::v2alpha::Principal_Authenticated& auth) : name_(auth.name()) {} - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: const std::string name_; @@ -163,14 +181,25 @@ class PolicyMatcher : public Matcher { PolicyMatcher(const envoy::config::rbac::v2alpha::Policy& policy) : permissions_(policy.permissions()), principals_(policy.principals()) {} - bool matches(const Network::Connection& connection, - const Envoy::Http::HeaderMap& headers) const override; + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata&) const override; private: const OrMatcher permissions_; const OrMatcher principals_; }; +class MetadataMatcher : public Matcher { +public: + MetadataMatcher(const Envoy::Matchers::MetadataMatcher& matcher) : matcher_(matcher) {} + + bool matches(const Network::Connection& connection, const Envoy::Http::HeaderMap& headers, + const envoy::api::v2::core::Metadata& metadata) const override; + +private: + const Envoy::Matchers::MetadataMatcher matcher_; +}; + } // namespace RBAC } // namespace Common } // namespace Filters diff --git a/source/extensions/filters/http/buffer/BUILD b/source/extensions/filters/http/buffer/BUILD index c002a21a99c13..9245b81f09c93 100644 --- a/source/extensions/filters/http/buffer/BUILD +++ b/source/extensions/filters/http/buffer/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Request buffering and timeout L7 HTTP filter # Public docs: docs/root/configuration/http_filters/buffer_filter.rst diff --git a/source/extensions/filters/http/buffer/buffer_filter.cc b/source/extensions/filters/http/buffer/buffer_filter.cc index 36b45f31fd938..6db61bdc0745d 100644 --- a/source/extensions/filters/http/buffer/buffer_filter.cc +++ b/source/extensions/filters/http/buffer/buffer_filter.cc @@ -58,7 +58,7 @@ void BufferFilter::initConfig() { return; } - const std::string& name = HttpFilterNames::get().BUFFER; + const std::string& name = HttpFilterNames::get().Buffer; const auto* entry = callbacks_->route()->routeEntry(); const BufferFilterSettings* route_local = diff --git a/source/extensions/filters/http/buffer/config.h b/source/extensions/filters/http/buffer/config.h index 98c14afb3868c..232acd431592d 100644 --- a/source/extensions/filters/http/buffer/config.h +++ b/source/extensions/filters/http/buffer/config.h @@ -17,7 +17,7 @@ class BufferFilterFactory : public Common::FactoryBase { public: - BufferFilterFactory() : FactoryBase(HttpFilterNames::get().BUFFER) {} + BufferFilterFactory() : FactoryBase(HttpFilterNames::get().Buffer) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string& stats_prefix, diff --git a/source/extensions/filters/http/common/factory_base.h b/source/extensions/filters/http/common/factory_base.h index 2d46efddab17e..8e3a9669173b8 100644 --- a/source/extensions/filters/http/common/factory_base.h +++ b/source/extensions/filters/http/common/factory_base.h @@ -18,7 +18,7 @@ class FactoryBase : public Server::Configuration::NamedHttpFilterConfigFactory { Http::FilterFactoryCb createFilterFactory(const Json::Object&, const std::string&, Server::Configuration::FactoryContext&) override { // Only used in v1 filters. - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Http::FilterFactoryCb diff --git a/source/extensions/filters/http/cors/BUILD b/source/extensions/filters/http/cors/BUILD index 8399cfe3f408d..6ddf34f61e087 100644 --- a/source/extensions/filters/http/cors/BUILD +++ b/source/extensions/filters/http/cors/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L7 HTTP filter which implements CORS processing (https://en.wikipedia.org/wiki/Cross-origin_resource_sharing) # Public docs: docs/root/configuration/http_filters/cors_filter.rst diff --git a/source/extensions/filters/http/cors/config.h b/source/extensions/filters/http/cors/config.h index b1e2f947a08c6..9ce14197fc43b 100644 --- a/source/extensions/filters/http/cors/config.h +++ b/source/extensions/filters/http/cors/config.h @@ -15,7 +15,7 @@ namespace Cors { */ class CorsFilterConfig : public Common::EmptyHttpFilterConfig { public: - CorsFilterConfig() : Common::EmptyHttpFilterConfig(HttpFilterNames::get().CORS) {} + CorsFilterConfig() : Common::EmptyHttpFilterConfig(HttpFilterNames::get().Cors) {} Http::FilterFactoryCb createFilter(const std::string&, Server::Configuration::FactoryContext&) override; diff --git a/source/extensions/filters/http/cors/cors_filter.cc b/source/extensions/filters/http/cors/cors_filter.cc index ae9cada105766..eb1ca99b02fc4 100644 --- a/source/extensions/filters/http/cors/cors_filter.cc +++ b/source/extensions/filters/http/cors/cors_filter.cc @@ -70,10 +70,6 @@ Http::FilterHeadersStatus CorsFilter::decodeHeaders(Http::HeaderMap& headers, bo response_headers->insertAccessControlAllowHeaders().value(allowHeaders()); } - if (!exposeHeaders().empty()) { - response_headers->insertAccessControlExposeHeaders().value(exposeHeaders()); - } - if (!maxAge().empty()) { response_headers->insertAccessControlMaxAge().value(maxAge()); } @@ -95,14 +91,22 @@ Http::FilterHeadersStatus CorsFilter::encodeHeaders(Http::HeaderMap& headers, bo headers.insertAccessControlAllowCredentials().value(Http::Headers::get().CORSValues.True); } + if (!exposeHeaders().empty()) { + headers.insertAccessControlExposeHeaders().value(exposeHeaders()); + } + return Http::FilterHeadersStatus::Continue; } void CorsFilter::setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callbacks) { decoder_callbacks_ = &callbacks; -}; +} bool CorsFilter::isOriginAllowed(const Http::HeaderString& origin) { + return isOriginAllowedString(origin) || isOriginAllowedRegex(origin); +} + +bool CorsFilter::isOriginAllowedString(const Http::HeaderString& origin) { if (allowOrigins() == nullptr) { return false; } @@ -114,6 +118,18 @@ bool CorsFilter::isOriginAllowed(const Http::HeaderString& origin) { return false; } +bool CorsFilter::isOriginAllowedRegex(const Http::HeaderString& origin) { + if (allowOriginRegexes() == nullptr) { + return false; + } + for (const auto& regex : *allowOriginRegexes()) { + if (std::regex_match(origin.c_str(), regex)) { + return true; + } + } + return false; +} + const std::list* CorsFilter::allowOrigins() { for (const auto policy : policies_) { if (policy && !policy->allowOrigins().empty()) { @@ -123,6 +139,15 @@ const std::list* CorsFilter::allowOrigins() { return nullptr; } +const std::list* CorsFilter::allowOriginRegexes() { + for (const auto policy : policies_) { + if (policy && !policy->allowOriginRegexes().empty()) { + return &policy->allowOriginRegexes(); + } + } + return nullptr; +} + const std::string& CorsFilter::allowMethods() { for (const auto policy : policies_) { if (policy && !policy->allowMethods().empty()) { diff --git a/source/extensions/filters/http/cors/cors_filter.h b/source/extensions/filters/http/cors/cors_filter.h index 4ea6741208327..0d48abf5be07d 100644 --- a/source/extensions/filters/http/cors/cors_filter.h +++ b/source/extensions/filters/http/cors/cors_filter.h @@ -45,6 +45,7 @@ class CorsFilter : public Http::StreamFilter { friend class CorsFilterTest; const std::list* allowOrigins(); + const std::list* allowOriginRegexes(); const std::string& allowMethods(); const std::string& allowHeaders(); const std::string& exposeHeaders(); @@ -52,6 +53,8 @@ class CorsFilter : public Http::StreamFilter { bool allowCredentials(); bool enabled(); bool isOriginAllowed(const Http::HeaderString& origin); + bool isOriginAllowedString(const Http::HeaderString& origin); + bool isOriginAllowedRegex(const Http::HeaderString& origin); Http::StreamDecoderFilterCallbacks* decoder_callbacks_{}; Http::StreamEncoderFilterCallbacks* encoder_callbacks_{}; diff --git a/source/extensions/filters/http/dynamo/BUILD b/source/extensions/filters/http/dynamo/BUILD index 0e93b18a03178..e1ba977284758 100644 --- a/source/extensions/filters/http/dynamo/BUILD +++ b/source/extensions/filters/http/dynamo/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # AWS DynamoDB L7 HTTP filter (observability): https://aws.amazon.com/dynamodb/ # Public docs: docs/root/configuration/http_filters/dynamodb_filter.rst @@ -40,7 +41,10 @@ envoy_cc_library( name = "dynamo_utility_lib", srcs = ["dynamo_utility.cc"], hdrs = ["dynamo_utility.h"], - deps = ["//source/common/stats:stats_lib"], + deps = [ + "//include/envoy/stats:stats_interface", + "//source/common/stats:stats_lib", + ], ) envoy_cc_library( diff --git a/source/extensions/filters/http/dynamo/config.h b/source/extensions/filters/http/dynamo/config.h index ce75bebb150f0..d452985beb2f9 100644 --- a/source/extensions/filters/http/dynamo/config.h +++ b/source/extensions/filters/http/dynamo/config.h @@ -15,7 +15,7 @@ namespace Dynamo { */ class DynamoFilterConfig : public Common::EmptyHttpFilterConfig { public: - DynamoFilterConfig() : Common::EmptyHttpFilterConfig(HttpFilterNames::get().DYNAMO) {} + DynamoFilterConfig() : Common::EmptyHttpFilterConfig(HttpFilterNames::get().Dynamo) {} Http::FilterFactoryCb createFilter(const std::string& stat_prefix, Server::Configuration::FactoryContext& context) override; diff --git a/source/extensions/filters/http/dynamo/dynamo_filter.cc b/source/extensions/filters/http/dynamo/dynamo_filter.cc index 5a6bd6905f2c5..504fba6dce215 100644 --- a/source/extensions/filters/http/dynamo/dynamo_filter.cc +++ b/source/extensions/filters/http/dynamo/dynamo_filter.cc @@ -236,8 +236,9 @@ void DynamoFilter::chargeTablePartitionIdStats(const Json::Object& json_body) { std::vector partitions = RequestParser::parsePartitions(json_body); for (const RequestParser::PartitionDescriptor& partition : partitions) { - std::string scope_string = Utility::buildPartitionStatString( - stat_prefix_, table_descriptor_.table_name, operation_, partition.partition_id_); + std::string scope_string = + Utility::buildPartitionStatString(stat_prefix_, table_descriptor_.table_name, operation_, + partition.partition_id_, scope_.statsOptions()); scope_.counter(scope_string).add(partition.capacity_); } } diff --git a/source/extensions/filters/http/dynamo/dynamo_utility.cc b/source/extensions/filters/http/dynamo/dynamo_utility.cc index ceaa4085aafa1..aa9cf79f7585d 100644 --- a/source/extensions/filters/http/dynamo/dynamo_utility.cc +++ b/source/extensions/filters/http/dynamo/dynamo_utility.cc @@ -12,14 +12,15 @@ namespace Dynamo { std::string Utility::buildPartitionStatString(const std::string& stat_prefix, const std::string& table_name, const std::string& operation, - const std::string& partition_id) { + const std::string& partition_id, + const Stats::StatsOptions& stats_options) { // Use the last 7 characters of the partition id. std::string stats_partition_postfix = fmt::format(".capacity.{}.__partition_id={}", operation, partition_id.substr(partition_id.size() - 7, partition_id.size())); // Calculate how many characters are available for the table prefix. - size_t remaining_size = Stats::RawStatData::maxNameLength() - stats_partition_postfix.size(); + size_t remaining_size = stats_options.maxNameLength() - stats_partition_postfix.size(); std::string stats_table_prefix = fmt::format("{}table.{}", stat_prefix, table_name); // Truncate the table prefix if the current string is too large. diff --git a/source/extensions/filters/http/dynamo/dynamo_utility.h b/source/extensions/filters/http/dynamo/dynamo_utility.h index 87b4bd0bc119f..c1609290a1ac9 100644 --- a/source/extensions/filters/http/dynamo/dynamo_utility.h +++ b/source/extensions/filters/http/dynamo/dynamo_utility.h @@ -2,6 +2,8 @@ #include +#include "envoy/stats/stats.h" + namespace Envoy { namespace Extensions { namespace HttpFilters { @@ -23,7 +25,8 @@ class Utility { static std::string buildPartitionStatString(const std::string& stat_prefix, const std::string& table_name, const std::string& operation, - const std::string& partition_id); + const std::string& partition_id, + const Stats::StatsOptions& stats_options); }; } // namespace Dynamo diff --git a/source/extensions/filters/http/ext_authz/BUILD b/source/extensions/filters/http/ext_authz/BUILD index d9d352e16e050..5fbcd5b1b242a 100644 --- a/source/extensions/filters/http/ext_authz/BUILD +++ b/source/extensions/filters/http/ext_authz/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # External authorization L7 HTTP filter # Public docs: TODO(saumoh): Docs needed in docs/root/configuration/http_filters @@ -16,12 +17,14 @@ envoy_cc_library( hdrs = ["ext_authz.h"], deps = [ "//include/envoy/http:codes_interface", + "//source/common/buffer:buffer_lib", "//source/common/common:assert_lib", "//source/common/common:empty_string", "//source/common/common:enum_to_int", + "//source/common/common:minimal_logger_lib", "//source/common/http:codes_lib", "//source/common/router:config_lib", - "//source/extensions/filters/common/ext_authz:ext_authz_lib", + "//source/extensions/filters/common/ext_authz:ext_authz_grpc_lib", "@envoy_api//envoy/config/filter/http/ext_authz/v2alpha:ext_authz_cc", ], ) @@ -34,6 +37,7 @@ envoy_cc_library( ":ext_authz", "//include/envoy/registry", "//source/common/protobuf:utility_lib", + "//source/extensions/filters/common/ext_authz:ext_authz_http_lib", "//source/extensions/filters/http:well_known_names", "//source/extensions/filters/http/common:factory_base_lib", ], diff --git a/source/extensions/filters/http/ext_authz/config.cc b/source/extensions/filters/http/ext_authz/config.cc index a7585c3e22513..353a67fc872ff 100644 --- a/source/extensions/filters/http/ext_authz/config.cc +++ b/source/extensions/filters/http/ext_authz/config.cc @@ -8,7 +8,8 @@ #include "common/protobuf/utility.h" -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" +#include "extensions/filters/common/ext_authz/ext_authz_http_impl.h" #include "extensions/filters/http/ext_authz/ext_authz.h" namespace Envoy { @@ -19,14 +20,32 @@ namespace ExtAuthz { Http::FilterFactoryCb ExtAuthzFilterConfig::createFilterFactoryFromProtoTyped( const envoy::config::filter::http::ext_authz::v2alpha::ExtAuthz& proto_config, const std::string&, Server::Configuration::FactoryContext& context) { - auto filter_config = + + const auto filter_config = std::make_shared(proto_config, context.localInfo(), context.scope(), context.runtime(), context.clusterManager()); - const uint32_t timeout_ms = PROTOBUF_GET_MS_OR_DEFAULT(proto_config.grpc_service(), timeout, 200); + + if (proto_config.has_http_service()) { + const uint32_t timeout_ms = PROTOBUF_GET_MS_OR_DEFAULT(proto_config.http_service().server_uri(), + timeout, DefaultTimeout); + return [ + filter_config, timeout_ms, cluster_name = proto_config.http_service().server_uri().cluster(), + path_prefix = proto_config.http_service().path_prefix() + ](Http::FilterChainFactoryCallbacks & callbacks) { + auto client = std::make_unique( + cluster_name, filter_config->cm(), std::chrono::milliseconds(timeout_ms), path_prefix, + filter_config->responseHeadersToRemove()); + callbacks.addStreamDecoderFilter(Http::StreamDecoderFilterSharedPtr{ + std::make_shared(filter_config, std::move(client))}); + }; + } + + const uint32_t timeout_ms = + PROTOBUF_GET_MS_OR_DEFAULT(proto_config.grpc_service(), timeout, DefaultTimeout); return [ grpc_service = proto_config.grpc_service(), &context, filter_config, timeout_ms ](Http::FilterChainFactoryCallbacks & callbacks) { - auto async_client_factory = + const auto async_client_factory = context.clusterManager().grpcAsyncClientManager().factoryForGrpcService( grpc_service, context.scope(), true); auto client = std::make_unique( @@ -34,7 +53,7 @@ Http::FilterFactoryCb ExtAuthzFilterConfig::createFilterFactoryFromProtoTyped( callbacks.addStreamDecoderFilter(Http::StreamDecoderFilterSharedPtr{ std::make_shared(filter_config, std::move(client))}); }; -} +}; /** * Static registration for the external authorization filter. @see RegisterFactory. diff --git a/source/extensions/filters/http/ext_authz/config.h b/source/extensions/filters/http/ext_authz/config.h index aa8b36d9c169c..71f003c8fe976 100644 --- a/source/extensions/filters/http/ext_authz/config.h +++ b/source/extensions/filters/http/ext_authz/config.h @@ -16,9 +16,10 @@ namespace ExtAuthz { class ExtAuthzFilterConfig : public Common::FactoryBase { public: - ExtAuthzFilterConfig() : FactoryBase(HttpFilterNames::get().EXT_AUTHORIZATION) {} + ExtAuthzFilterConfig() : FactoryBase(HttpFilterNames::get().ExtAuthorization) {} private: + static constexpr uint64_t DefaultTimeout = 200; Http::FilterFactoryCb createFilterFactoryFromProtoTyped( const envoy::config::filter::http::ext_authz::v2alpha::ExtAuthz& proto_config, const std::string& stats_prefix, Server::Configuration::FactoryContext& context) override; diff --git a/source/extensions/filters/http/ext_authz/ext_authz.cc b/source/extensions/filters/http/ext_authz/ext_authz.cc index d67faa1e5d25d..20846dd34d2bf 100644 --- a/source/extensions/filters/http/ext_authz/ext_authz.cc +++ b/source/extensions/filters/http/ext_authz/ext_authz.cc @@ -1,19 +1,10 @@ #include "extensions/filters/http/ext_authz/ext_authz.h" -#include -#include - -#include "envoy/http/codes.h" - #include "common/common/assert.h" #include "common/common/enum_to_int.h" #include "common/http/codes.h" #include "common/router/config_impl.h" -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" - -#include "fmt/format.h" - namespace Envoy { namespace Extensions { namespace HttpFilters { @@ -39,11 +30,13 @@ void Filter::initiateCall(const Http::HeaderMap& headers) { // Don't let the filter chain continue as we are going to invoke check call. filter_return_ = FilterReturn::StopDecoding; initiating_call_ = true; + ENVOY_STREAM_LOG(trace, "Ext_authz filter calling authorization server", *callbacks_); client_->check(*this, check_request_, callbacks_->activeSpan()); initiating_call_ = false; } Http::FilterHeadersStatus Filter::decodeHeaders(Http::HeaderMap& headers, bool) { + request_headers_ = &headers; initiateCall(headers); return filter_return_ == FilterReturn::StopDecoding ? Http::FilterHeadersStatus::StopIteration : Http::FilterHeadersStatus::Continue; @@ -71,14 +64,13 @@ void Filter::onDestroy() { } } -void Filter::onComplete(Filters::Common::ExtAuthz::CheckStatus status) { +void Filter::onComplete(Filters::Common::ExtAuthz::ResponsePtr&& response) { ASSERT(cluster_); - state_ = State::Complete; using Filters::Common::ExtAuthz::CheckStatus; - switch (status) { + switch (response->status) { case CheckStatus::OK: cluster_->statsScope().counter("ext_authz.ok").inc(); break; @@ -91,7 +83,7 @@ void Filter::onComplete(Filters::Common::ExtAuthz::CheckStatus status) { Http::CodeUtility::ResponseStatInfo info{config_->scope(), cluster_->statsScope(), EMPTY_STRING, - enumToInt(Http::Code::Forbidden), + enumToInt(response->status_code), true, EMPTY_STRING, EMPTY_STRING, @@ -102,20 +94,57 @@ void Filter::onComplete(Filters::Common::ExtAuthz::CheckStatus status) { break; } + ENVOY_STREAM_LOG(trace, "Ext_authz received status code {}", *callbacks_, + enumToInt(response->status_code)); + // We fail open/fail close based of filter config // if there is an error contacting the service. - if (status == CheckStatus::Denied || - (status == CheckStatus::Error && !config_->failureModeAllow())) { - callbacks_->sendLocalReply(Http::Code::Forbidden, "", nullptr); + if (response->status == CheckStatus::Denied || + (response->status == CheckStatus::Error && !config_->failureModeAllow())) { + ENVOY_STREAM_LOG(debug, "Ext_authz rejected the request", *callbacks_); + ENVOY_STREAM_LOG(trace, "Ext_authz downstream header(s):", *callbacks_); + callbacks_->sendLocalReply(response->status_code, response->body, + [& headers = response->headers_to_add, + &callbacks = *callbacks_ ](Http::HeaderMap & response_headers) + ->void { + for (const auto& header : headers) { + response_headers.remove(header.first); + response_headers.addCopy(header.first, header.second); + ENVOY_STREAM_LOG(trace, " '{}':'{}'", callbacks, + header.first.get(), header.second); + } + }); callbacks_->requestInfo().setResponseFlag( RequestInfo::ResponseFlag::UnauthorizedExternalService); } else { + ENVOY_STREAM_LOG(debug, "Ext_authz accepted the request", *callbacks_); // Let the filter chain continue. filter_return_ = FilterReturn::ContinueDecoding; - if (config_->failureModeAllow() && status == CheckStatus::Error) { + if (config_->failureModeAllow() && response->status == CheckStatus::Error) { // Status is Error and yet we are allowing the request. Click a counter. cluster_->statsScope().counter("ext_authz.failure_mode_allowed").inc(); } + // Only send headers if the response is ok. + if (response->status == CheckStatus::OK) { + ENVOY_STREAM_LOG(trace, "Ext_authz upstream header(s):", *callbacks_); + for (const auto& header : response->headers_to_add) { + Http::HeaderEntry* header_to_modify = request_headers_->get(header.first); + if (header_to_modify) { + header_to_modify->value(header.second.c_str(), header.second.size()); + } else { + request_headers_->addCopy(header.first, header.second); + } + ENVOY_STREAM_LOG(trace, " '{}':'{}'", *callbacks_, header.first.get(), header.second); + } + for (const auto& header : response->headers_to_append) { + Http::HeaderEntry* header_to_modify = request_headers_->get(header.first); + if (header_to_modify) { + Http::HeaderMapImpl::appendToHeader(header_to_modify->value(), header.second); + ENVOY_STREAM_LOG(trace, " '{}':'{}'", *callbacks_, header.first.get(), header.second); + } + } + } + if (!initiating_call_) { // We got completion async. Let the filter chain continue. callbacks_->continueDecoding(); diff --git a/source/extensions/filters/http/ext_authz/ext_authz.h b/source/extensions/filters/http/ext_authz/ext_authz.h index 8220bf4a327c7..81dd788a155bf 100644 --- a/source/extensions/filters/http/ext_authz/ext_authz.h +++ b/source/extensions/filters/http/ext_authz/ext_authz.h @@ -12,10 +12,11 @@ #include "envoy/upstream/cluster_manager.h" #include "common/common/assert.h" +#include "common/common/logger.h" #include "common/http/header_map_impl.h" #include "extensions/filters/common/ext_authz/ext_authz.h" -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" namespace Envoy { namespace Extensions { @@ -37,6 +38,8 @@ class FilterConfig { Runtime::Loader& runtime, Upstream::ClusterManager& cm) : local_info_(local_info), scope_(scope), runtime_(runtime), cm_(cm), cluster_name_(config.grpc_service().envoy_grpc().cluster_name()), + response_headers_to_remove_(config.http_service().response_headers_to_remove().begin(), + config.http_service().response_headers_to_remove().end()), failure_mode_allow_(config.failure_mode_allow()) {} const LocalInfo::LocalInfo& localInfo() const { return local_info_; } @@ -44,6 +47,9 @@ class FilterConfig { Stats::Scope& scope() { return scope_; } std::string cluster() { return cluster_name_; } Upstream::ClusterManager& cm() { return cm_; } + const std::vector& responseHeadersToRemove() { + return response_headers_to_remove_; + } bool failureModeAllow() const { return failure_mode_allow_; } private: @@ -52,6 +58,7 @@ class FilterConfig { Runtime::Loader& runtime_; Upstream::ClusterManager& cm_; std::string cluster_name_; + std::vector response_headers_to_remove_; bool failure_mode_allow_; }; @@ -61,7 +68,8 @@ typedef std::shared_ptr FilterConfigSharedPtr; * HTTP ext_authz filter. Depending on the route configuration, this filter calls the global * ext_authz service before allowing further filter iteration. */ -class Filter : public Http::StreamDecoderFilter, +class Filter : public Logger::Loggable, + public Http::StreamDecoderFilter, public Filters::Common::ExtAuthz::RequestCallbacks { public: Filter(FilterConfigSharedPtr config, Filters::Common::ExtAuthz::ClientPtr&& client) @@ -77,9 +85,10 @@ class Filter : public Http::StreamDecoderFilter, void setDecoderFilterCallbacks(Http::StreamDecoderFilterCallbacks& callbacks) override; // ExtAuthz::RequestCallbacks - void onComplete(Filters::Common::ExtAuthz::CheckStatus status) override; + void onComplete(Filters::Common::ExtAuthz::ResponsePtr&&) override; private: + void addResponseHeaders(Http::HeaderMap& header_map, const Http::HeaderVector& headers); // State of this filter's communication with the external authorization service. // The filter has either not started calling the external service, in the middle of calling // it or has completed. @@ -89,10 +98,11 @@ class Filter : public Http::StreamDecoderFilter, // the filter chain should stop. Otherwise the filter chain can continue to the next filter. enum class FilterReturn { ContinueDecoding, StopDecoding }; void initiateCall(const Http::HeaderMap& headers); - + Http::HeaderMapPtr getHeaderMap(const Filters::Common::ExtAuthz::ResponsePtr& reponse); FilterConfigSharedPtr config_; Filters::Common::ExtAuthz::ClientPtr client_; Http::StreamDecoderFilterCallbacks* callbacks_{}; + Http::HeaderMap* request_headers_; State state_{State::NotStarted}; FilterReturn filter_return_{FilterReturn::ContinueDecoding}; Upstream::ClusterInfoConstSharedPtr cluster_; diff --git a/source/extensions/filters/http/fault/BUILD b/source/extensions/filters/http/fault/BUILD index 2ad7bba0b0e63..847e32b50e922 100644 --- a/source/extensions/filters/http/fault/BUILD +++ b/source/extensions/filters/http/fault/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # HTTP L7 filter that injects faults into the request flow # Public docs: docs/root/configuration/http_filters/fault_filter.rst diff --git a/source/extensions/filters/http/fault/config.h b/source/extensions/filters/http/fault/config.h index ae41ebc42261a..240c2f7dab520 100644 --- a/source/extensions/filters/http/fault/config.h +++ b/source/extensions/filters/http/fault/config.h @@ -16,7 +16,7 @@ namespace Fault { class FaultFilterFactory : public Common::FactoryBase { public: - FaultFilterFactory() : FactoryBase(HttpFilterNames::get().FAULT) {} + FaultFilterFactory() : FactoryBase(HttpFilterNames::get().Fault) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string& stats_prefix, diff --git a/source/extensions/filters/http/fault/fault_filter.cc b/source/extensions/filters/http/fault/fault_filter.cc index 9669b8b041d71..47bfbec25900b 100644 --- a/source/extensions/filters/http/fault/fault_filter.cc +++ b/source/extensions/filters/http/fault/fault_filter.cc @@ -76,7 +76,7 @@ Http::FilterHeadersStatus FaultFilter::decodeHeaders(Http::HeaderMap& headers, b // configured at the filter level. fault_settings_ = config_->settings(); if (callbacks_->route() && callbacks_->route()->routeEntry()) { - const std::string& name = Extensions::HttpFilters::HttpFilterNames::get().FAULT; + const std::string& name = Extensions::HttpFilters::HttpFilterNames::get().Fault; const auto* route_entry = callbacks_->route()->routeEntry(); const FaultSettings* per_route_settings_ = diff --git a/source/extensions/filters/http/grpc_http1_bridge/BUILD b/source/extensions/filters/http/grpc_http1_bridge/BUILD index 91325b4143daf..c7ceb6cfd4b11 100644 --- a/source/extensions/filters/http/grpc_http1_bridge/BUILD +++ b/source/extensions/filters/http/grpc_http1_bridge/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L7 HTTP filter that bridges HTTP/1.1 unary "gRPC" to compliant HTTP/2 gRPC. # Public docs: docs/root/configuration/http_filters/grpc_http1_bridge_filter.rst diff --git a/source/extensions/filters/http/grpc_http1_bridge/config.h b/source/extensions/filters/http/grpc_http1_bridge/config.h index 7edb9106adec1..b8a698f0d7852 100644 --- a/source/extensions/filters/http/grpc_http1_bridge/config.h +++ b/source/extensions/filters/http/grpc_http1_bridge/config.h @@ -16,7 +16,7 @@ namespace GrpcHttp1Bridge { class GrpcHttp1BridgeFilterConfig : public Common::EmptyHttpFilterConfig { public: GrpcHttp1BridgeFilterConfig() - : Common::EmptyHttpFilterConfig(HttpFilterNames::get().GRPC_HTTP1_BRIDGE) {} + : Common::EmptyHttpFilterConfig(HttpFilterNames::get().GrpcHttp1Bridge) {} Http::FilterFactoryCb createFilter(const std::string&, Server::Configuration::FactoryContext& context) override; diff --git a/source/extensions/filters/http/grpc_json_transcoder/BUILD b/source/extensions/filters/http/grpc_json_transcoder/BUILD index 5bf972b6ac8ea..3128edcf4c52b 100644 --- a/source/extensions/filters/http/grpc_json_transcoder/BUILD +++ b/source/extensions/filters/http/grpc_json_transcoder/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L7 HTTP filter that implements binary gRPC to JSON transcoding # Public docs: docs/root/configuration/http_filters/grpc_json_transcoder_filter.rst @@ -19,11 +20,11 @@ envoy_cc_library( "path_matcher", "grpc_transcoding", "http_api_protos", + "api_httpbody_protos", ], deps = [ ":transcoder_input_stream_lib", "//include/envoy/http:filter_interface", - "//source/common/common:base64_lib", "//source/common/grpc:codec_lib", "//source/common/grpc:common_lib", "//source/common/http:headers_lib", diff --git a/source/extensions/filters/http/grpc_json_transcoder/config.h b/source/extensions/filters/http/grpc_json_transcoder/config.h index 8c332d342a3d5..8b910725c6901 100644 --- a/source/extensions/filters/http/grpc_json_transcoder/config.h +++ b/source/extensions/filters/http/grpc_json_transcoder/config.h @@ -16,7 +16,7 @@ namespace GrpcJsonTranscoder { class GrpcJsonTranscoderFilterConfig : public Common::FactoryBase { public: - GrpcJsonTranscoderFilterConfig() : FactoryBase(HttpFilterNames::get().GRPC_JSON_TRANSCODER) {} + GrpcJsonTranscoderFilterConfig() : FactoryBase(HttpFilterNames::get().GrpcJsonTranscoder) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string& stats_prefix, diff --git a/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.cc b/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.cc index 4328fb2ab760c..4b13a9eee1283 100644 --- a/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.cc +++ b/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.cc @@ -14,6 +14,7 @@ #include "google/api/annotations.pb.h" #include "google/api/http.pb.h" +#include "google/api/httpbody.pb.h" #include "grpc_transcoding/json_request_translator.h" #include "grpc_transcoding/path_matcher_utility.h" #include "grpc_transcoding/response_to_json_translator.h" @@ -91,7 +92,7 @@ JsonTranscoderConfig::JsonTranscoderConfig( } break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } for (const auto& file : descriptor_set.file()) { @@ -222,6 +223,7 @@ Http::FilterHeadersStatus JsonTranscoderFilter::decodeHeaders(Http::HeaderMap& h // just pass-through the request to upstream. return Http::FilterHeadersStatus::Continue; } + has_http_body_output_ = !method_->server_streaming() && hasHttpBodyAsOutputType(); headers.removeContentLength(); headers.insertContentType().value().setReference(Http::Headers::get().ContentTypeValues.Grpc); @@ -340,6 +342,12 @@ Http::FilterDataStatus JsonTranscoderFilter::encodeData(Buffer::Instance& data, return Http::FilterDataStatus::Continue; } + // TODO(dio): Add support for streaming case. + if (has_http_body_output_) { + buildResponseFromHttpBodyOutput(*response_headers_, data); + return Http::FilterDataStatus::StopIterationAndBuffer; + } + response_in_.move(data); if (end_stream) { @@ -415,6 +423,36 @@ bool JsonTranscoderFilter::readToBuffer(Protobuf::io::ZeroCopyInputStream& strea return false; } +void JsonTranscoderFilter::buildResponseFromHttpBodyOutput(Http::HeaderMap& response_headers, + Buffer::Instance& data) { + std::vector frames; + decoder_.decode(data, frames); + if (frames.empty()) { + return; + } + + google::api::HttpBody http_body; + for (auto& frame : frames) { + if (frame.length_ > 0) { + Buffer::ZeroCopyInputStreamImpl stream(std::move(frame.data_)); + http_body.ParseFromZeroCopyStream(&stream); + const auto& body = http_body.data(); + + // TODO(mrice32): This string conversion is currently required because body has a different + // type within Google. Remove when the string types merge. + data.add(ProtobufTypes::String(body)); + + response_headers.insertContentType().value(http_body.content_type()); + response_headers.insertContentLength().value(body.size()); + return; + } + } +} + +bool JsonTranscoderFilter::hasHttpBodyAsOutputType() { + return method_->output_type()->full_name() == google::api::HttpBody::descriptor()->full_name(); +} + } // namespace GrpcJsonTranscoder } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.h b/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.h index e8544ea2600ff..5849bcc96b2fb 100644 --- a/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.h +++ b/source/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter.h @@ -7,6 +7,7 @@ #include "envoy/json/json_object.h" #include "common/common/logger.h" +#include "common/grpc/codec.h" #include "common/protobuf/protobuf.h" #include "extensions/filters/http/grpc_json_transcoder/transcoder_input_stream_impl.h" @@ -117,6 +118,8 @@ class JsonTranscoderFilter : public Http::StreamFilter, public Logger::Loggable< private: bool readToBuffer(Protobuf::io::ZeroCopyInputStream& stream, Buffer::Instance& data); + void buildResponseFromHttpBodyOutput(Http::HeaderMap& response_headers, Buffer::Instance& data); + bool hasHttpBodyAsOutputType(); JsonTranscoderConfig& config_; std::unique_ptr transcoder_; @@ -126,8 +129,10 @@ class JsonTranscoderFilter : public Http::StreamFilter, public Logger::Loggable< Http::StreamEncoderFilterCallbacks* encoder_callbacks_{nullptr}; const Protobuf::MethodDescriptor* method_{nullptr}; Http::HeaderMap* response_headers_{nullptr}; + Grpc::Decoder decoder_; bool error_{false}; + bool has_http_body_output_{false}; }; } // namespace GrpcJsonTranscoder diff --git a/source/extensions/filters/http/grpc_web/BUILD b/source/extensions/filters/http/grpc_web/BUILD index e9552b2bc70fb..28d7317948e0c 100644 --- a/source/extensions/filters/http/grpc_web/BUILD +++ b/source/extensions/filters/http/grpc_web/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L7 HTTP filter that implements the grpc-web protocol (https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-WEB.md) # Public docs: docs/root/configuration/http_filters/grpc_web_filter.rst diff --git a/source/extensions/filters/http/grpc_web/config.h b/source/extensions/filters/http/grpc_web/config.h index 98fabcf17c815..a1402b15fce54 100644 --- a/source/extensions/filters/http/grpc_web/config.h +++ b/source/extensions/filters/http/grpc_web/config.h @@ -12,7 +12,7 @@ namespace GrpcWeb { class GrpcWebFilterConfig : public Common::EmptyHttpFilterConfig { public: - GrpcWebFilterConfig() : Common::EmptyHttpFilterConfig(HttpFilterNames::get().GRPC_WEB) {} + GrpcWebFilterConfig() : Common::EmptyHttpFilterConfig(HttpFilterNames::get().GrpcWeb) {} Http::FilterFactoryCb createFilter(const std::string&, Server::Configuration::FactoryContext& context) override; diff --git a/source/extensions/filters/http/gzip/BUILD b/source/extensions/filters/http/gzip/BUILD index a31a6f93fee31..91efd59658911 100644 --- a/source/extensions/filters/http/gzip/BUILD +++ b/source/extensions/filters/http/gzip/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # HTTP L7 filter that performs gzip compression # Public docs: docs/root/configuration/http_filters/gzip_filter.rst diff --git a/source/extensions/filters/http/gzip/config.h b/source/extensions/filters/http/gzip/config.h index 06e2e70a5ae69..8dd48de622d36 100644 --- a/source/extensions/filters/http/gzip/config.h +++ b/source/extensions/filters/http/gzip/config.h @@ -15,7 +15,7 @@ namespace Gzip { */ class GzipFilterFactory : public Common::FactoryBase { public: - GzipFilterFactory() : FactoryBase(HttpFilterNames::get().ENVOY_GZIP) {} + GzipFilterFactory() : FactoryBase(HttpFilterNames::get().EnvoyGzip) {} private: Http::FilterFactoryCb diff --git a/source/extensions/filters/http/header_to_metadata/BUILD b/source/extensions/filters/http/header_to_metadata/BUILD index c32b2255b8f38..b31109c6333c4 100644 --- a/source/extensions/filters/http/header_to_metadata/BUILD +++ b/source/extensions/filters/http/header_to_metadata/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # HTTP L7 filter that transforms request data into dynamic metadata # Public docs: docs/root/configuration/http_filters/header_to_metadata_filter.rst diff --git a/source/extensions/filters/http/header_to_metadata/config.h b/source/extensions/filters/http/header_to_metadata/config.h index bbc548376bc70..11ced03b1cbf3 100644 --- a/source/extensions/filters/http/header_to_metadata/config.h +++ b/source/extensions/filters/http/header_to_metadata/config.h @@ -16,7 +16,7 @@ namespace HeaderToMetadataFilter { class HeaderToMetadataConfig : public Common::FactoryBase { public: - HeaderToMetadataConfig() : FactoryBase(HttpFilterNames::get().HEADER_TO_METADATA) {} + HeaderToMetadataConfig() : FactoryBase(HttpFilterNames::get().HeaderToMetadata) {} private: Http::FilterFactoryCb createFilterFactoryFromProtoTyped( diff --git a/source/extensions/filters/http/header_to_metadata/header_to_metadata_filter.cc b/source/extensions/filters/http/header_to_metadata/header_to_metadata_filter.cc index 83f7a356cb94a..722c0d16a59da 100644 --- a/source/extensions/filters/http/header_to_metadata/header_to_metadata_filter.cc +++ b/source/extensions/filters/http/header_to_metadata/header_to_metadata_filter.cc @@ -132,7 +132,7 @@ bool HeaderToMetadataFilter::addMetadata(StructMap& map, const std::string& meta } const std::string& HeaderToMetadataFilter::decideNamespace(const std::string& nspace) const { - return nspace.empty() ? HttpFilterNames::get().HEADER_TO_METADATA : nspace; + return nspace.empty() ? HttpFilterNames::get().HeaderToMetadata : nspace; } void HeaderToMetadataFilter::writeHeaderToMetadata(Http::HeaderMap& headers, diff --git a/source/extensions/filters/http/health_check/BUILD b/source/extensions/filters/http/health_check/BUILD index d8cc758715b32..6ecc7f75df0a5 100644 --- a/source/extensions/filters/http/health_check/BUILD +++ b/source/extensions/filters/http/health_check/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L7 HTTP filter that implements health check responses # Public docs: docs/root/configuration/http_filters/health_check_filter.rst diff --git a/source/extensions/filters/http/health_check/config.cc b/source/extensions/filters/http/health_check/config.cc index 04e28992198f2..56b682dafa382 100644 --- a/source/extensions/filters/http/health_check/config.cc +++ b/source/extensions/filters/http/health_check/config.cc @@ -20,25 +20,12 @@ Http::FilterFactoryCb HealthCheckFilterConfig::createFilterFactoryFromProtoTyped const bool pass_through_mode = proto_config.pass_through_mode().value(); const int64_t cache_time_ms = PROTOBUF_GET_MS_OR_DEFAULT(proto_config, cache_time, 0); - const std::string hc_endpoint = proto_config.endpoint(); auto header_match_data = std::make_shared>(); - // TODO(mrice32): remove endpoint field at the end of the 1.7.0 deprecation cycle. - const bool endpoint_set = !proto_config.endpoint().empty(); - if (endpoint_set) { - envoy::api::v2::route::HeaderMatcher matcher; - matcher.set_name(Http::Headers::get().Path.get()); - matcher.set_exact_match(proto_config.endpoint()); - header_match_data->emplace_back(matcher); - } - for (const envoy::api::v2::route::HeaderMatcher& matcher : proto_config.headers()) { Http::HeaderUtility::HeaderData single_header_match(matcher); - // Ignore any path header matchers if the endpoint field has been set. - if (!(endpoint_set && single_header_match.name_ == Http::Headers::get().Path)) { - header_match_data->push_back(std::move(single_header_match)); - } + header_match_data->push_back(std::move(single_header_match)); } if (!pass_through_mode && cache_time_ms) { diff --git a/source/extensions/filters/http/health_check/config.h b/source/extensions/filters/http/health_check/config.h index 71bd8c91baad7..d7763d6f0e96c 100644 --- a/source/extensions/filters/http/health_check/config.h +++ b/source/extensions/filters/http/health_check/config.h @@ -13,7 +13,7 @@ namespace HealthCheck { class HealthCheckFilterConfig : public Common::FactoryBase { public: - HealthCheckFilterConfig() : FactoryBase(HttpFilterNames::get().HEALTH_CHECK) {} + HealthCheckFilterConfig() : FactoryBase(HttpFilterNames::get().HealthCheck) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string&, diff --git a/source/extensions/filters/http/ip_tagging/BUILD b/source/extensions/filters/http/ip_tagging/BUILD index 12e72155c14f1..583893eadb238 100644 --- a/source/extensions/filters/http/ip_tagging/BUILD +++ b/source/extensions/filters/http/ip_tagging/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # HTTP L7 filter that writes an IP tagging header based on IP trie data # Public docs: docs/root/configuration/http_filters/ip_tagging_filter.rst diff --git a/source/extensions/filters/http/ip_tagging/config.h b/source/extensions/filters/http/ip_tagging/config.h index 98bd55cb6304d..01da576853bc9 100644 --- a/source/extensions/filters/http/ip_tagging/config.h +++ b/source/extensions/filters/http/ip_tagging/config.h @@ -16,7 +16,7 @@ namespace IpTagging { class IpTaggingFilterFactory : public Common::FactoryBase { public: - IpTaggingFilterFactory() : FactoryBase(HttpFilterNames::get().IP_TAGGING) {} + IpTaggingFilterFactory() : FactoryBase(HttpFilterNames::get().IpTagging) {} private: Http::FilterFactoryCb createFilterFactoryFromProtoTyped( diff --git a/source/extensions/filters/http/ip_tagging/ip_tagging_filter.cc b/source/extensions/filters/http/ip_tagging/ip_tagging_filter.cc index 7ebe7949b6d01..22be63eeb4527 100644 --- a/source/extensions/filters/http/ip_tagging/ip_tagging_filter.cc +++ b/source/extensions/filters/http/ip_tagging/ip_tagging_filter.cc @@ -28,7 +28,7 @@ Http::FilterHeadersStatus IpTaggingFilter::decodeHeaders(Http::HeaderMap& header } std::vector tags = - config_->trie().getTags(callbacks_->requestInfo().downstreamRemoteAddress()); + config_->trie().getData(callbacks_->requestInfo().downstreamRemoteAddress()); if (!tags.empty()) { const std::string tags_join = absl::StrJoin(tags, ","); diff --git a/source/extensions/filters/http/ip_tagging/ip_tagging_filter.h b/source/extensions/filters/http/ip_tagging/ip_tagging_filter.h index bbc72ef18a1cc..ffe88c81a12e9 100644 --- a/source/extensions/filters/http/ip_tagging/ip_tagging_filter.h +++ b/source/extensions/filters/http/ip_tagging/ip_tagging_filter.h @@ -66,13 +66,13 @@ class IpTaggingFilterConfig { ip_tag_pair.second = cidr_set; tag_data.emplace_back(ip_tag_pair); } - trie_.reset(new Network::LcTrie::LcTrie(tag_data)); + trie_.reset(new Network::LcTrie::LcTrie(tag_data)); } Runtime::Loader& runtime() { return runtime_; } Stats::Scope& scope() { return scope_; } FilterRequestType requestType() const { return request_type_; } - const Network::LcTrie::LcTrie& trie() const { return *trie_; } + const Network::LcTrie::LcTrie& trie() const { return *trie_; } const std::string& statsPrefix() const { return stats_prefix_; } private: @@ -86,7 +86,7 @@ class IpTaggingFilterConfig { case envoy::config::filter::http::ip_tagging::v2::IPTagging_RequestType_EXTERNAL: return FilterRequestType::EXTERNAL; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -94,7 +94,7 @@ class IpTaggingFilterConfig { Stats::Scope& scope_; Runtime::Loader& runtime_; const std::string stats_prefix_; - std::unique_ptr trie_; + std::unique_ptr> trie_; }; typedef std::shared_ptr IpTaggingFilterConfigSharedPtr; diff --git a/source/extensions/filters/http/jwt_authn/filter_factory.h b/source/extensions/filters/http/jwt_authn/filter_factory.h index 67f653bf053a3..6bdee97bae503 100644 --- a/source/extensions/filters/http/jwt_authn/filter_factory.h +++ b/source/extensions/filters/http/jwt_authn/filter_factory.h @@ -17,7 +17,7 @@ namespace JwtAuthn { class FilterFactory : public Common::FactoryBase< ::envoy::config::filter::http::jwt_authn::v2alpha::JwtAuthentication> { public: - FilterFactory() : FactoryBase(HttpFilterNames::get().JWT_AUTHN) {} + FilterFactory() : FactoryBase(HttpFilterNames::get().JwtAuthn) {} private: Http::FilterFactoryCb createFilterFactoryFromProtoTyped( diff --git a/source/extensions/filters/http/jwt_authn/jwks_cache.h b/source/extensions/filters/http/jwt_authn/jwks_cache.h index d0684d7aaac90..a36b081f52fad 100644 --- a/source/extensions/filters/http/jwt_authn/jwks_cache.h +++ b/source/extensions/filters/http/jwt_authn/jwks_cache.h @@ -22,7 +22,7 @@ typedef std::unique_ptr JwksCachePtr; * * // for a given jwt * auto jwks_data = jwks_cache->findByIssuer(jwt->getIssuer()); - * if (!jwks_data->isAudidenceAllowed(jwt->getAudiences())) reject; + * if (!jwks_data->areAudiencesAllowed(jwt->getAudiences())) reject; * * if (jwks_data->getJwksObj() == nullptr || jwks_data->isExpired()) { * // Fetch remote Jwks. diff --git a/source/extensions/filters/http/lua/BUILD b/source/extensions/filters/http/lua/BUILD index bc306c23731d5..fbb03ed551fef 100644 --- a/source/extensions/filters/http/lua/BUILD +++ b/source/extensions/filters/http/lua/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Lua scripting L7 HTTP filter (https://www.lua.org/, http://luajit.org/) # Public docs: docs/root/configuration/http_filters/lua_filter.rst @@ -34,7 +35,10 @@ envoy_cc_library( hdrs = ["wrappers.h"], deps = [ "//include/envoy/http:header_map_interface", + "//include/envoy/request_info:request_info_interface", + "//source/common/http:utility_lib", "//source/extensions/filters/common/lua:lua_lib", + "//source/extensions/filters/common/lua:wrappers_lib", ], ) diff --git a/source/extensions/filters/http/lua/config.h b/source/extensions/filters/http/lua/config.h index 82d9ec287fbda..6921fc6b349a0 100644 --- a/source/extensions/filters/http/lua/config.h +++ b/source/extensions/filters/http/lua/config.h @@ -15,7 +15,7 @@ namespace Lua { */ class LuaFilterConfig : public Common::FactoryBase { public: - LuaFilterConfig() : FactoryBase(HttpFilterNames::get().LUA) {} + LuaFilterConfig() : FactoryBase(HttpFilterNames::get().Lua) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string& stats_prefix, diff --git a/source/extensions/filters/http/lua/lua_filter.cc b/source/extensions/filters/http/lua/lua_filter.cc index 73c7a5ff2f010..fe6b8ab286bb3 100644 --- a/source/extensions/filters/http/lua/lua_filter.cc +++ b/source/extensions/filters/http/lua/lua_filter.cc @@ -12,10 +12,10 @@ namespace Extensions { namespace HttpFilters { namespace Lua { -StreamHandleWrapper::StreamHandleWrapper(Filters::Common::Lua::CoroutinePtr&& coroutine, +StreamHandleWrapper::StreamHandleWrapper(Filters::Common::Lua::Coroutine& coroutine, Http::HeaderMap& headers, bool end_stream, Filter& filter, FilterCallbacks& callbacks) - : coroutine_(std::move(coroutine)), headers_(headers), end_stream_(end_stream), filter_(filter), + : coroutine_(coroutine), headers_(headers), end_stream_(end_stream), filter_(filter), callbacks_(callbacks), yield_callback_([this]() { if (state_ == State::Running) { throw Filters::Common::Lua::LuaException("script performed an unexpected yield"); @@ -24,7 +24,7 @@ StreamHandleWrapper::StreamHandleWrapper(Filters::Common::Lua::CoroutinePtr&& co Http::FilterHeadersStatus StreamHandleWrapper::start(int function_ref) { // We are on the top of the stack. - coroutine_->start(function_ref, 1, yield_callback_); + coroutine_.start(function_ref, 1, yield_callback_); Http::FilterHeadersStatus status = (state_ == State::WaitForBody || state_ == State::HttpCall || state_ == State::Responded) ? Http::FilterHeadersStatus::StopIteration @@ -45,18 +45,18 @@ Http::FilterDataStatus StreamHandleWrapper::onData(Buffer::Instance& data, bool if (state_ == State::WaitForBodyChunk) { ENVOY_LOG(trace, "resuming for next body chunk"); Filters::Common::Lua::LuaDeathRef wrapper( - Filters::Common::Lua::BufferWrapper::create(coroutine_->luaState(), data), true); + Filters::Common::Lua::BufferWrapper::create(coroutine_.luaState(), data), true); state_ = State::Running; - coroutine_->resume(1, yield_callback_); + coroutine_.resume(1, yield_callback_); } else if (state_ == State::WaitForBody && end_stream_) { ENVOY_LOG(debug, "resuming body due to end stream"); callbacks_.addData(data); state_ = State::Running; - coroutine_->resume(luaBody(coroutine_->luaState()), yield_callback_); + coroutine_.resume(luaBody(coroutine_.luaState()), yield_callback_); } else if (state_ == State::WaitForTrailers && end_stream_) { ENVOY_LOG(debug, "resuming nil trailers due to end stream"); state_ = State::Running; - coroutine_->resume(0, yield_callback_); + coroutine_.resume(0, yield_callback_); } if (state_ == State::HttpCall || state_ == State::WaitForBody) { @@ -78,17 +78,17 @@ Http::FilterTrailersStatus StreamHandleWrapper::onTrailers(Http::HeaderMap& trai if (state_ == State::WaitForBodyChunk) { ENVOY_LOG(debug, "resuming nil body chunk due to trailers"); state_ = State::Running; - coroutine_->resume(0, yield_callback_); + coroutine_.resume(0, yield_callback_); } else if (state_ == State::WaitForBody) { ENVOY_LOG(debug, "resuming body due to trailers"); state_ = State::Running; - coroutine_->resume(luaBody(coroutine_->luaState()), yield_callback_); + coroutine_.resume(luaBody(coroutine_.luaState()), yield_callback_); } if (state_ == State::WaitForTrailers) { // Mimic a call to trailers which will push the trailers onto the stack and then resume. state_ = State::Running; - coroutine_->resume(luaTrailers(coroutine_->luaState()), yield_callback_); + coroutine_.resume(luaTrailers(coroutine_.luaState()), yield_callback_); } Http::FilterTrailersStatus status = (state_ == State::HttpCall || state_ == State::Responded) @@ -206,7 +206,7 @@ void StreamHandleWrapper::onSuccess(Http::MessagePtr&& response) { http_request_ = nullptr; // We need to build a table with the headers as return param 1. The body will be return param 2. - lua_newtable(coroutine_->luaState()); + lua_newtable(coroutine_.luaState()); response->headers().iterate( [](const Http::HeaderEntry& header, void* context) -> Http::HeaderMap::Iterate { lua_State* state = static_cast(context); @@ -215,13 +215,13 @@ void StreamHandleWrapper::onSuccess(Http::MessagePtr&& response) { lua_settable(state, -3); return Http::HeaderMap::Iterate::Continue; }, - coroutine_->luaState()); + coroutine_.luaState()); // TODO(mattklein123): Avoid double copy here. if (response->body() != nullptr) { - lua_pushstring(coroutine_->luaState(), response->bodyAsString().c_str()); + lua_pushstring(coroutine_.luaState(), response->bodyAsString().c_str()); } else { - lua_pushnil(coroutine_->luaState()); + lua_pushnil(coroutine_.luaState()); } // In the immediate failure case, we are just going to immediately return to the script. We @@ -231,7 +231,7 @@ void StreamHandleWrapper::onSuccess(Http::MessagePtr&& response) { markLive(); try { - coroutine_->resume(2, yield_callback_); + coroutine_.resume(2, yield_callback_); markDead(); } catch (const Filters::Common::Lua::LuaException& e) { filter_.scriptError(e); @@ -367,6 +367,27 @@ int StreamHandleWrapper::luaMetadata(lua_State* state) { return 1; } +int StreamHandleWrapper::luaRequestInfo(lua_State* state) { + ASSERT(state_ == State::Running); + if (request_info_wrapper_.get() != nullptr) { + request_info_wrapper_.pushStack(); + } else { + request_info_wrapper_.reset(RequestInfoWrapper::create(state, callbacks_.requestInfo()), true); + } + return 1; +} + +int StreamHandleWrapper::luaConnection(lua_State* state) { + ASSERT(state_ == State::Running); + if (connection_wrapper_.get() != nullptr) { + connection_wrapper_.pushStack(); + } else { + connection_wrapper_.reset( + Filters::Common::Lua::ConnectionWrapper::create(state, callbacks_.connection()), true); + } + return 1; +} + int StreamHandleWrapper::luaLogTrace(lua_State* state) { const char* message = luaL_checkstring(state, 2); filter_.scriptLog(spdlog::level::trace, message); @@ -409,8 +430,13 @@ FilterConfig::FilterConfig(const std::string& lua_code, ThreadLocal::SlotAllocat lua_state_.registerType(); lua_state_.registerType(); lua_state_.registerType(); + lua_state_.registerType(); + lua_state_.registerType(); lua_state_.registerType(); lua_state_.registerType(); + lua_state_.registerType(); + lua_state_.registerType(); + lua_state_.registerType(); lua_state_.registerType(); request_function_slot_ = lua_state_.registerGlobal("envoy_on_request"); @@ -434,16 +460,17 @@ void Filter::onDestroy() { } } -Http::FilterHeadersStatus Filter::doHeaders(StreamHandleRef& handle, FilterCallbacks& callbacks, - int function_ref, Http::HeaderMap& headers, - bool end_stream) { +Http::FilterHeadersStatus Filter::doHeaders(StreamHandleRef& handle, + Filters::Common::Lua::CoroutinePtr& coroutine, + FilterCallbacks& callbacks, int function_ref, + Http::HeaderMap& headers, bool end_stream) { if (function_ref == LUA_REFNIL) { return Http::FilterHeadersStatus::Continue; } - Filters::Common::Lua::CoroutinePtr coroutine = config_->createCoroutine(); - handle.reset(StreamHandleWrapper::create(coroutine->luaState(), std::move(coroutine), headers, - end_stream, *this, callbacks), + coroutine = config_->createCoroutine(); + handle.reset(StreamHandleWrapper::create(coroutine->luaState(), *coroutine, headers, end_stream, + *this, callbacks), true); Http::FilterHeadersStatus status = Http::FilterHeadersStatus::Continue; @@ -515,7 +542,7 @@ void Filter::scriptLog(spdlog::level::level_enum level, const char* message) { ENVOY_LOG(critical, "script log: {}", message); return; case spdlog::level::off: - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } diff --git a/source/extensions/filters/http/lua/lua_filter.h b/source/extensions/filters/http/lua/lua_filter.h index 868229686737d..bd306801606ae 100644 --- a/source/extensions/filters/http/lua/lua_filter.h +++ b/source/extensions/filters/http/lua/lua_filter.h @@ -18,7 +18,7 @@ const ProtobufWkt::Struct& getMetadata(Http::StreamFilterCallbacks* callbacks) { return ProtobufWkt::Struct::default_instance(); } const auto& metadata = callbacks->route()->routeEntry()->metadata(); - const auto& filter_it = metadata.filter_metadata().find(HttpFilterNames::get().LUA); + const auto& filter_it = metadata.filter_metadata().find(HttpFilterNames::get().Lua); if (filter_it == metadata.filter_metadata().end()) { return ProtobufWkt::Struct::default_instance(); } @@ -68,6 +68,17 @@ class FilterCallbacks { * route entry. */ virtual const ProtobufWkt::Struct& metadata() const PURE; + + /** + * @return RequestInfo::RequestInfo& the current request info handle. This handle is mutable to + * accomodate write API e.g. setDynamicMetadata(). + */ + virtual RequestInfo::RequestInfo& requestInfo() PURE; + + /** + * @return const const Network::Connection* the current network connection handle. + */ + virtual const Network::Connection* connection() const PURE; }; class Filter; @@ -101,7 +112,7 @@ class StreamHandleWrapper : public Filters::Common::Lua::BaseLuaObject body_wrapper_; Filters::Common::Lua::LuaDeathRef trailers_wrapper_; Filters::Common::Lua::LuaDeathRef metadata_wrapper_; + Filters::Common::Lua::LuaDeathRef request_info_wrapper_; + Filters::Common::Lua::LuaDeathRef connection_wrapper_; State state_{State::Running}; std::function yield_callback_; Http::AsyncClient::Request* http_request_{}; @@ -236,6 +262,8 @@ class FilterConfig : Logger::Loggable { Filters::Common::Lua::CoroutinePtr createCoroutine() { return lua_state_.createCoroutine(); } int requestFunctionRef() { return lua_state_.getGlobalRef(request_function_slot_); } int responseFunctionRef() { return lua_state_.getGlobalRef(response_function_slot_); } + uint64_t runtimeBytesUsed() { return lua_state_.runtimeBytesUsed(); } + void runtimeGC() { return lua_state_.runtimeGC(); } Upstream::ClusterManager& cluster_manager_; @@ -265,8 +293,8 @@ class Filter : public Http::StreamFilter, Logger::Loggable { // Http::StreamDecoderFilter Http::FilterHeadersStatus decodeHeaders(Http::HeaderMap& headers, bool end_stream) override { - return doHeaders(request_stream_wrapper_, decoder_callbacks_, config_->requestFunctionRef(), - headers, end_stream); + return doHeaders(request_stream_wrapper_, request_coroutine_, decoder_callbacks_, + config_->requestFunctionRef(), headers, end_stream); } Http::FilterDataStatus decodeData(Buffer::Instance& data, bool end_stream) override { return doData(request_stream_wrapper_, data, end_stream); @@ -283,8 +311,8 @@ class Filter : public Http::StreamFilter, Logger::Loggable { return Http::FilterHeadersStatus::Continue; } Http::FilterHeadersStatus encodeHeaders(Http::HeaderMap& headers, bool end_stream) override { - return doHeaders(response_stream_wrapper_, encoder_callbacks_, config_->responseFunctionRef(), - headers, end_stream); + return doHeaders(response_stream_wrapper_, response_coroutine_, encoder_callbacks_, + config_->responseFunctionRef(), headers, end_stream); } Http::FilterDataStatus encodeData(Buffer::Instance& data, bool end_stream) override { return doData(response_stream_wrapper_, data, end_stream); @@ -310,6 +338,8 @@ class Filter : public Http::StreamFilter, Logger::Loggable { void respond(Http::HeaderMapPtr&& headers, Buffer::Instance* body, lua_State* state) override; const ProtobufWkt::Struct& metadata() const override { return getMetadata(callbacks_); } + RequestInfo::RequestInfo& requestInfo() override { return callbacks_->requestInfo(); } + const Network::Connection* connection() const override { return callbacks_->connection(); } Filter& parent_; Http::StreamDecoderFilterCallbacks* callbacks_{}; @@ -328,6 +358,8 @@ class Filter : public Http::StreamFilter, Logger::Loggable { void respond(Http::HeaderMapPtr&& headers, Buffer::Instance* body, lua_State* state) override; const ProtobufWkt::Struct& metadata() const override { return getMetadata(callbacks_); } + RequestInfo::RequestInfo& requestInfo() override { return callbacks_->requestInfo(); } + const Network::Connection* connection() const override { return callbacks_->connection(); } Filter& parent_; Http::StreamEncoderFilterCallbacks* callbacks_{}; @@ -335,8 +367,10 @@ class Filter : public Http::StreamFilter, Logger::Loggable { typedef Filters::Common::Lua::LuaDeathRef StreamHandleRef; - Http::FilterHeadersStatus doHeaders(StreamHandleRef& handle, FilterCallbacks& callbacks, - int function_ref, Http::HeaderMap& headers, bool end_stream); + Http::FilterHeadersStatus doHeaders(StreamHandleRef& handle, + Filters::Common::Lua::CoroutinePtr& coroutine, + FilterCallbacks& callbacks, int function_ref, + Http::HeaderMap& headers, bool end_stream); Http::FilterDataStatus doData(StreamHandleRef& handle, Buffer::Instance& data, bool end_stream); Http::FilterTrailersStatus doTrailers(StreamHandleRef& handle, Http::HeaderMap& trailers); @@ -346,6 +380,21 @@ class Filter : public Http::StreamFilter, Logger::Loggable { StreamHandleRef request_stream_wrapper_; StreamHandleRef response_stream_wrapper_; bool destroyed_{}; + + // These coroutines used to be owned by the stream handles. After investigating #3570, it + // became clear that there is a circular memory reference when a coroutine yields. Basically, + // the coroutine holds a reference to the stream wrapper. I'm not completely sure why this is, + // but I think it is because the yield happens via a stream handle method, so the runtime must + // hold a reference so that it can return out of the yield through the object. So now we hold + // the coroutine references at the same level as the stream handles so that when the filter is + // destroyed the circular reference is broken and both objects are cleaned up. + // + // Note that the above explanation probably means that we don't need to hold a reference to the + // coroutine at all and it would be taken care of automatically via a runtime internal reference + // when a yield happens. However, given that I don't fully understand the runtime internals, this + // seems like a safer fix for now. + Filters::Common::Lua::CoroutinePtr request_coroutine_; + Filters::Common::Lua::CoroutinePtr response_coroutine_; }; } // namespace Lua diff --git a/source/extensions/filters/http/lua/wrappers.cc b/source/extensions/filters/http/lua/wrappers.cc index 642a444a612d2..62924f4f4fa0b 100644 --- a/source/extensions/filters/http/lua/wrappers.cc +++ b/source/extensions/filters/http/lua/wrappers.cc @@ -1,5 +1,9 @@ #include "extensions/filters/http/lua/wrappers.h" +#include "common/http/utility.h" + +#include "extensions/filters/common/lua/wrappers.h" + namespace Envoy { namespace Extensions { namespace HttpFilters { @@ -100,6 +104,74 @@ void HeaderMapWrapper::checkModifiable(lua_State* state) { } } +int RequestInfoWrapper::luaProtocol(lua_State* state) { + lua_pushstring(state, Http::Utility::getProtocolString(request_info_.protocol().value()).c_str()); + return 1; +} + +int RequestInfoWrapper::luaDynamicMetadata(lua_State* state) { + if (dynamic_metadata_wrapper_.get() != nullptr) { + dynamic_metadata_wrapper_.pushStack(); + } else { + dynamic_metadata_wrapper_.reset(DynamicMetadataMapWrapper::create(state, *this), true); + } + return 1; +} + +DynamicMetadataMapIterator::DynamicMetadataMapIterator(DynamicMetadataMapWrapper& parent) + : parent_{parent}, current_{parent_.requestInfo().dynamicMetadata().filter_metadata().begin()} { +} + +RequestInfo::RequestInfo& DynamicMetadataMapWrapper::requestInfo() { return parent_.request_info_; } + +int DynamicMetadataMapIterator::luaPairsIterator(lua_State* state) { + if (current_ == parent_.requestInfo().dynamicMetadata().filter_metadata().end()) { + parent_.iterator_.reset(); + return 0; + } + + lua_pushstring(state, current_->first.c_str()); + Filters::Common::Lua::MetadataMapHelper::createTable(state, current_->second.fields()); + + current_++; + return 2; +} + +int DynamicMetadataMapWrapper::luaGet(lua_State* state) { + const char* filter_name = luaL_checkstring(state, 2); + const auto& metadata = requestInfo().dynamicMetadata().filter_metadata(); + const auto filter_it = metadata.find(filter_name); + if (filter_it == metadata.end()) { + return 0; + } + + Filters::Common::Lua::MetadataMapHelper::createTable(state, filter_it->second.fields()); + return 1; +} + +int DynamicMetadataMapWrapper::luaSet(lua_State* state) { + if (iterator_.get() != nullptr) { + luaL_error(state, "dynamic metadata map cannot be modified while iterating"); + } + + // TODO(dio): Allow to set dynamic metadata using a table. + const char* filter_name = luaL_checkstring(state, 2); + const char* key = luaL_checkstring(state, 3); + const char* value = luaL_checkstring(state, 4); + requestInfo().setDynamicMetadata(filter_name, MessageUtil::keyValueStruct(key, value)); + return 0; +} + +int DynamicMetadataMapWrapper::luaPairs(lua_State* state) { + if (iterator_.get() != nullptr) { + luaL_error(state, "cannot create a second iterator before completing the first"); + } + + iterator_.reset(DynamicMetadataMapIterator::create(state, *this), true); + lua_pushcclosure(state, DynamicMetadataMapIterator::static_luaPairsIterator, 1); + return 1; +} + } // namespace Lua } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/lua/wrappers.h b/source/extensions/filters/http/lua/wrappers.h index 57a5b6ce7cbc1..694170b78c2ae 100644 --- a/source/extensions/filters/http/lua/wrappers.h +++ b/source/extensions/filters/http/lua/wrappers.h @@ -1,6 +1,7 @@ #pragma once #include "envoy/http/header_map.h" +#include "envoy/request_info/request_info.h" #include "extensions/filters/common/lua/lua.h" @@ -98,6 +99,105 @@ class HeaderMapWrapper : public Filters::Common::Lua::BaseLuaObject { +public: + DynamicMetadataMapIterator(DynamicMetadataMapWrapper& parent); + + static ExportedFunctions exportedFunctions() { return {}; } + + DECLARE_LUA_CLOSURE(DynamicMetadataMapIterator, luaPairsIterator); + +private: + DynamicMetadataMapWrapper& parent_; + Protobuf::Map::const_iterator current_; +}; + +/** + * Lua wrapper for a dynamic metadata. + */ +class DynamicMetadataMapWrapper + : public Filters::Common::Lua::BaseLuaObject { +public: + DynamicMetadataMapWrapper(RequestInfoWrapper& parent) : parent_{parent} {} + + static ExportedFunctions exportedFunctions() { + return {{"get", static_luaGet}, {"set", static_luaSet}, {"__pairs", static_luaPairs}}; + } + +private: + /** + * Get a metadata value from the map. + * @param 1 (string): filter name. + * @return value if found or nil. + */ + DECLARE_LUA_FUNCTION(DynamicMetadataMapWrapper, luaGet); + + /** + * Get a metadata value from the map. + * @param 1 (string): filter name. + * @param 2 (string or table): key. + * @param 3 (string or table): value. + * @return nil. + */ + DECLARE_LUA_FUNCTION(DynamicMetadataMapWrapper, luaSet); + + /** + * Implementation of the __pairs metamethod so a dynamic metadata wrapper can be iterated over + * using pairs(). + */ + DECLARE_LUA_FUNCTION(DynamicMetadataMapWrapper, luaPairs); + + // Envoy::Lua::BaseLuaObject + void onMarkDead() override { + // Iterators do not survive yields. + iterator_.reset(); + } + + // To get reference to parent's (RequestInfoWrapper) request info member. + RequestInfo::RequestInfo& requestInfo(); + + RequestInfoWrapper& parent_; + Filters::Common::Lua::LuaDeathRef iterator_; + + friend class DynamicMetadataMapIterator; +}; + +/** + * Lua wrapper for a request info. + */ +class RequestInfoWrapper : public Filters::Common::Lua::BaseLuaObject { +public: + RequestInfoWrapper(RequestInfo::RequestInfo& request_info) : request_info_{request_info} {} + static ExportedFunctions exportedFunctions() { + return {{"protocol", static_luaProtocol}, {"dynamicMetadata", static_luaDynamicMetadata}}; + } + +private: + /** + * Get current protocol being used. + * @return string representation of Http::Protocol. + */ + DECLARE_LUA_FUNCTION(RequestInfoWrapper, luaProtocol); + + /** + * Get reference to request info dynamic metadata object. + * @return DynamicMetadataMapWrapper representation of RequestInfo dynamic metadata. + */ + DECLARE_LUA_FUNCTION(RequestInfoWrapper, luaDynamicMetadata); + + RequestInfo::RequestInfo& request_info_; + Filters::Common::Lua::LuaDeathRef dynamic_metadata_wrapper_; + + friend class DynamicMetadataMapWrapper; +}; + } // namespace Lua } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/ratelimit/BUILD b/source/extensions/filters/http/ratelimit/BUILD index 3f85e70b294c3..f656f229a0d79 100644 --- a/source/extensions/filters/http/ratelimit/BUILD +++ b/source/extensions/filters/http/ratelimit/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Ratelimit L7 HTTP filter # Public docs: docs/root/configuration/http_filters/rate_limit_filter.rst diff --git a/source/extensions/filters/http/ratelimit/config.h b/source/extensions/filters/http/ratelimit/config.h index 91ec643de135d..bf9ebadfe2560 100644 --- a/source/extensions/filters/http/ratelimit/config.h +++ b/source/extensions/filters/http/ratelimit/config.h @@ -16,7 +16,7 @@ namespace RateLimitFilter { class RateLimitFilterConfig : public Common::FactoryBase { public: - RateLimitFilterConfig() : FactoryBase(HttpFilterNames::get().RATE_LIMIT) {} + RateLimitFilterConfig() : FactoryBase(HttpFilterNames::get().RateLimit) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string&, diff --git a/source/extensions/filters/http/rbac/config.h b/source/extensions/filters/http/rbac/config.h index c6002fb284018..f24235eb783fb 100644 --- a/source/extensions/filters/http/rbac/config.h +++ b/source/extensions/filters/http/rbac/config.h @@ -18,7 +18,7 @@ class RoleBasedAccessControlFilterConfigFactory : public Common::FactoryBase { public: - RoleBasedAccessControlFilterConfigFactory() : FactoryBase(HttpFilterNames::get().RBAC) {} + RoleBasedAccessControlFilterConfigFactory() : FactoryBase(HttpFilterNames::get().Rbac) {} private: Http::FilterFactoryCb diff --git a/source/extensions/filters/http/rbac/rbac_filter.cc b/source/extensions/filters/http/rbac/rbac_filter.cc index 9ef629f6e6841..08a347a79aa9b 100644 --- a/source/extensions/filters/http/rbac/rbac_filter.cc +++ b/source/extensions/filters/http/rbac/rbac_filter.cc @@ -36,7 +36,7 @@ RoleBasedAccessControlFilterConfig::engine(const Router::RouteConstSharedPtr rou return engine(mode); } - const std::string& name = HttpFilterNames::get().RBAC; + const std::string& name = HttpFilterNames::get().Rbac; const auto* entry = route->routeEntry(); const auto* route_local = @@ -65,12 +65,27 @@ RoleBasedAccessControlRouteSpecificFilterConfig::RoleBasedAccessControlRouteSpec Http::FilterHeadersStatus RoleBasedAccessControlFilter::decodeHeaders(Http::HeaderMap& headers, bool) { + ENVOY_LOG( + debug, + "checking request: remoteAddress: {}, localAddress: {}, ssl: {}, headers: {}, " + "dynamicMetadata: {}", + callbacks_->connection()->remoteAddress()->asString(), + callbacks_->connection()->localAddress()->asString(), + callbacks_->connection()->ssl() + ? "uriSanPeerCertificate: " + callbacks_->connection()->ssl()->uriSanPeerCertificate() + + ", subjectPeerCertificate: " + + callbacks_->connection()->ssl()->subjectPeerCertificate() + : "none", + headers, callbacks_->requestInfo().dynamicMetadata().DebugString()); const absl::optional& shadow_engine = config_->engine(callbacks_->route(), EnforcementMode::Shadow); if (shadow_engine.has_value()) { - if (shadow_engine->allowed(*callbacks_->connection(), headers)) { + if (shadow_engine->allowed(*callbacks_->connection(), headers, + callbacks_->requestInfo().dynamicMetadata())) { + ENVOY_LOG(debug, "shadow allowed"); config_->stats().shadow_allowed_.inc(); } else { + ENVOY_LOG(debug, "shadow denied"); config_->stats().shadow_denied_.inc(); } } @@ -78,16 +93,20 @@ Http::FilterHeadersStatus RoleBasedAccessControlFilter::decodeHeaders(Http::Head const absl::optional& engine = config_->engine(callbacks_->route(), EnforcementMode::Enforced); if (engine.has_value()) { - if (engine->allowed(*callbacks_->connection(), headers)) { + if (engine->allowed(*callbacks_->connection(), headers, + callbacks_->requestInfo().dynamicMetadata())) { + ENVOY_LOG(debug, "enforced allowed"); config_->stats().allowed_.inc(); return Http::FilterHeadersStatus::Continue; } else { + ENVOY_LOG(debug, "enforced denied"); callbacks_->sendLocalReply(Http::Code::Forbidden, "RBAC: access denied", nullptr); config_->stats().denied_.inc(); return Http::FilterHeadersStatus::StopIteration; } } + ENVOY_LOG(debug, "no engine, allowed by default"); return Http::FilterHeadersStatus::Continue; } diff --git a/source/extensions/filters/http/rbac/rbac_filter.h b/source/extensions/filters/http/rbac/rbac_filter.h index 48a0c9703b5d4..72770303b6e44 100644 --- a/source/extensions/filters/http/rbac/rbac_filter.h +++ b/source/extensions/filters/http/rbac/rbac_filter.h @@ -6,6 +6,8 @@ #include "envoy/http/filter.h" #include "envoy/stats/stats_macros.h" +#include "common/common/logger.h" + #include "extensions/filters/common/rbac/engine_impl.h" namespace Envoy { @@ -80,7 +82,8 @@ typedef std::shared_ptr /** * A filter that provides role-based access control authorization for HTTP requests. */ -class RoleBasedAccessControlFilter : public Http::StreamDecoderFilter { +class RoleBasedAccessControlFilter : public Http::StreamDecoderFilter, + public Logger::Loggable { public: RoleBasedAccessControlFilter(RoleBasedAccessControlFilterConfigSharedPtr config) : config_(config) {} diff --git a/source/extensions/filters/http/router/BUILD b/source/extensions/filters/http/router/BUILD index 74561ea6545c7..ddffe3458ebd9 100644 --- a/source/extensions/filters/http/router/BUILD +++ b/source/extensions/filters/http/router/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # HTTP L7 filter responsible for routing to upstream connection pools # Public docs: docs/root/configuration/http_filters/router_filter.rst diff --git a/source/extensions/filters/http/router/config.h b/source/extensions/filters/http/router/config.h index dcf0957ca4b1a..11c4d74f8a453 100644 --- a/source/extensions/filters/http/router/config.h +++ b/source/extensions/filters/http/router/config.h @@ -18,7 +18,7 @@ namespace RouterFilter { class RouterFilterConfig : public Common::FactoryBase { public: - RouterFilterConfig() : FactoryBase(HttpFilterNames::get().ROUTER) {} + RouterFilterConfig() : FactoryBase(HttpFilterNames::get().Router) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string& stat_prefix, diff --git a/source/extensions/filters/http/squash/BUILD b/source/extensions/filters/http/squash/BUILD index b0cbb13c1083d..ab986ca184438 100644 --- a/source/extensions/filters/http/squash/BUILD +++ b/source/extensions/filters/http/squash/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L7 HTTP filter that implements the Squash microservice debugger # Public docs: docs/root/configuration/http_filters/squash_filter.rst diff --git a/source/extensions/filters/http/squash/config.h b/source/extensions/filters/http/squash/config.h index 9cb83b0096f23..10647ebdb6b05 100644 --- a/source/extensions/filters/http/squash/config.h +++ b/source/extensions/filters/http/squash/config.h @@ -16,7 +16,7 @@ namespace Squash { class SquashFilterConfigFactory : public Common::FactoryBase { public: - SquashFilterConfigFactory() : FactoryBase(HttpFilterNames::get().SQUASH) {} + SquashFilterConfigFactory() : FactoryBase(HttpFilterNames::get().Squash) {} Http::FilterFactoryCb createFilterFactory(const Json::Object& json_config, const std::string&, diff --git a/source/extensions/filters/http/well_known_names.h b/source/extensions/filters/http/well_known_names.h index f71960625e0ef..238671e6da289 100644 --- a/source/extensions/filters/http/well_known_names.h +++ b/source/extensions/filters/http/well_known_names.h @@ -13,50 +13,50 @@ namespace HttpFilters { class HttpFilterNameValues { public: // Buffer filter - const std::string BUFFER = "envoy.buffer"; + const std::string Buffer = "envoy.buffer"; // CORS filter - const std::string CORS = "envoy.cors"; + const std::string Cors = "envoy.cors"; // Dynamo filter - const std::string DYNAMO = "envoy.http_dynamo_filter"; + const std::string Dynamo = "envoy.http_dynamo_filter"; // Fault filter - const std::string FAULT = "envoy.fault"; + const std::string Fault = "envoy.fault"; // GRPC http1 bridge filter - const std::string GRPC_HTTP1_BRIDGE = "envoy.grpc_http1_bridge"; + const std::string GrpcHttp1Bridge = "envoy.grpc_http1_bridge"; // GRPC json transcoder filter - const std::string GRPC_JSON_TRANSCODER = "envoy.grpc_json_transcoder"; + const std::string GrpcJsonTranscoder = "envoy.grpc_json_transcoder"; // GRPC web filter - const std::string GRPC_WEB = "envoy.grpc_web"; + const std::string GrpcWeb = "envoy.grpc_web"; // Gzip filter - const std::string ENVOY_GZIP = "envoy.gzip"; + const std::string EnvoyGzip = "envoy.gzip"; // IP tagging filter - const std::string IP_TAGGING = "envoy.ip_tagging"; + const std::string IpTagging = "envoy.ip_tagging"; // Rate limit filter - const std::string RATE_LIMIT = "envoy.rate_limit"; + const std::string RateLimit = "envoy.rate_limit"; // Router filter - const std::string ROUTER = "envoy.router"; + const std::string Router = "envoy.router"; // Health checking filter - const std::string HEALTH_CHECK = "envoy.health_check"; + const std::string HealthCheck = "envoy.health_check"; // Lua filter - const std::string LUA = "envoy.lua"; + const std::string Lua = "envoy.lua"; // Squash filter - const std::string SQUASH = "envoy.squash"; + const std::string Squash = "envoy.squash"; // External Authorization filter - const std::string EXT_AUTHORIZATION = "envoy.ext_authz"; + const std::string ExtAuthorization = "envoy.ext_authz"; // RBAC HTTP Authorization filter - const std::string RBAC = "envoy.filters.http.rbac"; + const std::string Rbac = "envoy.filters.http.rbac"; // JWT authentication filter - const std::string JWT_AUTHN = "envoy.filters.http.jwt_authn"; + const std::string JwtAuthn = "envoy.filters.http.jwt_authn"; // Header to metadata filter - const std::string HEADER_TO_METADATA = "envoy.filters.http.header_to_metadata"; + const std::string HeaderToMetadata = "envoy.filters.http.header_to_metadata"; // Converts names from v1 to v2 const Config::V1Converter v1_converter_; // NOTE: Do not add any new filters to this list. All future filters are v2 only. HttpFilterNameValues() - : v1_converter_({BUFFER, CORS, DYNAMO, FAULT, GRPC_HTTP1_BRIDGE, GRPC_JSON_TRANSCODER, - GRPC_WEB, HEADER_TO_METADATA, HEALTH_CHECK, IP_TAGGING, RATE_LIMIT, ROUTER, - LUA, EXT_AUTHORIZATION}) {} + : v1_converter_({Buffer, Cors, Dynamo, Fault, GrpcHttp1Bridge, GrpcJsonTranscoder, GrpcWeb, + HeaderToMetadata, HealthCheck, IpTagging, RateLimit, Router, Lua, + ExtAuthorization}) {} }; typedef ConstSingleton HttpFilterNames; diff --git a/source/extensions/filters/listener/original_dst/BUILD b/source/extensions/filters/listener/original_dst/BUILD index 969370454e75a..b3843b7744dd9 100644 --- a/source/extensions/filters/listener/original_dst/BUILD +++ b/source/extensions/filters/listener/original_dst/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # ORIGINAL_DST iptables redirection listener filter # Public docs: docs/root/configuration/listener_filters/original_dst_filter.rst diff --git a/source/extensions/filters/listener/original_dst/config.cc b/source/extensions/filters/listener/original_dst/config.cc index 1176508cb1033..63e858ee463fe 100644 --- a/source/extensions/filters/listener/original_dst/config.cc +++ b/source/extensions/filters/listener/original_dst/config.cc @@ -29,7 +29,7 @@ class OriginalDstConfigFactory : public Server::Configuration::NamedListenerFilt return std::make_unique(); } - std::string name() override { return ListenerFilterNames::get().ORIGINAL_DST; } + std::string name() override { return ListenerFilterNames::get().OriginalDst; } }; /** diff --git a/source/extensions/filters/listener/proxy_protocol/BUILD b/source/extensions/filters/listener/proxy_protocol/BUILD index 14f8b475ee34c..5bbefff20f0d8 100644 --- a/source/extensions/filters/listener/proxy_protocol/BUILD +++ b/source/extensions/filters/listener/proxy_protocol/BUILD @@ -1,8 +1,6 @@ licenses(["notice"]) # Apache 2 -# Proxy protocol V1 listener filter: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt -# Note: Currently there are no public docs for this filter as it is implicitly loaded by -# configuration options. In the future it will likely be configurable on a per filter chain -# basis and will need public docs. + +# Proxy protocol listener filter: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt load( "//bazel:envoy_build_system.bzl", @@ -15,11 +13,15 @@ envoy_package() envoy_cc_library( name = "proxy_protocol_lib", srcs = ["proxy_protocol.cc"], - hdrs = ["proxy_protocol.h"], + hdrs = [ + "proxy_protocol.h", + "proxy_protocol_header.h", + ], deps = [ "//include/envoy/event:dispatcher_interface", "//include/envoy/network:filter_interface", "//include/envoy/network:listen_socket_interface", + "//source/common/api:os_sys_calls_lib", "//source/common/common:assert_lib", "//source/common/common:empty_string", "//source/common/common:minimal_logger_lib", diff --git a/source/extensions/filters/listener/proxy_protocol/config.cc b/source/extensions/filters/listener/proxy_protocol/config.cc index 3ce2d851f8087..6cbe967666e02 100644 --- a/source/extensions/filters/listener/proxy_protocol/config.cc +++ b/source/extensions/filters/listener/proxy_protocol/config.cc @@ -28,7 +28,7 @@ class ProxyProtocolConfigFactory : public Server::Configuration::NamedListenerFi return std::make_unique(); } - std::string name() override { return ListenerFilterNames::get().PROXY_PROTOCOL; } + std::string name() override { return ListenerFilterNames::get().ProxyProtocol; } }; /** diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc index b303a966ab1ed..a0ffa4128f3a0 100644 --- a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.cc @@ -1,7 +1,10 @@ #include "extensions/filters/listener/proxy_protocol/proxy_protocol.h" +#include +#include #include +#include #include #include #include @@ -11,6 +14,7 @@ #include "envoy/network/listen_socket.h" #include "envoy/stats/stats.h" +#include "common/api/os_sys_calls_impl.h" #include "common/common/assert.h" #include "common/common/empty_string.h" #include "common/common/utility.h" @@ -50,11 +54,145 @@ void Filter::onRead() { void Filter::onReadWorker() { Network::ConnectionSocket& socket = cb_->socket(); - std::string proxy_line; - if (!readLine(socket.fd(), proxy_line)) { + + if ((!proxy_protocol_header_.has_value() && !readProxyHeader(socket.fd())) || + (proxy_protocol_header_.has_value() && !parseExtensions(socket.fd()))) { + // We return if a) we do not yet have the header, or b) we have the header but not yet all + // the extension data. In both cases we'll be called again when the socket is ready to read + // and pick up where we left off. return; } + if (proxy_protocol_header_.has_value() && !proxy_protocol_header_.value().local_command_) { + // If this is a local_command, we are not to override address + // Error check the source and destination fields. Most errors are caught by the address + // parsing above, but a malformed IPv6 address may combine with a malformed port and parse as + // an IPv6 address when parsing for an IPv4 address(for v1 mode). Remote address refers to the + // source address. + const auto remote_version = proxy_protocol_header_.value().remote_address_->ip()->version(); + const auto local_version = proxy_protocol_header_.value().local_address_->ip()->version(); + if (remote_version != proxy_protocol_header_.value().protocol_version_ || + local_version != proxy_protocol_header_.value().protocol_version_) { + throw EnvoyException("failed to read proxy protocol"); + } + // Check that both addresses are valid unicast addresses, as required for TCP + if (!proxy_protocol_header_.value().remote_address_->ip()->isUnicastAddress() || + !proxy_protocol_header_.value().local_address_->ip()->isUnicastAddress()) { + throw EnvoyException("failed to read proxy protocol"); + } + + // Only set the local address if it really changed, and mark it as address being restored. + if (*proxy_protocol_header_.value().local_address_ != *socket.localAddress()) { + socket.setLocalAddress(proxy_protocol_header_.value().local_address_, true); + } + socket.setRemoteAddress(proxy_protocol_header_.value().remote_address_); + } + + // Release the file event so that we do not interfere with the connection read events. + file_event_.reset(); + cb_->continueFilterChain(true); +} + +size_t Filter::lenV2Address(char* buf) { + const uint8_t proto_family = buf[PROXY_PROTO_V2_SIGNATURE_LEN + 1]; + const int ver_cmd = buf[PROXY_PROTO_V2_SIGNATURE_LEN]; + size_t len; + + if ((ver_cmd & 0xf) == PROXY_PROTO_V2_LOCAL) { + // According to the spec there is no address encoded, len=0, and we must ignore + return 0; + } + + switch ((proto_family & 0xf0) >> 4) { + case PROXY_PROTO_V2_AF_INET: + len = PROXY_PROTO_V2_ADDR_LEN_INET; + break; + case PROXY_PROTO_V2_AF_INET6: + len = PROXY_PROTO_V2_ADDR_LEN_INET6; + break; + default: + throw EnvoyException("Unsupported V2 proxy protocol address family"); + } + return len; +} + +void Filter::parseV2Header(char* buf) { + const int ver_cmd = buf[PROXY_PROTO_V2_SIGNATURE_LEN]; + uint8_t upper_byte = buf[PROXY_PROTO_V2_HEADER_LEN - 2]; + uint8_t lower_byte = buf[PROXY_PROTO_V2_HEADER_LEN - 1]; + size_t hdr_addr_len = (upper_byte << 8) + lower_byte; + + if ((ver_cmd & 0xf) == PROXY_PROTO_V2_LOCAL) { + // This is locally-initiated, e.g. health-check, and should not override remote address + proxy_protocol_header_.emplace(WireHeader{hdr_addr_len}); + return; + } + + // Only do connections on behalf of another user, not internally-driven health-checks. If + // its not on behalf of someone, or its not AF_INET{6} / STREAM/DGRAM, ignore and + /// use the real-remote info + if ((ver_cmd & 0xf) == PROXY_PROTO_V2_ONBEHALF_OF) { + uint8_t proto_family = buf[PROXY_PROTO_V2_SIGNATURE_LEN + 1]; + if (((proto_family & 0x0f) == PROXY_PROTO_V2_TRANSPORT_STREAM) || + ((proto_family & 0x0f) == PROXY_PROTO_V2_TRANSPORT_DGRAM)) { + if (((proto_family & 0xf0) >> 4) == PROXY_PROTO_V2_AF_INET) { + typedef struct { + uint32_t src_addr; + uint32_t dst_addr; + uint16_t src_port; + uint16_t dst_port; + } __attribute__((packed)) pp_ipv4_addr; + pp_ipv4_addr* v4; + v4 = reinterpret_cast(&buf[PROXY_PROTO_V2_HEADER_LEN]); + sockaddr_in ra4, la4; + memset(&ra4, 0, sizeof(ra4)); + memset(&la4, 0, sizeof(la4)); + ra4.sin_family = AF_INET; + ra4.sin_port = v4->src_port; + ra4.sin_addr.s_addr = v4->src_addr; + + la4.sin_family = AF_INET; + la4.sin_port = v4->dst_port; + la4.sin_addr.s_addr = v4->dst_addr; + proxy_protocol_header_.emplace( + WireHeader{hdr_addr_len - PROXY_PROTO_V2_ADDR_LEN_INET, Network::Address::IpVersion::v4, + std::make_shared(&ra4), + std::make_shared(&la4)}); + return; + } else if (((proto_family & 0xf0) >> 4) == PROXY_PROTO_V2_AF_INET6) { + typedef struct { + uint8_t src_addr[16]; + uint8_t dst_addr[16]; + uint16_t src_port; + uint16_t dst_port; + } __attribute__((packed)) pp_ipv6_addr; + pp_ipv6_addr* v6; + v6 = reinterpret_cast(&buf[PROXY_PROTO_V2_HEADER_LEN]); + sockaddr_in6 ra6, la6; + memset(&ra6, 0, sizeof(ra6)); + memset(&la6, 0, sizeof(la6)); + ra6.sin6_family = AF_INET6; + ra6.sin6_port = v6->src_port; + memcpy(ra6.sin6_addr.s6_addr, v6->src_addr, sizeof(ra6.sin6_addr.s6_addr)); + + la6.sin6_family = AF_INET6; + la6.sin6_port = v6->dst_port; + memcpy(la6.sin6_addr.s6_addr, v6->dst_addr, sizeof(la6.sin6_addr.s6_addr)); + + proxy_protocol_header_.emplace(WireHeader{ + hdr_addr_len - PROXY_PROTO_V2_ADDR_LEN_INET6, Network::Address::IpVersion::v6, + std::make_shared(ra6), + std::make_shared(la6)}); + return; + } + } + } + throw EnvoyException("Unsupported command or address family or transport"); +} + +void Filter::parseV1Header(char* buf, size_t len) { + std::string proxy_line; + proxy_line.assign(buf, len); const auto trimmed_proxy_line = StringUtil::rtrim(proxy_line); // Parse proxy protocol line with format: PROXY TCP4/TCP6/UNKNOWN SOURCE_ADDRESS @@ -73,89 +211,163 @@ void Filter::onReadWorker() { throw EnvoyException("failed to read proxy protocol"); } - Network::Address::IpVersion protocol_version; - Network::Address::InstanceConstSharedPtr remote_address; - Network::Address::InstanceConstSharedPtr local_address; - // TODO(gsagula): parseInternetAddressAndPort() could be modified to take two string_view // arguments, so we can eliminate allocation here. if (line_parts[1] == "TCP4") { - protocol_version = Network::Address::IpVersion::v4; - remote_address = Network::Utility::parseInternetAddressAndPort( - std::string{line_parts[2]} + ":" + std::string{line_parts[4]}); - local_address = Network::Utility::parseInternetAddressAndPort( - std::string{line_parts[3]} + ":" + std::string{line_parts[5]}); + proxy_protocol_header_.emplace( + WireHeader{0, Network::Address::IpVersion::v4, + Network::Utility::parseInternetAddressAndPort( + std::string{line_parts[2]} + ":" + std::string{line_parts[4]}), + Network::Utility::parseInternetAddressAndPort( + std::string{line_parts[3]} + ":" + std::string{line_parts[5]})}); } else if (line_parts[1] == "TCP6") { - protocol_version = Network::Address::IpVersion::v6; - remote_address = Network::Utility::parseInternetAddressAndPort( - "[" + std::string{line_parts[2]} + "]:" + std::string{line_parts[4]}); - local_address = Network::Utility::parseInternetAddressAndPort( - "[" + std::string{line_parts[3]} + "]:" + std::string{line_parts[5]}); + proxy_protocol_header_.emplace( + WireHeader{0, Network::Address::IpVersion::v6, + Network::Utility::parseInternetAddressAndPort( + "[" + std::string{line_parts[2]} + "]:" + std::string{line_parts[4]}), + Network::Utility::parseInternetAddressAndPort( + "[" + std::string{line_parts[3]} + "]:" + std::string{line_parts[5]})}); } else { throw EnvoyException("failed to read proxy protocol"); } + } +} - // Error check the source and destination fields. Most errors are caught by the address - // parsing above, but a malformed IPv6 address may combine with a malformed port and parse as - // an IPv6 address when parsing for an IPv4 address. Remote address refers to the source - // address. - const auto remote_version = remote_address->ip()->version(); - const auto local_version = local_address->ip()->version(); - if (remote_version != protocol_version || local_version != protocol_version) { - throw EnvoyException("failed to read proxy protocol"); +bool Filter::parseExtensions(int fd) { + // If we ever implement extensions elsewhere, be sure to + // continue to skip and ignore those for LOCAL. + while (proxy_protocol_header_.value().extensions_length_) { + // buf_ is no longer in use so we re-use it to read/discard + int bytes_avail; + auto& os_syscalls = Api::OsSysCallsSingleton::get(); + if (os_syscalls.ioctl(fd, FIONREAD, &bytes_avail) < 0) { + throw EnvoyException("failed to read proxy protocol (no bytes avail)"); } - // Check that both addresses are valid unicast addresses, as required for TCP - if (!remote_address->ip()->isUnicastAddress() || !local_address->ip()->isUnicastAddress()) { - throw EnvoyException("failed to read proxy protocol"); + if (bytes_avail == 0) { + return false; } - - // Only set the local address if it really changed, and mark it as address being restored. - if (*local_address != *socket.localAddress()) { - socket.setLocalAddress(local_address, true); + bytes_avail = std::min(size_t(bytes_avail), sizeof(buf_)); + bytes_avail = std::min(size_t(bytes_avail), proxy_protocol_header_.value().extensions_length_); + ssize_t nread = os_syscalls.recv(fd, buf_, bytes_avail, 0); + if (nread != bytes_avail) { + throw EnvoyException("failed to read proxy protocol extension"); } - socket.setRemoteAddress(remote_address); + proxy_protocol_header_.value().extensions_length_ -= nread; } - - // Release the file event so that we do not interfere with the connection read events. - file_event_.reset(); - cb_->continueFilterChain(true); + return true; } -bool Filter::readLine(int fd, std::string& s) { - while (buf_off_ < MAX_PROXY_PROTO_LEN) { - ssize_t nread = recv(fd, buf_ + buf_off_, MAX_PROXY_PROTO_LEN - buf_off_, MSG_PEEK); +bool Filter::readProxyHeader(int fd) { + while (buf_off_ < MAX_PROXY_PROTO_LEN_V2) { + int bytes_avail; + auto& os_syscalls = Api::OsSysCallsSingleton::get(); - if (nread == -1 && errno == EAGAIN) { + if (os_syscalls.ioctl(fd, FIONREAD, &bytes_avail) < 0) { + throw EnvoyException("failed to read proxy protocol (no bytes avail)"); + } + + if (bytes_avail == 0) { return false; - } else if (nread < 1) { - throw EnvoyException("failed to read proxy protocol"); } - bool found = false; - // continue searching buf_ from where we left off - for (; search_index_ < buf_off_ + nread; search_index_++) { - if (buf_[search_index_] == '\n' && buf_[search_index_ - 1] == '\r') { - search_index_++; - found = true; - break; + bytes_avail = std::min(size_t(bytes_avail), MAX_PROXY_PROTO_LEN_V2 - buf_off_); + + ssize_t nread = os_syscalls.recv(fd, buf_ + buf_off_, bytes_avail, MSG_PEEK); + + if (nread < 1) { + throw EnvoyException("failed to read proxy protocol (no bytes read)"); + } + + if (buf_off_ + nread >= PROXY_PROTO_V2_HEADER_LEN) { + const char* sig = PROXY_PROTO_V2_SIGNATURE; + if (!memcmp(buf_, sig, PROXY_PROTO_V2_SIGNATURE_LEN)) { + header_version_ = V2; + } else if (memcmp(buf_, PROXY_PROTO_V1_SIGNATURE, PROXY_PROTO_V1_SIGNATURE_LEN)) { + // It is not v2, and can't be v1, so no sense hanging around: it is invalid + throw EnvoyException("failed to read proxy protocol (exceed max v1 header len)"); } } - // Read the data upto and including the line feed, if available, but not past it. - // This should never fail, as search_index_ - buf_off_ <= nread, so we're asking - // only for bytes we have already seen. - nread = recv(fd, buf_ + buf_off_, search_index_ - buf_off_, 0); - ASSERT(size_t(nread) == search_index_ - buf_off_); + if (header_version_ == V2) { + const int ver_cmd = buf_[PROXY_PROTO_V2_SIGNATURE_LEN]; + if (((ver_cmd & 0xf0) >> 4) != PROXY_PROTO_V2_VERSION) { + throw EnvoyException("Unsupported V2 proxy protocol version"); + } + if (buf_off_ < PROXY_PROTO_V2_HEADER_LEN) { + ssize_t lread; + ssize_t exp = PROXY_PROTO_V2_HEADER_LEN - buf_off_; + lread = os_syscalls.recv(fd, buf_ + buf_off_, exp, 0); + if (lread != exp) { + throw EnvoyException("failed to read proxy protocol (remote closed)"); + } + buf_off_ += lread; + nread -= lread; + } + ssize_t addr_len = lenV2Address(buf_); + uint8_t upper_byte = buf_[PROXY_PROTO_V2_HEADER_LEN - 2]; + uint8_t lower_byte = buf_[PROXY_PROTO_V2_HEADER_LEN - 1]; + ssize_t hdr_addr_len = (upper_byte << 8) + lower_byte; + if (hdr_addr_len < addr_len) { + throw EnvoyException("failed to read proxy protocol (insufficient data)"); + } + if (ssize_t(buf_off_) + nread >= PROXY_PROTO_V2_HEADER_LEN + addr_len) { + ssize_t lread; + ssize_t missing = (PROXY_PROTO_V2_HEADER_LEN + addr_len) - buf_off_; + lread = os_syscalls.recv(fd, buf_ + buf_off_, missing, 0); + if (lread != missing) { + throw EnvoyException("failed to read proxy protocol (remote closed)"); + } + buf_off_ += lread; + parseV2Header(buf_); + // The TLV remain, they are read/discard in parseExtensions() which is called from the + // parent (if needed). + return true; + } else { + nread = os_syscalls.recv(fd, buf_ + buf_off_, nread, 0); + if (nread < 0) { + throw EnvoyException("failed to read proxy protocol (remote closed)"); + } + buf_off_ += nread; + } + } else { + // continue searching buf_ from where we left off + for (; search_index_ < buf_off_ + nread; search_index_++) { + if (buf_[search_index_] == '\n' && buf_[search_index_ - 1] == '\r') { + if (search_index_ == 1) { + // This could be the binary protocol. It cannot be the ascii protocol + header_version_ = InProgress; + } else { + header_version_ = V1; + search_index_++; + } + break; + } + } - buf_off_ += nread; + // If we bailed on the first char, we might be v2, but are for sure not v1. Thus we + // can read up to min(PROXY_PROTO_V2_HEADER_LEN, bytes_avail). If we bailed after first + // char, but before we hit \r\n, read up to search_index_. We're asking only for + // bytes we've already seen so there should be no block or fail + size_t ntoread; + if (header_version_ == InProgress) { + ntoread = bytes_avail; + } else { + ntoread = search_index_ - buf_off_; + } + + nread = os_syscalls.recv(fd, buf_ + buf_off_, ntoread, 0); + ASSERT(size_t(nread) == ntoread); - if (found) { - s.assign(buf_, buf_off_); - return true; + buf_off_ += nread; + + if (header_version_ == V1) { + parseV1Header(buf_, buf_off_); + return true; + } } } - throw EnvoyException("failed to read proxy protocol"); + throw EnvoyException("failed to read proxy protocol (exceed max v2 header len)"); } } // namespace ProxyProtocol diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h index 3815dadab2880..08ff0054420fb 100644 --- a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h @@ -6,6 +6,8 @@ #include "common/common/logger.h" +#include "proxy_protocol_header.h" + namespace Envoy { namespace Extensions { namespace ListenerFilters { @@ -38,9 +40,17 @@ class Config { typedef std::shared_ptr ConfigSharedPtr; +enum ProxyProtocolVersion { Unknown = -1, InProgress = -2, V1 = 1, V2 = 2 }; + /** - * Implementation the PROXY Protocol V1 listener filter - * (http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt) + * Implementation the PROXY Protocol listener filter + * (https://github.com/haproxy/haproxy/blob/master/doc/proxy-protocol.txt) + * + * This implementation supports Proxy Protocol v1 (TCP/UDP, v4/v6), + * and Proxy Protocol v2 (TCP/UDP, v4/v6). + * + * Non INET (AF_UNIX) address family in v2 is not supported, will throw an error. + * Extensions (TLV) in v2 are skipped over. */ class Filter : public Network::ListenerFilter, Logger::Loggable { public: @@ -50,17 +60,32 @@ class Filter : public Network::ListenerFilter, Logger::Loggable proxy_protocol_header_; }; } // namespace ProxyProtocol diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h b/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h new file mode 100644 index 0000000000000..97497e9067619 --- /dev/null +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h @@ -0,0 +1,54 @@ +#pragma once +#include "common/common/assert.h" + +namespace Envoy { +namespace Extensions { +namespace ListenerFilters { +namespace ProxyProtocol { + +// See https://github.com/haproxy/haproxy/blob/master/doc/proxy-protocol.txt for definitions + +constexpr char PROXY_PROTO_V1_SIGNATURE[] = "PROXY "; +constexpr uint32_t PROXY_PROTO_V1_SIGNATURE_LEN = 6; +constexpr char PROXY_PROTO_V2_SIGNATURE[] = "\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a"; +constexpr uint32_t PROXY_PROTO_V2_SIGNATURE_LEN = 12; +constexpr uint32_t PROXY_PROTO_V2_HEADER_LEN = 16; +constexpr uint32_t PROXY_PROTO_V2_VERSION = 0x2; +constexpr uint32_t PROXY_PROTO_V2_ONBEHALF_OF = 0x1; +constexpr uint32_t PROXY_PROTO_V2_LOCAL = 0x0; + +constexpr uint32_t PROXY_PROTO_V2_AF_INET = 0x1; +constexpr uint32_t PROXY_PROTO_V2_AF_INET6 = 0x2; +constexpr uint32_t PROXY_PROTO_V2_AF_UNIX = 0x3; + +struct WireHeader { + WireHeader(size_t extensions_length) + : extensions_length_(extensions_length), protocol_version_(Network::Address::IpVersion::v4), + remote_address_(0), local_address_(0), local_command_(true) {} + WireHeader(size_t extensions_length, Network::Address::IpVersion protocol_version, + Network::Address::InstanceConstSharedPtr remote_address, + Network::Address::InstanceConstSharedPtr local_address) + : extensions_length_(extensions_length), protocol_version_(protocol_version), + remote_address_(remote_address), local_address_(local_address), local_command_(false) { + + ASSERT(extensions_length_ <= 65535); + } + size_t extensions_length_; + const Network::Address::IpVersion protocol_version_; + const Network::Address::InstanceConstSharedPtr remote_address_; + const Network::Address::InstanceConstSharedPtr local_address_; + const bool local_command_; +}; + +constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_UNSPEC = 0; +constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_INET = 12; +constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_INET6 = 36; +constexpr uint32_t PROXY_PROTO_V2_ADDR_LEN_UNIX = 216; + +constexpr uint8_t PROXY_PROTO_V2_TRANSPORT_STREAM = 0x1; +constexpr uint8_t PROXY_PROTO_V2_TRANSPORT_DGRAM = 0x2; + +} // namespace ProxyProtocol +} // namespace ListenerFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/listener/tls_inspector/BUILD b/source/extensions/filters/listener/tls_inspector/BUILD index d31d1c7149706..af90ed9fcd4af 100644 --- a/source/extensions/filters/listener/tls_inspector/BUILD +++ b/source/extensions/filters/listener/tls_inspector/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # TLS inspector filter for examining various TLS parameters before routing to a FilterChain. # Public docs: docs/root/configuration/listener_filters/tls_inspector.rst diff --git a/source/extensions/filters/listener/tls_inspector/config.cc b/source/extensions/filters/listener/tls_inspector/config.cc index a75bb315da484..af8f8170683c6 100644 --- a/source/extensions/filters/listener/tls_inspector/config.cc +++ b/source/extensions/filters/listener/tls_inspector/config.cc @@ -30,7 +30,7 @@ class TlsInspectorConfigFactory : public Server::Configuration::NamedListenerFil return std::make_unique(); } - std::string name() override { return ListenerFilterNames::get().TLS_INSPECTOR; } + std::string name() override { return ListenerFilterNames::get().TlsInspector; } }; /** diff --git a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc index 14cb468541585..a462a812c8700 100644 --- a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc +++ b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc @@ -63,7 +63,7 @@ bssl::UniquePtr Config::newSsl() { return bssl::UniquePtr{SSL_new(ssl_ thread_local uint8_t Filter::buf_[Config::TLS_MAX_CLIENT_HELLO]; Filter::Filter(const ConfigSharedPtr config) : config_(config), ssl_(config_->newSsl()) { - RELEASE_ASSERT(sizeof(buf_) >= config_->maxClientHelloSize()); + RELEASE_ASSERT(sizeof(buf_) >= config_->maxClientHelloSize(), ""); SSL_set_app_data(ssl_.get(), this); SSL_set_accept_state(ssl_.get()); @@ -144,9 +144,10 @@ void Filter::onRead() { // platforms. auto& os_syscalls = Api::OsSysCallsSingleton::get(); ssize_t n = os_syscalls.recv(cb_->socket().fd(), buf_, config_->maxClientHelloSize(), MSG_PEEK); + const int error = errno; // Latch errno right after the recv call. ENVOY_LOG(trace, "tls inspector: recv: {}", n); - if (n == -1 && errno == EAGAIN) { + if (n == -1 && error == EAGAIN) { return; } else if (n < 0) { config_->stats().read_error_.inc(); @@ -209,7 +210,7 @@ void Filter::parseClientHello(const void* data, size_t len) { } else { config_->stats().alpn_not_found_.inc(); } - cb_->socket().setDetectedTransportProtocol(TransportSockets::TransportSocketNames::get().TLS); + cb_->socket().setDetectedTransportProtocol(TransportSockets::TransportSocketNames::get().Tls); } else { config_->stats().tls_not_found_.inc(); } diff --git a/source/extensions/filters/listener/well_known_names.h b/source/extensions/filters/listener/well_known_names.h index f9a4fbce55180..ae66726a10a22 100644 --- a/source/extensions/filters/listener/well_known_names.h +++ b/source/extensions/filters/listener/well_known_names.h @@ -13,11 +13,11 @@ namespace ListenerFilters { class ListenerFilterNameValues { public: // Original destination listener filter - const std::string ORIGINAL_DST = "envoy.listener.original_dst"; + const std::string OriginalDst = "envoy.listener.original_dst"; // Proxy Protocol listener filter - const std::string PROXY_PROTOCOL = "envoy.listener.proxy_protocol"; + const std::string ProxyProtocol = "envoy.listener.proxy_protocol"; // TLS Inspector listener filter - const std::string TLS_INSPECTOR = "envoy.listener.tls_inspector"; + const std::string TlsInspector = "envoy.listener.tls_inspector"; }; typedef ConstSingleton ListenerFilterNames; diff --git a/source/extensions/filters/network/client_ssl_auth/BUILD b/source/extensions/filters/network/client_ssl_auth/BUILD index 23c515dace98a..3a99d24182314 100644 --- a/source/extensions/filters/network/client_ssl_auth/BUILD +++ b/source/extensions/filters/network/client_ssl_auth/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Client SSL authorization L4 network filter # Public docs: docs/root/configuration/network_filters/client_ssl_auth_filter.rst diff --git a/source/extensions/filters/network/client_ssl_auth/config.h b/source/extensions/filters/network/client_ssl_auth/config.h index 896e29923ba37..8953d3be6b2c2 100644 --- a/source/extensions/filters/network/client_ssl_auth/config.h +++ b/source/extensions/filters/network/client_ssl_auth/config.h @@ -17,7 +17,7 @@ class ClientSslAuthConfigFactory : public Common::FactoryBase< envoy::config::filter::network::client_ssl_auth::v2::ClientSSLAuth> { public: - ClientSslAuthConfigFactory() : FactoryBase(NetworkFilterNames::get().CLIENT_SSL_AUTH) {} + ClientSslAuthConfigFactory() : FactoryBase(NetworkFilterNames::get().ClientSslAuth) {} // NamedNetworkFilterConfigFactory Network::FilterFactoryCb diff --git a/source/extensions/filters/network/common/factory_base.h b/source/extensions/filters/network/common/factory_base.h index 0d1a29920b5b5..421ea72dcbab1 100644 --- a/source/extensions/filters/network/common/factory_base.h +++ b/source/extensions/filters/network/common/factory_base.h @@ -18,7 +18,7 @@ class FactoryBase : public Server::Configuration::NamedNetworkFilterConfigFactor Network::FilterFactoryCb createFilterFactory(const Json::Object&, Server::Configuration::FactoryContext&) override { // Only used in v1 filters. - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Network::FilterFactoryCb diff --git a/source/extensions/filters/network/echo/BUILD b/source/extensions/filters/network/echo/BUILD index 286f371d107bd..253cfb55935b9 100644 --- a/source/extensions/filters/network/echo/BUILD +++ b/source/extensions/filters/network/echo/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Echo L4 network filter. This is primarily a simplistic example. # Public docs: docs/root/configuration/network_filters/echo_filter.rst diff --git a/source/extensions/filters/network/echo/config.cc b/source/extensions/filters/network/echo/config.cc index d092b074ae2d1..990a3bf5ac01b 100644 --- a/source/extensions/filters/network/echo/config.cc +++ b/source/extensions/filters/network/echo/config.cc @@ -34,7 +34,7 @@ class EchoConfigFactory : public Server::Configuration::NamedNetworkFilterConfig return ProtobufTypes::MessagePtr{new Envoy::ProtobufWkt::Empty()}; } - std::string name() override { return NetworkFilterNames::get().ECHO; } + std::string name() override { return NetworkFilterNames::get().Echo; } }; /** diff --git a/source/extensions/filters/network/ext_authz/BUILD b/source/extensions/filters/network/ext_authz/BUILD index ef0737886dc2d..f0e866d13307e 100644 --- a/source/extensions/filters/network/ext_authz/BUILD +++ b/source/extensions/filters/network/ext_authz/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # External authorization L4 network filter # Public docs: TODO(saumoh): Docs needed in docs/root/configuration/network_filters @@ -22,8 +23,8 @@ envoy_cc_library( "//include/envoy/upstream:cluster_manager_interface", "//source/common/common:assert_lib", "//source/common/tracing:http_tracer_lib", + "//source/extensions/filters/common/ext_authz:ext_authz_grpc_lib", "//source/extensions/filters/common/ext_authz:ext_authz_interface", - "//source/extensions/filters/common/ext_authz:ext_authz_lib", "@envoy_api//envoy/config/filter/network/ext_authz/v2:ext_authz_cc", ], ) diff --git a/source/extensions/filters/network/ext_authz/config.cc b/source/extensions/filters/network/ext_authz/config.cc index cffece5bb2b4b..54e5de374a3d8 100644 --- a/source/extensions/filters/network/ext_authz/config.cc +++ b/source/extensions/filters/network/ext_authz/config.cc @@ -10,7 +10,7 @@ #include "common/protobuf/utility.h" #include "extensions/filters/common/ext_authz/ext_authz.h" -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" #include "extensions/filters/network/ext_authz/ext_authz.h" namespace Envoy { diff --git a/source/extensions/filters/network/ext_authz/config.h b/source/extensions/filters/network/ext_authz/config.h index 6c801436aa889..05fccafbb1f56 100644 --- a/source/extensions/filters/network/ext_authz/config.h +++ b/source/extensions/filters/network/ext_authz/config.h @@ -17,7 +17,7 @@ namespace ExtAuthz { class ExtAuthzConfigFactory : public Common::FactoryBase { public: - ExtAuthzConfigFactory() : FactoryBase(NetworkFilterNames::get().EXT_AUTHORIZATION) {} + ExtAuthzConfigFactory() : FactoryBase(NetworkFilterNames::get().ExtAuthorization) {} private: Network::FilterFactoryCb createFilterFactoryFromProtoTyped( diff --git a/source/extensions/filters/network/ext_authz/ext_authz.cc b/source/extensions/filters/network/ext_authz/ext_authz.cc index d0d2ca335bfe0..73412ee16eb42 100644 --- a/source/extensions/filters/network/ext_authz/ext_authz.cc +++ b/source/extensions/filters/network/ext_authz/ext_authz.cc @@ -56,11 +56,11 @@ void Filter::onEvent(Network::ConnectionEvent event) { } } -void Filter::onComplete(Filters::Common::ExtAuthz::CheckStatus status) { +void Filter::onComplete(Filters::Common::ExtAuthz::ResponsePtr&& response) { status_ = Status::Complete; config_->stats().active_.dec(); - switch (status) { + switch (response->status) { case Filters::Common::ExtAuthz::CheckStatus::OK: config_->stats().ok_.inc(); break; @@ -73,14 +73,16 @@ void Filter::onComplete(Filters::Common::ExtAuthz::CheckStatus status) { } // Fail open only if configured to do so and if the check status was a error. - if (status == Filters::Common::ExtAuthz::CheckStatus::Denied || - (status == Filters::Common::ExtAuthz::CheckStatus::Error && !config_->failureModeAllow())) { + if (response->status == Filters::Common::ExtAuthz::CheckStatus::Denied || + (response->status == Filters::Common::ExtAuthz::CheckStatus::Error && + !config_->failureModeAllow())) { config_->stats().cx_closed_.inc(); filter_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); } else { // Let the filter chain continue. filter_return_ = FilterReturn::Continue; - if (config_->failureModeAllow() && status == Filters::Common::ExtAuthz::CheckStatus::Error) { + if (config_->failureModeAllow() && + response->status == Filters::Common::ExtAuthz::CheckStatus::Error) { // Status is Error and yet we are configured to allow traffic. Click a counter. config_->stats().failure_mode_allowed_.inc(); } diff --git a/source/extensions/filters/network/ext_authz/ext_authz.h b/source/extensions/filters/network/ext_authz/ext_authz.h index ef97b822c2b36..4fecd72b052e0 100644 --- a/source/extensions/filters/network/ext_authz/ext_authz.h +++ b/source/extensions/filters/network/ext_authz/ext_authz.h @@ -13,7 +13,7 @@ #include "envoy/upstream/cluster_manager.h" #include "extensions/filters/common/ext_authz/ext_authz.h" -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" namespace Envoy { namespace Extensions { @@ -90,7 +90,7 @@ class Filter : public Network::ReadFilter, void onBelowWriteBufferLowWatermark() override {} // ExtAuthz::RequestCallbacks - void onComplete(Filters::Common::ExtAuthz::CheckStatus status) override; + void onComplete(Filters::Common::ExtAuthz::ResponsePtr&&) override; private: // State of this filter's communication with the external authorization service. diff --git a/source/extensions/filters/network/http_connection_manager/BUILD b/source/extensions/filters/network/http_connection_manager/BUILD index 8b62933ac9b3d..29f6420e60845 100644 --- a/source/extensions/filters/network/http_connection_manager/BUILD +++ b/source/extensions/filters/network/http_connection_manager/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # L4 network filter that implements HTTP protocol handling and filtering. This filter internally # drives all of the L7 HTTP filters. # Public docs: docs/root/configuration/http_conn_man/http_conn_man.rst diff --git a/source/extensions/filters/network/http_connection_manager/config.cc b/source/extensions/filters/network/http_connection_manager/config.cc index 3b6283d69bb6e..0503e5296b9a6 100644 --- a/source/extensions/filters/network/http_connection_manager/config.cc +++ b/source/extensions/filters/network/http_connection_manager/config.cc @@ -84,7 +84,8 @@ HttpConnectionManagerFilterConfigFactory::createFilterFactoryFromProtoTyped( Network::FilterFactoryCb HttpConnectionManagerFilterConfigFactory::createFilterFactory( const Json::Object& json_config, Server::Configuration::FactoryContext& context) { envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager proto_config; - Config::FilterJson::translateHttpConnectionManager(json_config, proto_config); + Config::FilterJson::translateHttpConnectionManager(json_config, proto_config, + context.scope().statsOptions()); return createFilterFactoryFromProtoTyped(proto_config, context); } @@ -128,6 +129,9 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig( route_config_provider_manager_(route_config_provider_manager), http2_settings_(Http::Utility::parseHttp2Settings(config.http2_protocol_options())), http1_settings_(Http::Utility::parseHttp1Settings(config.http_protocol_options())), + idle_timeout_(PROTOBUF_GET_OPTIONAL_MS(config, idle_timeout)), + stream_idle_timeout_( + PROTOBUF_GET_MS_OR_DEFAULT(config, stream_idle_timeout, StreamIdleTimeoutMs)), drain_timeout_(PROTOBUF_GET_MS_OR_DEFAULT(config, drain_timeout, 5000)), generate_request_id_(PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, generate_request_id, true)), date_provider_(date_provider), @@ -159,7 +163,7 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig( forward_client_cert_ = Http::ForwardClientCertType::AlwaysForwardOnly; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } const auto& set_current_client_cert_details = config.set_current_client_cert_details(); @@ -169,9 +173,6 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig( if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(set_current_client_cert_details, subject, false)) { set_current_client_cert_details_.push_back(Http::ClientCertDetailsType::Subject); } - if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(set_current_client_cert_details, san, false)) { - set_current_client_cert_details_.push_back(Http::ClientCertDetailsType::SAN); - } if (set_current_client_cert_details.uri()) { set_current_client_cert_details_.push_back(Http::ClientCertDetailsType::URI); } @@ -199,7 +200,7 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig( tracing_operation_name = Tracing::OperationName::Egress; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } for (const std::string& header : tracing_config.request_headers_for_tags()) { @@ -218,10 +219,6 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig( overall_sampling})); } - if (config.has_idle_timeout()) { - idle_timeout_ = std::chrono::milliseconds(PROTOBUF_GET_MS_REQUIRED(config, idle_timeout)); - } - for (const auto& access_log : config.access_log()) { AccessLog::InstanceSharedPtr current_access_log = AccessLog::AccessLogFactory::fromProto(access_log, context_); @@ -245,7 +242,7 @@ HttpConnectionManagerConfig::HttpConnectionManagerConfig( codec_type_ = CodecType::HTTP2; break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } const auto& filters = config.http_filters(); @@ -323,7 +320,7 @@ HttpConnectionManagerConfig::createCodec(Network::Connection& connection, } } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void HttpConnectionManagerConfig::createFilterChain(Http::FilterChainFactoryCallbacks& callbacks) { diff --git a/source/extensions/filters/network/http_connection_manager/config.h b/source/extensions/filters/network/http_connection_manager/config.h index 9e5863b0efcb5..c4cffd2dd458a 100644 --- a/source/extensions/filters/network/http_connection_manager/config.h +++ b/source/extensions/filters/network/http_connection_manager/config.h @@ -32,7 +32,7 @@ class HttpConnectionManagerFilterConfigFactory envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager> { public: HttpConnectionManagerFilterConfigFactory() - : FactoryBase(NetworkFilterNames::get().HTTP_CONNECTION_MANAGER) {} + : FactoryBase(NetworkFilterNames::get().HttpConnectionManager) {} // NamedNetworkFilterConfigFactory Network::FilterFactoryCb @@ -87,7 +87,8 @@ class HttpConnectionManagerConfig : Logger::Loggable, std::chrono::milliseconds drainTimeout() override { return drain_timeout_; } FilterChainFactory& filterFactory() override { return *this; } bool generateRequestId() override { return generate_request_id_; } - const absl::optional& idleTimeout() override { return idle_timeout_; } + absl::optional idleTimeout() const override { return idle_timeout_; } + std::chrono::milliseconds streamIdleTimeout() const override { return stream_idle_timeout_; } Router::RouteConfigProvider& routeConfigProvider() override { return *route_config_provider_; } const std::string& serverName() override { return server_name_; } Http::ConnectionManagerStats& stats() override { return stats_; } @@ -137,12 +138,16 @@ class HttpConnectionManagerConfig : Logger::Loggable, Http::TracingConnectionManagerConfigPtr tracing_config_; absl::optional user_agent_; absl::optional idle_timeout_; - Router::RouteConfigProviderSharedPtr route_config_provider_; + std::chrono::milliseconds stream_idle_timeout_; + Router::RouteConfigProviderPtr route_config_provider_; std::chrono::milliseconds drain_timeout_; bool generate_request_id_; Http::DateProvider& date_provider_; Http::ConnectionManagerListenerStats listener_stats_; const bool proxy_100_continue_; + + // Default idle timeout is 5 minutes if nothing is specified in the HCM config. + static const uint64_t StreamIdleTimeoutMs = 5 * 60 * 1000; }; } // namespace HttpConnectionManager diff --git a/source/extensions/filters/network/mongo_proxy/BUILD b/source/extensions/filters/network/mongo_proxy/BUILD index 364ec8989b539..e3e0efd147a1b 100644 --- a/source/extensions/filters/network/mongo_proxy/BUILD +++ b/source/extensions/filters/network/mongo_proxy/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Mongo proxy L4 network filter (observability and fault injection). # Public docs: docs/root/configuration/network_filters/mongo_proxy_filter.rst diff --git a/source/extensions/filters/network/mongo_proxy/bson_impl.cc b/source/extensions/filters/network/mongo_proxy/bson_impl.cc index 6f11f858e105a..de675c460c6d7 100644 --- a/source/extensions/filters/network/mongo_proxy/bson_impl.cc +++ b/source/extensions/filters/network/mongo_proxy/bson_impl.cc @@ -199,7 +199,7 @@ int32_t FieldImpl::byteSize() const { } } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void FieldImpl::encode(Buffer::Instance& output) const { @@ -252,7 +252,7 @@ void FieldImpl::encode(Buffer::Instance& output) const { return BufferHelper::writeInt32(output, value_.int32_value_); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } bool FieldImpl::operator==(const Field& rhs) const { @@ -314,7 +314,7 @@ bool FieldImpl::operator==(const Field& rhs) const { } } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } std::string FieldImpl::toString() const { @@ -362,7 +362,7 @@ std::string FieldImpl::toString() const { } } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } void DocumentImpl::fromBuffer(Buffer::Instance& data) { diff --git a/source/extensions/filters/network/mongo_proxy/config.h b/source/extensions/filters/network/mongo_proxy/config.h index 78ae8f8eb8c33..c69ce12d5b2d1 100644 --- a/source/extensions/filters/network/mongo_proxy/config.h +++ b/source/extensions/filters/network/mongo_proxy/config.h @@ -19,7 +19,7 @@ namespace MongoProxy { class MongoProxyFilterConfigFactory : public Common::FactoryBase { public: - MongoProxyFilterConfigFactory() : FactoryBase(NetworkFilterNames::get().MONGO_PROXY) {} + MongoProxyFilterConfigFactory() : FactoryBase(NetworkFilterNames::get().MongoProxy) {} // NamedNetworkFilterConfigFactory Network::FilterFactoryCb diff --git a/source/extensions/filters/network/ratelimit/BUILD b/source/extensions/filters/network/ratelimit/BUILD index b72d061267cb9..78f9360153447 100644 --- a/source/extensions/filters/network/ratelimit/BUILD +++ b/source/extensions/filters/network/ratelimit/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Ratelimit L4 network filter # Public docs: docs/root/configuration/network_filters/rate_limit_filter.rst diff --git a/source/extensions/filters/network/ratelimit/config.h b/source/extensions/filters/network/ratelimit/config.h index fd68087efd4f9..19c77ddd748e5 100644 --- a/source/extensions/filters/network/ratelimit/config.h +++ b/source/extensions/filters/network/ratelimit/config.h @@ -17,7 +17,7 @@ namespace RateLimitFilter { class RateLimitConfigFactory : public Common::FactoryBase { public: - RateLimitConfigFactory() : FactoryBase(NetworkFilterNames::get().RATE_LIMIT) {} + RateLimitConfigFactory() : FactoryBase(NetworkFilterNames::get().RateLimit) {} // NamedNetworkFilterConfigFactory Network::FilterFactoryCb diff --git a/source/extensions/filters/network/redis_proxy/BUILD b/source/extensions/filters/network/redis_proxy/BUILD index 44297a2d449ea..7dafa10e3da5a 100644 --- a/source/extensions/filters/network/redis_proxy/BUILD +++ b/source/extensions/filters/network/redis_proxy/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Redis proxy L4 network filter. Implements consistent hashing and observability for large redis # clusters. # Public docs: docs/root/configuration/network_filters/redis_proxy_filter.rst diff --git a/source/extensions/filters/network/redis_proxy/codec_impl.cc b/source/extensions/filters/network/redis_proxy/codec_impl.cc index 15ac012172774..20649b4927ba1 100644 --- a/source/extensions/filters/network/redis_proxy/codec_impl.cc +++ b/source/extensions/filters/network/redis_proxy/codec_impl.cc @@ -35,7 +35,7 @@ std::string RespValue::toString() const { return std::to_string(asInteger()); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } std::vector& RespValue::asArray() { diff --git a/source/extensions/filters/network/redis_proxy/config.h b/source/extensions/filters/network/redis_proxy/config.h index a499edc6acd1f..51562452b0bb7 100644 --- a/source/extensions/filters/network/redis_proxy/config.h +++ b/source/extensions/filters/network/redis_proxy/config.h @@ -19,7 +19,7 @@ namespace RedisProxy { class RedisProxyFilterConfigFactory : public Common::FactoryBase { public: - RedisProxyFilterConfigFactory() : FactoryBase(NetworkFilterNames::get().REDIS_PROXY) {} + RedisProxyFilterConfigFactory() : FactoryBase(NetworkFilterNames::get().RedisProxy) {} // NamedNetworkFilterConfigFactory Network::FilterFactoryCb diff --git a/source/extensions/filters/network/tcp_proxy/BUILD b/source/extensions/filters/network/tcp_proxy/BUILD index aed1bd07121d9..5767776482380 100644 --- a/source/extensions/filters/network/tcp_proxy/BUILD +++ b/source/extensions/filters/network/tcp_proxy/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # TCP proxy L4 network filter. # Public docs: docs/root/configuration/network_filters/tcp_proxy_filter.rst diff --git a/source/extensions/filters/network/tcp_proxy/config.h b/source/extensions/filters/network/tcp_proxy/config.h index 8e7deba8fc12d..e5664ed45a7f5 100644 --- a/source/extensions/filters/network/tcp_proxy/config.h +++ b/source/extensions/filters/network/tcp_proxy/config.h @@ -16,7 +16,7 @@ namespace TcpProxy { class ConfigFactory : public Common::FactoryBase { public: - ConfigFactory() : FactoryBase(NetworkFilterNames::get().TCP_PROXY) {} + ConfigFactory() : FactoryBase(NetworkFilterNames::get().TcpProxy) {} // NamedNetworkFilterConfigFactory Network::FilterFactoryCb diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 898824667a100..05264d60abac1 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -8,6 +8,17 @@ load( envoy_package() +envoy_cc_library( + name = "app_exception_lib", + srcs = ["app_exception_impl.cc"], + hdrs = ["app_exception_impl.h"], + deps = [ + ":protocol_interface", + "//include/envoy/buffer:buffer_interface", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "buffer_helper_lib", srcs = ["buffer_helper.cc"], @@ -24,12 +35,43 @@ envoy_cc_library( srcs = ["config.cc"], hdrs = ["config.h"], deps = [ - ":filter_lib", + ":conn_manager_lib", + ":decoder_lib", + ":protocol_lib", "//include/envoy/registry", - "//source/common/config:filter_json_lib", + "//source/common/config:utility_lib", "//source/extensions/filters/network:well_known_names", "//source/extensions/filters/network/common:factory_base_lib", - "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1:thrift_proxy_cc", + "//source/extensions/filters/network/thrift_proxy/filters:filter_config_interface", + "//source/extensions/filters/network/thrift_proxy/filters:well_known_names", + "//source/extensions/filters/network/thrift_proxy/router:router_lib", + "@envoy_api//envoy/config/filter/network/thrift_proxy/v2alpha1:thrift_proxy_cc", + ], +) + +envoy_cc_library( + name = "conn_manager_lib", + srcs = ["conn_manager.cc"], + hdrs = ["conn_manager.h"], + deps = [ + ":app_exception_lib", + ":decoder_lib", + ":protocol_converter_lib", + ":protocol_lib", + ":stats_lib", + ":transport_lib", + "//include/envoy/event:deferred_deletable", + "//include/envoy/event:dispatcher_interface", + "//include/envoy/network:connection_interface", + "//include/envoy/network:filter_interface", + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:timespan", + "//source/common/buffer:buffer_lib", + "//source/common/common:assert_lib", + "//source/common/common:linked_object", + "//source/common/common:logger_lib", + "//source/common/network:filter_lib", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", ], ) @@ -39,57 +81,99 @@ envoy_cc_library( hdrs = ["decoder.h"], deps = [ ":protocol_lib", + ":stats_lib", ":transport_lib", "//source/common/buffer:buffer_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) envoy_cc_library( - name = "filter_lib", - srcs = ["filter.cc"], - hdrs = ["filter.h"], + name = "protocol_converter_lib", + hdrs = [ + "protocol_converter.h", + ], deps = [ - ":decoder_lib", - "//include/envoy/network:connection_interface", - "//include/envoy/network:filter_interface", - "//include/envoy/stats:stats_interface", - "//include/envoy/stats:stats_macros", - "//include/envoy/stats:timespan", - "//source/common/buffer:buffer_lib", + ":protocol_interface", + "//include/envoy/buffer:buffer_interface", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + +envoy_cc_library( + name = "protocol_interface", + hdrs = [ + "protocol.h", + ], + external_deps = ["abseil_optional"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/registry", "//source/common/common:assert_lib", - "//source/common/common:logger_lib", - "//source/common/network:filter_lib", + "//source/common/config:utility_lib", + "//source/common/singleton:const_singleton", ], ) envoy_cc_library( name = "protocol_lib", srcs = [ - "binary_protocol.cc", - "compact_protocol.cc", - "protocol.cc", + "binary_protocol_impl.cc", + "compact_protocol_impl.cc", + "protocol_impl.cc", ], hdrs = [ - "binary_protocol.h", - "compact_protocol.h", - "protocol.h", + "binary_protocol_impl.h", + "compact_protocol_impl.h", + "protocol_impl.h", ], external_deps = ["abseil_optional"], deps = [ ":buffer_helper_lib", + ":protocol_interface", "//source/common/singleton:const_singleton", ], ) envoy_cc_library( - name = "transport_lib", - srcs = ["transport.cc"], + name = "stats_lib", + hdrs = ["stats.h"], + deps = [ + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:stats_macros", + ], +) + +envoy_cc_library( + name = "transport_interface", hdrs = ["transport.h"], + external_deps = ["abseil_optional"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/registry", + "//source/common/common:assert_lib", + "//source/common/config:utility_lib", + "//source/common/singleton:const_singleton", + ], +) + +envoy_cc_library( + name = "transport_lib", + srcs = [ + "framed_transport_impl.cc", + "transport_impl.cc", + "unframed_transport_impl.cc", + ], + hdrs = [ + "framed_transport_impl.h", + "transport_impl.h", + "unframed_transport_impl.h", + ], deps = [ ":buffer_helper_lib", ":protocol_lib", + ":transport_interface", "//source/common/common:assert_lib", - "//source/common/common:utility_lib", "//source/common/singleton:const_singleton", ], ) diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc new file mode 100644 index 0000000000000..65455c12b3609 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc @@ -0,0 +1,34 @@ +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +static const std::string TApplicationException = "TApplicationException"; +static const std::string MessageField = "message"; +static const std::string TypeField = "type"; +static const std::string StopField = ""; + +void AppException::encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) { + proto.writeMessageBegin(buffer, method_name_, ThriftProxy::MessageType::Exception, seq_id_); + proto.writeStructBegin(buffer, TApplicationException); + + proto.writeFieldBegin(buffer, MessageField, ThriftProxy::FieldType::String, 1); + proto.writeString(buffer, error_message_); + proto.writeFieldEnd(buffer); + + proto.writeFieldBegin(buffer, TypeField, ThriftProxy::FieldType::I32, 2); + proto.writeInt32(buffer, static_cast(type_)); + proto.writeFieldEnd(buffer); + + proto.writeFieldBegin(buffer, StopField, ThriftProxy::FieldType::Stop, 0); + + proto.writeStructEnd(buffer); + proto.writeMessageEnd(buffer); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.h b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h new file mode 100644 index 0000000000000..4a0335704100a --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h @@ -0,0 +1,44 @@ +#pragma once + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Thrift Application Exception types. + * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + */ +enum class AppExceptionType { + Unknown = 0, + UnknownMethod = 1, + InvalidMessageType = 2, + WrongMethodName = 3, + BadSequenceId = 4, + MissingResult = 5, + InternalError = 6, + ProtocolError = 7, + InvalidTransform = 8, + InvalidProtocol = 9, + UnsupportedClientType = 10, +}; + +struct AppException : public ThriftFilters::DirectResponse { + AppException(const absl::string_view method_name, int32_t seq_id, AppExceptionType type, + const std::string& error_message) + : method_name_(method_name), seq_id_(seq_id), type_(type), error_message_(error_message) {} + + void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) override; + + const std::string method_name_; + const int32_t seq_id_; + const AppExceptionType type_; + const std::string error_message_; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol.cc b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc similarity index 61% rename from source/extensions/filters/network/thrift_proxy/binary_protocol.cc rename to source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc index bf9e05dd1235c..ee5a734eda209 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol.cc +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc @@ -1,4 +1,6 @@ -#include "extensions/filters/network/thrift_proxy/binary_protocol.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" + +#include #include "envoy/common/exception.h" @@ -58,26 +60,22 @@ bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& msg_type = type; seq_id = BufferHelper::drainI32(buffer); - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } bool BinaryProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onMessageComplete(); return true; } bool BinaryProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& name) { UNREFERENCED_PARAMETER(buffer); name.clear(); // binary protocol does not transmit struct names - onStructBegin(absl::string_view(name)); return true; } bool BinaryProtocolImpl::readStructEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onStructEnd(); return true; } @@ -97,14 +95,17 @@ bool BinaryProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& n if (buffer.length() < 3) { return false; } - field_id = BufferHelper::peekI16(buffer, 1); + int16_t id = BufferHelper::peekI16(buffer, 1); + if (id < 0) { + throw EnvoyException(fmt::format("invalid binary protocol field id {}", id)); + } + field_id = id; buffer.drain(3); } name.clear(); // binary protocol does not transmit field names field_type = type; - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -261,6 +262,106 @@ bool BinaryProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& value return readString(buffer, value); } +void BinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name, + MessageType msg_type, int32_t seq_id) { + BufferHelper::writeU16(buffer, Magic); + BufferHelper::writeU16(buffer, static_cast(msg_type)); + writeString(buffer, name); + BufferHelper::writeI32(buffer, seq_id); +} + +void BinaryProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(buffer); +} + +void BinaryProtocolImpl::writeStructBegin(Buffer::Instance& buffer, const std::string& name) { + UNREFERENCED_PARAMETER(buffer); + UNREFERENCED_PARAMETER(name); +} + +void BinaryProtocolImpl::writeStructEnd(Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(buffer); +} + +void BinaryProtocolImpl::writeFieldBegin(Buffer::Instance& buffer, const std::string& name, + FieldType field_type, int16_t field_id) { + UNREFERENCED_PARAMETER(name); + + BufferHelper::writeI8(buffer, static_cast(field_type)); + if (field_type == FieldType::Stop) { + return; + } + + BufferHelper::writeI16(buffer, field_id); +} + +void BinaryProtocolImpl::writeFieldEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); } + +void BinaryProtocolImpl::writeMapBegin(Buffer::Instance& buffer, FieldType key_type, + FieldType value_type, uint32_t size) { + if (size > std::numeric_limits::max()) { + throw EnvoyException(fmt::format("illegal binary protocol map size {}", size)); + } + + BufferHelper::writeI8(buffer, static_cast(key_type)); + BufferHelper::writeI8(buffer, static_cast(value_type)); + BufferHelper::writeI32(buffer, static_cast(size)); +} + +void BinaryProtocolImpl::writeMapEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); } + +void BinaryProtocolImpl::writeListBegin(Buffer::Instance& buffer, FieldType elem_type, + uint32_t size) { + if (size > std::numeric_limits::max()) { + throw EnvoyException(fmt::format("illegal binary protocol list/set size {}", size)); + } + + BufferHelper::writeI8(buffer, static_cast(elem_type)); + BufferHelper::writeI32(buffer, static_cast(size)); +} + +void BinaryProtocolImpl::writeListEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); } + +void BinaryProtocolImpl::writeSetBegin(Buffer::Instance& buffer, FieldType elem_type, + uint32_t size) { + writeListBegin(buffer, elem_type, size); +} + +void BinaryProtocolImpl::writeSetEnd(Buffer::Instance& buffer) { writeListEnd(buffer); } + +void BinaryProtocolImpl::writeBool(Buffer::Instance& buffer, bool value) { + BufferHelper::writeI8(buffer, value ? 1 : 0); +} + +void BinaryProtocolImpl::writeByte(Buffer::Instance& buffer, uint8_t value) { + BufferHelper::writeI8(buffer, value); +} + +void BinaryProtocolImpl::writeInt16(Buffer::Instance& buffer, int16_t value) { + BufferHelper::writeI16(buffer, value); +} + +void BinaryProtocolImpl::writeInt32(Buffer::Instance& buffer, int32_t value) { + BufferHelper::writeI32(buffer, value); +} + +void BinaryProtocolImpl::writeInt64(Buffer::Instance& buffer, int64_t value) { + BufferHelper::writeI64(buffer, value); +} + +void BinaryProtocolImpl::writeDouble(Buffer::Instance& buffer, double value) { + BufferHelper::writeDouble(buffer, value); +} + +void BinaryProtocolImpl::writeString(Buffer::Instance& buffer, const std::string& value) { + BufferHelper::writeU32(buffer, value.length()); + buffer.add(value); +} + +void BinaryProtocolImpl::writeBinary(Buffer::Instance& buffer, const std::string& value) { + writeString(buffer, value); +} + bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) { // Minimum message length: @@ -296,10 +397,37 @@ bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::stri seq_id = BufferHelper::peekI32(buffer, 1); buffer.drain(5); - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } +void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name, + MessageType msg_type, int32_t seq_id) { + writeString(buffer, name); + BufferHelper::writeI8(buffer, static_cast(msg_type)); + BufferHelper::writeI32(buffer, seq_id); +} + +class BinaryProtocolConfigFactory : public ProtocolFactoryBase { +public: + BinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().BINARY) {} +}; + +/** + * Static registration for the binary protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + +class LaxBinaryProtocolConfigFactory : public ProtocolFactoryBase { +public: + LaxBinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().LAX_BINARY) {} +}; + +/** + * Static registration for the auto protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_lax_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol.h b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h similarity index 58% rename from source/extensions/filters/network/thrift_proxy/binary_protocol.h rename to source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h index 520a33e48bad7..e292d6cd036b9 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol.h +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h @@ -5,7 +5,7 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" -#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" namespace Envoy { namespace Extensions { @@ -16,12 +16,13 @@ namespace ThriftProxy { * BinaryProtocolImpl implements the Thrift Binary protocol with strict message encoding. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md */ -class BinaryProtocolImpl : public ProtocolImplBase { +class BinaryProtocolImpl : public Protocol { public: - BinaryProtocolImpl(ProtocolCallbacks& callbacks) : ProtocolImplBase(callbacks) {} + BinaryProtocolImpl() {} // Protocol const std::string& name() const override { return ProtocolNames::get().BINARY; } + ProtocolType type() const override { return ProtocolType::Binary; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; @@ -45,6 +46,29 @@ class BinaryProtocolImpl : public ProtocolImplBase { bool readDouble(Buffer::Instance& buffer, double& value) override; bool readString(Buffer::Instance& buffer, std::string& value) override; bool readBinary(Buffer::Instance& buffer, std::string& value) override; + void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, + int32_t seq_id) override; + void writeMessageEnd(Buffer::Instance& buffer) override; + void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override; + void writeStructEnd(Buffer::Instance& buffer) override; + void writeFieldBegin(Buffer::Instance& buffer, const std::string& name, FieldType field_type, + int16_t field_id) override; + void writeFieldEnd(Buffer::Instance& buffer) override; + void writeMapBegin(Buffer::Instance& buffer, FieldType key_type, FieldType value_type, + uint32_t size) override; + void writeMapEnd(Buffer::Instance& buffer) override; + void writeListBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) override; + void writeListEnd(Buffer::Instance& buffer) override; + void writeSetBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) override; + void writeSetEnd(Buffer::Instance& buffer) override; + void writeBool(Buffer::Instance& buffer, bool value) override; + void writeByte(Buffer::Instance& buffer, uint8_t value) override; + void writeInt16(Buffer::Instance& buffer, int16_t value) override; + void writeInt32(Buffer::Instance& buffer, int32_t value) override; + void writeInt64(Buffer::Instance& buffer, int64_t value) override; + void writeDouble(Buffer::Instance& buffer, double value) override; + void writeString(Buffer::Instance& buffer, const std::string& value) override; + void writeBinary(Buffer::Instance& buffer, const std::string& value) override; static bool isMagic(uint16_t word) { return word == Magic; } @@ -58,12 +82,14 @@ class BinaryProtocolImpl : public ProtocolImplBase { */ class LaxBinaryProtocolImpl : public BinaryProtocolImpl { public: - LaxBinaryProtocolImpl(ProtocolCallbacks& callbacks) : BinaryProtocolImpl(callbacks) {} + LaxBinaryProtocolImpl() {} const std::string& name() const override { return ProtocolNames::get().LAX_BINARY; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; + void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, + int32_t seq_id) override; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/buffer_helper.cc b/source/extensions/filters/network/thrift_proxy/buffer_helper.cc index 5badb4e1001eb..d8724af4b1181 100644 --- a/source/extensions/filters/network/thrift_proxy/buffer_helper.cc +++ b/source/extensions/filters/network/thrift_proxy/buffer_helper.cc @@ -221,6 +221,95 @@ int32_t BufferHelper::peekZigZagI32(Buffer::Instance& buffer, uint64_t offset, i return (zz32 >> 1) ^ static_cast(-static_cast(zz32 & 1)); } +void BufferHelper::writeI8(Buffer::Instance& buffer, int8_t value) { buffer.add(&value, 1); } + +void BufferHelper::writeI16(Buffer::Instance& buffer, int16_t value) { + value = htobe16(value); + buffer.add(&value, 2); +} + +void BufferHelper::writeU16(Buffer::Instance& buffer, uint16_t value) { + value = htobe16(value); + buffer.add(&value, 2); +} + +void BufferHelper::writeI32(Buffer::Instance& buffer, int32_t value) { + value = htobe32(value); + buffer.add(&value, 4); +} + +void BufferHelper::writeU32(Buffer::Instance& buffer, uint32_t value) { + value = htobe32(value); + buffer.add(&value, 4); +} + +void BufferHelper::writeI64(Buffer::Instance& buffer, int64_t value) { + value = htobe64(value); + buffer.add(&value, 8); +} + +void BufferHelper::writeDouble(Buffer::Instance& buffer, double value) { + static_assert(sizeof(double) == sizeof(uint64_t), "sizeof(double) != sizeof(uint64_t)"); + static_assert(std::numeric_limits::is_iec559, "non-IEC559 (IEEE 754) double"); + + // See drainDouble for implementation details. + uint64_t i; + std::memcpy(&i, &value, 8); + i = htobe64(i); + buffer.add(&i, 8); +} + +// Thrift's var int encoding is described in +// https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md +void BufferHelper::writeVarIntI32(Buffer::Instance& buffer, int32_t value) { + uint8_t bytes[5]; + uint32_t v = static_cast(value); + int pos = 0; + while (pos < 5) { + if ((v & ~0x7F) == 0) { + bytes[pos++] = static_cast(v); + break; + } + + bytes[pos++] = static_cast(v & 0x7F) | 0x80; + v >>= 7; + } + ASSERT(v < 0x80); + ASSERT(pos <= 5); + + buffer.add(bytes, pos); +} + +void BufferHelper::writeVarIntI64(Buffer::Instance& buffer, int64_t value) { + uint8_t bytes[10]; + uint64_t v = static_cast(value); + int pos = 0; + while (pos < 10) { + if ((v & ~0x7F) == 0) { + bytes[pos++] = static_cast(v); + break; + } + + bytes[pos++] = static_cast(v & 0x7F) | 0x80; + v >>= 7; + } + + ASSERT(v < 0x80); + ASSERT(pos <= 10); + + buffer.add(bytes, pos); +} + +void BufferHelper::writeZigZagI32(Buffer::Instance& buffer, int32_t value) { + uint32_t zz32 = (static_cast(value) << 1) ^ (value >> 31); + writeVarIntI32(buffer, zz32); +} + +void BufferHelper::writeZigZagI64(Buffer::Instance& buffer, int64_t value) { + uint64_t zz64 = (static_cast(value) << 1) ^ (value >> 63); + writeVarIntI64(buffer, zz64); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/buffer_helper.h b/source/extensions/filters/network/thrift_proxy/buffer_helper.h index 162f9cafa9810..c4945cd5da6b5 100644 --- a/source/extensions/filters/network/thrift_proxy/buffer_helper.h +++ b/source/extensions/filters/network/thrift_proxy/buffer_helper.h @@ -10,52 +10,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/** - * BufferWrapper provides a partial implementation of Buffer::Instance that is sufficient for - * BufferHelper to read Thrift protocol data without draining the buffer's contents. - */ -class BufferWrapper : public Buffer::Instance { -public: - BufferWrapper(Buffer::Instance& underlying) : underlying_(underlying) {} - - uint64_t position() { return position_; } - - // Buffer::Instance - void copyOut(size_t start, uint64_t size, void* data) const override { - ASSERT(position_ + start + size <= underlying_.length()); - underlying_.copyOut(start + position_, size, data); - } - void drain(uint64_t size) override { - ASSERT(position_ + size <= underlying_.length()); - position_ += size; - } - uint64_t length() const override { - ASSERT(underlying_.length() >= position_); - return underlying_.length() - position_; - } - void* linearize(uint32_t size) override { - ASSERT(position_ + size <= underlying_.length()); - uint8_t* p = static_cast(underlying_.linearize(position_ + size)); - return p + position_; - } - void add(const void*, uint64_t) override { NOT_IMPLEMENTED; } - void addBufferFragment(Buffer::BufferFragment&) override { NOT_IMPLEMENTED; } - void add(const std::string&) override { NOT_IMPLEMENTED; } - void add(const Buffer::Instance&) override { NOT_IMPLEMENTED; } - void commit(Buffer::RawSlice*, uint64_t) override { NOT_IMPLEMENTED; } - uint64_t getRawSlices(Buffer::RawSlice*, uint64_t) const override { NOT_IMPLEMENTED; } - void move(Buffer::Instance&) override { NOT_IMPLEMENTED; } - void move(Buffer::Instance&, uint64_t) override { NOT_IMPLEMENTED; } - int read(int, uint64_t) override { NOT_IMPLEMENTED; } - uint64_t reserve(uint64_t, Buffer::RawSlice*, uint64_t) override { NOT_IMPLEMENTED; } - ssize_t search(const void*, uint64_t, size_t) const override { NOT_IMPLEMENTED; } - int write(int) override { NOT_IMPLEMENTED; } - -private: - Buffer::Instance& underlying_; - uint64_t position_{0}; -}; - /** * BufferHelper provides buffer operations for reading bytes and numbers in the various encodings * used by Thrift protocols. @@ -211,6 +165,83 @@ class BufferHelper { */ static int32_t peekZigZagI32(Buffer::Instance& buffer, uint64_t offset, int& size); + /** + * Writes an int8_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int8_t to write + */ + static void writeI8(Buffer::Instance& buffer, int8_t value); + + /** + * Writes an int16_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int16_t to write + */ + static void writeI16(Buffer::Instance& buffer, int16_t value); + + /** + * Writes an uint16_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the uint16_t to write + */ + static void writeU16(Buffer::Instance& buffer, uint16_t value); + + /** + * Writes an int32_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int32_t to write + */ + static void writeI32(Buffer::Instance& buffer, int32_t value); + + /** + * Writes an uint32_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the uint32_t to write + */ + static void writeU32(Buffer::Instance& buffer, uint32_t value); + + /** + * Writes an int64_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int64_t to write + */ + static void writeI64(Buffer::Instance& buffer, int64_t value); + + /** + * Writes a double to the buffer. + * @param buffer Buffer::Instance written to + * @param value the double to write + */ + static void writeDouble(Buffer::Instance& buffer, double value); + + /** + * Writes a var-int encoded int32_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int32_t to write + */ + static void writeVarIntI32(Buffer::Instance& buffer, int32_t value); + + /** + * Writes a var-int encoded int64_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int64_t to write + */ + static void writeVarIntI64(Buffer::Instance& buffer, int64_t value); + + /** + * Writes a zig-zag encoded int32_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int32_t to write + */ + static void writeZigZagI32(Buffer::Instance& buffer, int32_t value); + + /** + * Writes a zig-zag encoded int64_t to the buffer. + * @param buffer Buffer::Instance written to + * @param value the int64_t to write + */ + static void writeZigZagI64(Buffer::Instance& buffer, int64_t value); + private: /** * Peeks at a variable-length int of up to 64 bits at offset. Updates size to indicate how many diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol.cc b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc similarity index 53% rename from source/extensions/filters/network/thrift_proxy/compact_protocol.cc rename to source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc index 527be64b76313..417a80d8b6197 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol.cc +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc @@ -1,4 +1,6 @@ -#include "extensions/filters/network/thrift_proxy/compact_protocol.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" + +#include #include "envoy/common/exception.h" @@ -70,13 +72,11 @@ bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string msg_type = type; seq_id = id; - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } bool CompactProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onMessageComplete(); return true; } @@ -89,7 +89,6 @@ bool CompactProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& last_field_id_stack_.push(last_field_id_); last_field_id_ = 0; - onStructBegin(absl::string_view(name)); return true; } @@ -103,7 +102,6 @@ bool CompactProtocolImpl::readStructEnd(Buffer::Instance& buffer) { last_field_id_ = last_field_id_stack_.top(); last_field_id_stack_.pop(); - onStructEnd(); return true; } @@ -122,12 +120,11 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& field_type = FieldType::Stop; buffer.drain(1); - onStructField(absl::string_view(name), field_type, field_id); return true; } int16_t compact_field_id; - uint8_t compact_field_type; + CompactFieldType compact_field_type; int id_size = 0; if ((delta_and_type >> 4) == 0) { // Field ID delta is zero: this is a long-form field header, followed by zig-zag field id. @@ -140,22 +137,22 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& return false; } - if (id <= 0 || id > INT16_MAX) { + if (id < 0 || id > std::numeric_limits::max()) { throw EnvoyException(fmt::format("invalid compact protocol field id {}", id)); } - compact_field_type = delta_and_type; + compact_field_type = static_cast(delta_and_type); compact_field_id = static_cast(id); } else { // Short form field header: 4 bits of field id delta, 4 bits of field type. - compact_field_type = delta_and_type & 0x0F; + compact_field_type = static_cast(delta_and_type & 0x0F); compact_field_id = last_field_id_ + static_cast(delta_and_type >> 4); } field_type = convertCompactFieldType(compact_field_type); // For simple fields, boolean values are transmitted as a type with no further data. if (field_type == FieldType::Bool) { - bool_value_ = compact_field_type == 1; + bool_value_ = compact_field_type == CompactFieldType::BoolTrue; } name.clear(); // compact protocol does not transmit field names @@ -164,43 +161,9 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& buffer.drain(id_size + 1); - onStructField(absl::string_view(name), field_type, field_id); return true; } -FieldType CompactProtocolImpl::convertCompactFieldType(uint8_t compact_field_type) { - switch (compact_field_type) { - case 0: - return FieldType::Stop; - case 1: - return FieldType::Bool; - case 2: - return FieldType::Bool; - case 3: - return FieldType::Byte; - case 4: - return FieldType::I16; - case 5: - return FieldType::I32; - case 6: - return FieldType::I64; - case 7: - return FieldType::Double; - case 8: - return FieldType::String; - case 9: - return FieldType::List; - case 10: - return FieldType::Set; - case 11: - return FieldType::Map; - case 12: - return FieldType::Struct; - default: - throw EnvoyException(fmt::format("unknown compact protocol field type {}", compact_field_type)); - } -} - bool CompactProtocolImpl::readFieldEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); bool_value_.reset(); @@ -232,8 +195,8 @@ bool CompactProtocolImpl::readMapBegin(Buffer::Instance& buffer, FieldType& key_ } uint8_t types = BufferHelper::peekI8(buffer, s_size); - FieldType ktype = convertCompactFieldType(types >> 4); - FieldType vtype = convertCompactFieldType(types & 0xF); + FieldType ktype = convertCompactFieldType(static_cast(types >> 4)); + FieldType vtype = convertCompactFieldType(static_cast(types & 0xF)); // Drain the size and the types byte. buffer.drain(s_size + 1); @@ -278,7 +241,7 @@ bool CompactProtocolImpl::readListBegin(Buffer::Instance& buffer, FieldType& ele sz = static_cast(s); } - elem_type = convertCompactFieldType(size_and_type & 0x0F); + elem_type = convertCompactFieldType(static_cast(size_and_type & 0x0F)); size = sz; buffer.drain(s_size + 1); @@ -332,7 +295,7 @@ bool CompactProtocolImpl::readInt16(Buffer::Instance& buffer, int16_t& value) { return false; } - if (i < INT16_MIN || i > INT16_MAX) { + if (i < std::numeric_limits::min() || i > std::numeric_limits::max()) { throw EnvoyException(fmt::format("compact protocol i16 exceeds allowable range {}", i)); } @@ -390,7 +353,7 @@ bool CompactProtocolImpl::readString(Buffer::Instance& buffer, std::string& valu } int len_size; - int32_t str_len = BufferHelper::peekZigZagI32(buffer, 0, len_size); + int32_t str_len = BufferHelper::peekVarIntI32(buffer, 0, len_size); if (len_size < 0) { return false; } @@ -419,6 +382,252 @@ bool CompactProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& valu return readString(buffer, value); } +void CompactProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name, + MessageType msg_type, int32_t seq_id) { + UNREFERENCED_PARAMETER(name); + + uint16_t ptv = (Magic & MagicMask) | (static_cast(msg_type) << 5); + ASSERT((ptv & MagicMask) == Magic); + ASSERT((ptv & ~MagicMask) >> 5 == static_cast(msg_type)); + + BufferHelper::writeU16(buffer, ptv); + BufferHelper::writeVarIntI32(buffer, seq_id); + writeString(buffer, name); +} + +void CompactProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(buffer); +} + +void CompactProtocolImpl::writeStructBegin(Buffer::Instance& buffer, const std::string& name) { + UNREFERENCED_PARAMETER(buffer); + UNREFERENCED_PARAMETER(name); + + // Field ids are encoded as deltas specific to the field's containing struct. Field ids are + // tracked in a stack to handle nested structs. + last_field_id_stack_.push(last_field_id_); + last_field_id_ = 0; +} + +void CompactProtocolImpl::writeStructEnd(Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(buffer); + + if (last_field_id_stack_.empty()) { + throw EnvoyException("invalid write of compact protocol struct end"); + } + + last_field_id_ = last_field_id_stack_.top(); + last_field_id_stack_.pop(); +} + +void CompactProtocolImpl::writeFieldBegin(Buffer::Instance& buffer, const std::string& name, + FieldType field_type, int16_t field_id) { + UNREFERENCED_PARAMETER(name); + + if (field_type == FieldType::Stop) { + BufferHelper::writeI8(buffer, 0); + return; + } + + if (field_type == FieldType::Bool) { + bool_field_id_ = field_id; + return; + } + + writeFieldBeginInternal(buffer, field_type, field_id, {}); +} + +void CompactProtocolImpl::writeFieldBeginInternal( + Buffer::Instance& buffer, FieldType field_type, int16_t field_id, + absl::optional field_type_override) { + CompactFieldType compact_field_type; + if (field_type_override.has_value()) { + compact_field_type = field_type_override.value(); + } else { + compact_field_type = convertFieldType(field_type); + } + + if (field_id > last_field_id_ && field_id - last_field_id_ <= 15) { + // Encode short-form field header. + BufferHelper::writeI8(buffer, (static_cast(field_id - last_field_id_) << 4) | + static_cast(compact_field_type)); + } else { + BufferHelper::writeI8(buffer, static_cast(compact_field_type)); + BufferHelper::writeZigZagI32(buffer, static_cast(field_id)); + } + + last_field_id_ = field_id; +} + +void CompactProtocolImpl::writeFieldEnd(Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(buffer); + + bool_field_id_.reset(); +} + +void CompactProtocolImpl::writeMapBegin(Buffer::Instance& buffer, FieldType key_type, + FieldType value_type, uint32_t size) { + if (size > std::numeric_limits::max()) { + throw EnvoyException(fmt::format("illegal compact protocol map size {}", size)); + } + + BufferHelper::writeVarIntI32(buffer, static_cast(size)); + if (size == 0) { + return; + } + + CompactFieldType compact_key_type = convertFieldType(key_type); + CompactFieldType compact_value_type = convertFieldType(value_type); + BufferHelper::writeI8(buffer, (static_cast(compact_key_type) << 4) | + static_cast(compact_value_type)); +} + +void CompactProtocolImpl::writeMapEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); } + +void CompactProtocolImpl::writeListBegin(Buffer::Instance& buffer, FieldType elem_type, + uint32_t size) { + if (size > std::numeric_limits::max()) { + throw EnvoyException(fmt::format("illegal compact protocol list/set size {}", size)); + } + + CompactFieldType compact_elem_type = convertFieldType(elem_type); + + if (size < 0xF) { + // Short form list/set header + int8_t short_size = static_cast(size & 0xF); + BufferHelper::writeI8(buffer, (short_size << 4) | static_cast(compact_elem_type)); + } else { + BufferHelper::writeI8(buffer, 0xF0 | static_cast(compact_elem_type)); + BufferHelper::writeVarIntI32(buffer, static_cast(size)); + } +} + +void CompactProtocolImpl::writeListEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); } + +void CompactProtocolImpl::writeSetBegin(Buffer::Instance& buffer, FieldType elem_type, + uint32_t size) { + writeListBegin(buffer, elem_type, size); +} + +void CompactProtocolImpl::writeSetEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); } + +void CompactProtocolImpl::writeBool(Buffer::Instance& buffer, bool value) { + if (bool_field_id_.has_value()) { + // Boolean fields have their value encoded by type. + CompactFieldType bool_field_type = + value ? CompactFieldType::BoolTrue : CompactFieldType::BoolFalse; + writeFieldBeginInternal(buffer, FieldType::Bool, bool_field_id_.value(), {bool_field_type}); + return; + } + + // Map/Set/List booleans are encoded as bytes. + BufferHelper::writeI8(buffer, value ? 1 : 0); +} + +void CompactProtocolImpl::writeByte(Buffer::Instance& buffer, uint8_t value) { + BufferHelper::writeI8(buffer, value); +} + +void CompactProtocolImpl::writeInt16(Buffer::Instance& buffer, int16_t value) { + int32_t extended = static_cast(value); + BufferHelper::writeZigZagI32(buffer, extended); +} + +void CompactProtocolImpl::writeInt32(Buffer::Instance& buffer, int32_t value) { + BufferHelper::writeZigZagI32(buffer, value); +} + +void CompactProtocolImpl::writeInt64(Buffer::Instance& buffer, int64_t value) { + BufferHelper::writeZigZagI64(buffer, value); +} + +void CompactProtocolImpl::writeDouble(Buffer::Instance& buffer, double value) { + BufferHelper::writeDouble(buffer, value); +} + +void CompactProtocolImpl::writeString(Buffer::Instance& buffer, const std::string& value) { + BufferHelper::writeVarIntI32(buffer, value.length()); + buffer.add(value); +} + +void CompactProtocolImpl::writeBinary(Buffer::Instance& buffer, const std::string& value) { + writeString(buffer, value); +} + +FieldType CompactProtocolImpl::convertCompactFieldType(CompactFieldType compact_field_type) { + switch (compact_field_type) { + case CompactFieldType::BoolTrue: + return FieldType::Bool; + case CompactFieldType::BoolFalse: + return FieldType::Bool; + case CompactFieldType::Byte: + return FieldType::Byte; + case CompactFieldType::I16: + return FieldType::I16; + case CompactFieldType::I32: + return FieldType::I32; + case CompactFieldType::I64: + return FieldType::I64; + case CompactFieldType::Double: + return FieldType::Double; + case CompactFieldType::String: + return FieldType::String; + case CompactFieldType::List: + return FieldType::List; + case CompactFieldType::Set: + return FieldType::Set; + case CompactFieldType::Map: + return FieldType::Map; + case CompactFieldType::Struct: + return FieldType::Struct; + default: + throw EnvoyException(fmt::format("unknown compact protocol field type {}", + static_cast(compact_field_type))); + } +} + +CompactProtocolImpl::CompactFieldType CompactProtocolImpl::convertFieldType(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + // c.f. special handling in writeFieldBegin + return CompactFieldType::BoolTrue; + case FieldType::Byte: + return CompactFieldType::Byte; + case FieldType::I16: + return CompactFieldType::I16; + case FieldType::I32: + return CompactFieldType::I32; + case FieldType::I64: + return CompactFieldType::I64; + case FieldType::Double: + return CompactFieldType::Double; + case FieldType::String: + return CompactFieldType::String; + case FieldType::Struct: + return CompactFieldType::Struct; + case FieldType::Map: + return CompactFieldType::Map; + case FieldType::Set: + return CompactFieldType::Set; + case FieldType::List: + return CompactFieldType::List; + default: + throw EnvoyException( + fmt::format("unknown protocol field type {}", static_cast(field_type))); + } +} + +class CompactProtocolConfigFactory : public ProtocolFactoryBase { +public: + CompactProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().COMPACT) {} +}; + +/** + * Static registration for the binary protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol.h b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h similarity index 51% rename from source/extensions/filters/network/thrift_proxy/compact_protocol.h rename to source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h index 91d394999fe1d..322d03a3a83da 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol.h +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h @@ -6,7 +6,7 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" -#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" #include "absl/types/optional.h" @@ -19,12 +19,13 @@ namespace ThriftProxy { * CompactProtocolImpl implements the Thrift Compact protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md */ -class CompactProtocolImpl : public ProtocolImplBase { +class CompactProtocolImpl : public Protocol { public: - CompactProtocolImpl(ProtocolCallbacks& callbacks) : ProtocolImplBase(callbacks) {} + CompactProtocolImpl() {} // Protocol const std::string& name() const override { return ProtocolNames::get().COMPACT; } + ProtocolType type() const override { return ProtocolType::Compact; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; @@ -48,12 +49,54 @@ class CompactProtocolImpl : public ProtocolImplBase { bool readDouble(Buffer::Instance& buffer, double& value) override; bool readString(Buffer::Instance& buffer, std::string& value) override; bool readBinary(Buffer::Instance& buffer, std::string& value) override; + void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, + int32_t seq_id) override; + void writeMessageEnd(Buffer::Instance& buffer) override; + void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override; + void writeStructEnd(Buffer::Instance& buffer) override; + void writeFieldBegin(Buffer::Instance& buffer, const std::string& name, FieldType field_type, + int16_t field_id) override; + void writeFieldEnd(Buffer::Instance& buffer) override; + void writeMapBegin(Buffer::Instance& buffer, FieldType key_type, FieldType value_type, + uint32_t size) override; + void writeMapEnd(Buffer::Instance& buffer) override; + void writeListBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) override; + void writeListEnd(Buffer::Instance& buffer) override; + void writeSetBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) override; + void writeSetEnd(Buffer::Instance& buffer) override; + void writeBool(Buffer::Instance& buffer, bool value) override; + void writeByte(Buffer::Instance& buffer, uint8_t value) override; + void writeInt16(Buffer::Instance& buffer, int16_t value) override; + void writeInt32(Buffer::Instance& buffer, int32_t value) override; + void writeInt64(Buffer::Instance& buffer, int64_t value) override; + void writeDouble(Buffer::Instance& buffer, double value) override; + void writeString(Buffer::Instance& buffer, const std::string& value) override; + void writeBinary(Buffer::Instance& buffer, const std::string& value) override; static bool isMagic(uint16_t word) { return (word & MagicMask) == Magic; } private: - // Translates compact field type IDs to FieldType. - FieldType convertCompactFieldType(uint8_t compact_field_type); + enum class CompactFieldType { + Stop = 0, + BoolTrue = 1, + BoolFalse = 2, + Byte = 3, + I16 = 4, + I32 = 5, + I64 = 6, + Double = 7, + String = 8, + List = 9, + Set = 10, + Map = 11, + Struct = 12, + }; + + FieldType convertCompactFieldType(CompactFieldType compact_field_type); + CompactFieldType convertFieldType(FieldType field_type); + + void writeFieldBeginInternal(Buffer::Instance& buffer, FieldType field_type, int16_t field_id, + absl::optional field_type_override); std::stack last_field_id_stack_{}; int16_t last_field_id_{0}; @@ -62,6 +105,9 @@ class CompactProtocolImpl : public ProtocolImplBase { // This tracks the last boolean struct field's value for readBool. absl::optional bool_value_{}; + // Similarly, track the field id for writeBool. + absl::optional bool_field_id_{}; + const static uint16_t Magic; const static uint16_t MagicMask; }; diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index 49f169eee4815..22ce4fbf3e163 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -1,26 +1,79 @@ #include "extensions/filters/network/thrift_proxy/config.h" +#include #include #include "envoy/network/connection.h" #include "envoy/registry/registry.h" -#include "extensions/filters/network/thrift_proxy/filter.h" +#include "common/config/utility.h" + +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter_config.h" +#include "extensions/filters/network/thrift_proxy/filters/well_known_names.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/stats.h" +#include "extensions/filters/network/thrift_proxy/transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +namespace { + +typedef std::map + TransportTypeMap; + +static const TransportTypeMap& transportTypeMap() { + CONSTRUCT_ON_FIRST_USE( + TransportTypeMap, + { + {envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_AUTO_TRANSPORT, + TransportType::Auto}, + {envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType_FRAMED, + TransportType::Framed}, + {envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_UNFRAMED, + TransportType::Unframed}, + }); +} + +typedef std::map + ProtocolTypeMap; + +static const ProtocolTypeMap& protocolTypeMap() { + CONSTRUCT_ON_FIRST_USE( + ProtocolTypeMap, + { + {envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_ProtocolType_AUTO_PROTOCOL, + ProtocolType::Auto}, + {envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType_BINARY, + ProtocolType::Binary}, + {envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_ProtocolType_LAX_BINARY, + ProtocolType::LaxBinary}, + {envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType_COMPACT, + ProtocolType::Compact}, + }); +} + +} // namespace Network::FilterFactoryCb ThriftProxyFilterConfigFactory::createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& proto_config, + const envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy& proto_config, Server::Configuration::FactoryContext& context) { - ASSERT(!proto_config.stat_prefix().empty()); - - const std::string stat_prefix = fmt::format("thrift.{}.", proto_config.stat_prefix()); + std::shared_ptr filter_config(new ConfigImpl(proto_config, context)); - return [stat_prefix, &context](Network::FilterManager& filter_manager) -> void { - filter_manager.addFilter(std::make_shared(stat_prefix, context.scope())); + return [filter_config](Network::FilterManager& filter_manager) -> void { + filter_manager.addReadFilter(std::make_shared(*filter_config)); }; } @@ -31,6 +84,48 @@ static Registry::RegisterFactory registered_; +ConfigImpl::ConfigImpl( + const envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy& config, + Server::Configuration::FactoryContext& context) + : context_(context), stats_prefix_(fmt::format("thrift.{}.", config.stat_prefix())), + stats_(ThriftFilterStats::generateStats(stats_prefix_, context_.scope())), + transport_(config.transport()), proto_(config.protocol()), + route_matcher_(new Router::RouteMatcher(config.route_config())) { + + // Construct the only Thrift DecoderFilter: the Router + auto& factory = + Envoy::Config::Utility::getAndCheckFactory( + ThriftFilters::ThriftFilterNames::get().ROUTER); + ThriftFilters::FilterFactoryCb callback; + + auto empty_config = factory.createEmptyConfigProto(); + callback = factory.createFilterFactoryFromProto(*empty_config, stats_prefix_, context_); + filter_factories_.push_back(callback); +} + +void ConfigImpl::createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) { + for (const ThriftFilters::FilterFactoryCb& factory : filter_factories_) { + factory(callbacks); + } +} + +DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { + return std::make_unique(createTransport(), createProtocol(), callbacks); +} + +TransportPtr ConfigImpl::createTransport() { + TransportTypeMap::const_iterator i = transportTypeMap().find(transport_); + RELEASE_ASSERT(i != transportTypeMap().end(), "invalid transport type"); + + return NamedTransportConfigFactory::getFactory(i->second).createTransport(); +} + +ProtocolPtr ConfigImpl::createProtocol() { + ProtocolTypeMap::const_iterator i = protocolTypeMap().find(proto_); + RELEASE_ASSERT(i != protocolTypeMap().end(), "invalid protocol type"); + return NamedProtocolConfigFactory::getFactory(i->second).createProtocol(); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index d578bd96591ac..630edfdf00ce6 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -1,11 +1,16 @@ #pragma once +#include #include -#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" -#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.validate.h" +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.pb.validate.h" +#include "envoy/stats/stats.h" #include "extensions/filters/network/common/factory_base.h" +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" #include "extensions/filters/network/well_known_names.h" namespace Envoy { @@ -18,16 +23,52 @@ namespace ThriftProxy { */ class ThriftProxyFilterConfigFactory : public Common::FactoryBase< - envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy> { + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy> { public: - ThriftProxyFilterConfigFactory() : FactoryBase(NetworkFilterNames::get().THRIFT_PROXY) {} + ThriftProxyFilterConfigFactory() : FactoryBase(NetworkFilterNames::get().ThriftProxy) {} private: Network::FilterFactoryCb createFilterFactoryFromProtoTyped( - const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& proto_config, + const envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy& proto_config, Server::Configuration::FactoryContext& context) override; }; +class ConfigImpl : public Config, + public Router::Config, + public ThriftFilters::FilterChainFactory, + Logger::Loggable { +public: + ConfigImpl(const envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy& config, + Server::Configuration::FactoryContext& context); + + // ThriftFilters::FilterChainFactory + void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override; + + // Router::Config + Router::RouteConstSharedPtr route(const std::string& method_name) const override { + return route_matcher_->route(method_name); + } + + // Config + ThriftFilterStats& stats() override { return stats_; } + ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } + DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; + Router::Config& routerConfig() override { return *this; } + +private: + TransportPtr createTransport(); + ProtocolPtr createProtocol(); + + Server::Configuration::FactoryContext& context_; + const std::string stats_prefix_; + ThriftFilterStats stats_; + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType transport_; + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType proto_; + std::unique_ptr route_matcher_; + + std::list filter_factories_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc new file mode 100644 index 0000000000000..c94bbeefcec36 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -0,0 +1,290 @@ +#include "extensions/filters/network/thrift_proxy/conn_manager.h" + +#include "envoy/common/exception.h" +#include "envoy/event/dispatcher.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +ConnectionManager::ConnectionManager(Config& config) + : config_(config), stats_(config_.stats()), decoder_(config_.createDecoder(*this)) {} + +ConnectionManager::~ConnectionManager() {} + +Network::FilterStatus ConnectionManager::onData(Buffer::Instance& data, bool end_stream) { + UNREFERENCED_PARAMETER(end_stream); + + request_buffer_.move(data); + dispatch(); + + return Network::FilterStatus::StopIteration; +} + +void ConnectionManager::dispatch() { + if (stopped_) { + ENVOY_LOG(error, "thrift filter stopped"); + return; + } + + try { + bool underflow = false; + while (!underflow) { + ThriftFilters::FilterStatus status = decoder_->onData(request_buffer_, underflow); + if (status == ThriftFilters::FilterStatus::StopIteration) { + stopped_ = true; + break; + } + } + } catch (const EnvoyException& ex) { + ENVOY_LOG(error, "thrift error: {}", ex.what()); + stats_.request_decoding_error_.inc(); + + // Use the current rpc to send an error downstream, if possible. + rpcs_.front()->onError(ex.what()); + + resetAllRpcs(); + read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); + } +} + +void ConnectionManager::continueDecoding() { + stopped_ = false; + dispatch(); +} + +void ConnectionManager::doDeferredRpcDestroy(ConnectionManager::ActiveRpc& rpc) { + read_callbacks_->connection().dispatcher().deferredDelete(rpc.removeFromList(rpcs_)); +} + +void ConnectionManager::resetAllRpcs() { + while (!rpcs_.empty()) { + rpcs_.front()->onReset(); + } +} + +void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) { + read_callbacks_ = &callbacks; +} + +void ConnectionManager::onEvent(Network::ConnectionEvent event) { + if (!rpcs_.empty()) { + if (event == Network::ConnectionEvent::RemoteClose) { + stats_.cx_destroy_remote_with_active_rq_.inc(); + } else if (event == Network::ConnectionEvent::LocalClose) { + stats_.cx_destroy_local_with_active_rq_.inc(); + } + + resetAllRpcs(); + } +} + +ThriftFilters::DecoderFilter& ConnectionManager::newDecoderFilter() { + ENVOY_LOG(debug, "new decoder filter"); + + ActiveRpcPtr new_rpc(new ActiveRpc(*this)); + new_rpc->createFilterChain(); + new_rpc->moveIntoList(std::move(new_rpc), rpcs_); + + return **rpcs_.begin(); +} + +bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { + upstream_buffer_.move(data); + + bool underflow = false; + decoder_->onData(upstream_buffer_, underflow); + ASSERT(complete_ || underflow); + return complete_; +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::messageBegin(absl::string_view name, + MessageType msg_type, + int32_t seq_id) { + reply_.emplace(std::string(name), msg_type, seq_id); + first_reply_field_ = (msg_type == MessageType::Reply); + return ProtocolConverter::messageBegin(name, msg_type, seq_id); +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl::string_view name, + FieldType field_type, + int16_t field_id) { + if (first_reply_field_) { + // Reply messages contain a struct where field 0 is the call result and fields 1+ are + // exceptions, if defined. At most one field may be set. Therefore, the very first field we + // encounter in a reply is either field 0 (success) or not (IDL exception returned). + ASSERT(reply_.has_value()); + reply_.value().success_ = field_id == 0 && field_type != FieldType::Stop; + first_reply_field_ = false; + } + + return ProtocolConverter::fieldBegin(name, field_type, field_id); +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { + ConnectionManager& cm = parent_.parent_; + + Buffer::OwnedImpl buffer; + + // Use the factory to get the concrete transport from the decoder transport (as opposed to + // potentially pre-detection auto transport). + TransportPtr transport = + NamedTransportConfigFactory::getFactory(parent_.parent_.decoder_->transportType()) + .createTransport(); + transport->encodeFrame(buffer, parent_.response_buffer_); + complete_ = true; + + cm.read_callbacks_->connection().write(buffer, false); + + cm.stats_.response_.inc(); + + ASSERT(reply_.has_value()); + switch (reply_.value().msg_type_) { + case MessageType::Reply: + cm.stats_.response_reply_.inc(); + if (reply_.value().success_.value_or(false)) { + cm.stats_.response_success_.inc(); + } else { + cm.stats_.response_error_.inc(); + } + + break; + + case MessageType::Exception: + cm.stats_.response_exception_.inc(); + break; + + default: + cm.stats_.response_invalid_type_.inc(); + break; + } + + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus ConnectionManager::ActiveRpc::transportEnd() { + ASSERT(call_.has_value()); + + parent_.stats_.request_.inc(); + + switch (call_.value().msg_type_) { + case MessageType::Call: + parent_.stats_.request_call_.inc(); + break; + + case MessageType::Oneway: + parent_.stats_.request_oneway_.inc(); + + // No response forthcoming, we're done. + parent_.doDeferredRpcDestroy(*this); + break; + + default: + parent_.stats_.request_invalid_type_.inc(); + break; + } + + return decoder_filter_->transportEnd(); +} + +void ConnectionManager::ActiveRpc::createFilterChain() { + parent_.config_.filterFactory().createFilterChain(*this); +} + +void ConnectionManager::ActiveRpc::onReset() { + // TODO(zuercher): e.g., parent_.stats_.named_.downstream_rq_rx_reset_.inc(); + parent_.doDeferredRpcDestroy(*this); +} + +void ConnectionManager::ActiveRpc::onError(const std::string& what) { + if (call_.has_value()) { + const Message& msg = call_.value(); + sendLocalReply(std::make_unique(msg.method_name_, msg.seq_id_, + AppExceptionType::ProtocolError, what)); + return; + } + + // Transport or protocol error happened before (or during message begin) parsing. It's not + // possible to provide a valid response, so don't try. +} + +const Network::Connection* ConnectionManager::ActiveRpc::connection() const { + return &parent_.read_callbacks_->connection(); +} + +void ConnectionManager::ActiveRpc::continueDecoding() { parent_.continueDecoding(); } + +Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() { + if (!cached_route_) { + if (call_.has_value()) { + Router::RouteConstSharedPtr route = + parent_.config_.routerConfig().route(call_.value().method_name_); + cached_route_ = std::move(route); + } else { + cached_route_ = nullptr; + } + } + + return cached_route_.value(); +} + +void ConnectionManager::ActiveRpc::sendLocalReply(ThriftFilters::DirectResponsePtr&& response) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(parent_.decoder_->protocolType()).createProtocol(); + Buffer::OwnedImpl buffer; + + response->encode(*proto, buffer); + + // Same logic as protocol above. + TransportPtr transport = + NamedTransportConfigFactory::getFactory(parent_.decoder_->transportType()).createTransport(); + transport->encodeFrame(response_buffer_, buffer); + + parent_.read_callbacks_->connection().write(response_buffer_, false); + parent_.doDeferredRpcDestroy(*this); +} + +void ConnectionManager::ActiveRpc::startUpstreamResponse(TransportType transport_type, + ProtocolType protocol_type) { + ASSERT(response_decoder_ == nullptr); + + response_decoder_ = std::make_unique(*this, transport_type, protocol_type); +} + +bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { + ASSERT(response_decoder_ != nullptr); + + try { + bool complete = response_decoder_->onData(buffer); + if (complete) { + parent_.doDeferredRpcDestroy(*this); + } + return complete; + } catch (const EnvoyException& ex) { + ENVOY_LOG(error, "thrift response error: {}", ex.what()); + parent_.stats_.response_decoding_error_.inc(); + + onError(ex.what()); + decoder_filter_->resetUpstreamConnection(); + return true; + } +} + +void ConnectionManager::ActiveRpc::resetDownstreamConnection() { + parent_.read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + parent_.doDeferredRpcDestroy(*this); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h new file mode 100644 index 0000000000000..c366a40c0f2a5 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -0,0 +1,256 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/event/deferred_deletable.h" +#include "envoy/network/connection.h" +#include "envoy/network/filter.h" +#include "envoy/stats/stats.h" +#include "envoy/stats/timespan.h" + +#include "common/buffer/buffer_impl.h" +#include "common/common/linked_object.h" +#include "common/common/logger.h" + +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/protocol_converter.h" +#include "extensions/filters/network/thrift_proxy/stats.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Config is a configuration interface for ConnectionManager. + */ +class Config { +public: + virtual ~Config() {} + + virtual ThriftFilters::FilterChainFactory& filterFactory() PURE; + virtual ThriftFilterStats& stats() PURE; + virtual DecoderPtr createDecoder(DecoderCallbacks& callbacks) PURE; + virtual Router::Config& routerConfig() PURE; +}; + +/** + * ConnectionManager is a Network::Filter that will perform Thrift request handling on a connection. + */ +class ConnectionManager : public Network::ReadFilter, + public Network::ConnectionCallbacks, + public DecoderCallbacks, + Logger::Loggable { +public: + ConnectionManager(Config& config); + ~ConnectionManager(); + + // Network::ReadFilter + Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; + Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } + void initializeReadFilterCallbacks(Network::ReadFilterCallbacks&) override; + + // Network::ConnectionCallbacks + void onEvent(Network::ConnectionEvent) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + // DecoderCallbacks + ThriftFilters::DecoderFilter& newDecoderFilter() override; + +private: + class Message { + public: + Message(const std::string& method_name, MessageType msg_type, int32_t seq_id) + : method_name_(method_name), msg_type_(msg_type), seq_id_(seq_id) {} + + const std::string method_name_; + const MessageType msg_type_; + const int32_t seq_id_; + absl::optional success_; + }; + + struct ActiveRpc; + + struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { + ResponseDecoder(ActiveRpc& parent, TransportType transport_type, ProtocolType protocol_type) + : parent_(parent), + decoder_(std::make_unique( + NamedTransportConfigFactory::getFactory(transport_type).createTransport(), + NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(), *this)), + complete_(false), first_reply_field_(false) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + initProtocolConverter( + NamedProtocolConfigFactory::getFactory(parent_.parent_.decoder_->protocolType()) + .createProtocol(), + parent_.response_buffer_); + } + + bool onData(Buffer::Instance& data); + + // ProtocolConverter + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override; + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override; + ThriftFilters::FilterStatus transportBegin(absl::optional size) override { + UNREFERENCED_PARAMETER(size); + return ThriftFilters::FilterStatus::Continue; + } + ThriftFilters::FilterStatus transportEnd() override; + + // DecoderCallbacks + ThriftFilters::DecoderFilter& newDecoderFilter() override { return *this; } + + ActiveRpc& parent_; + DecoderPtr decoder_; + Buffer::OwnedImpl upstream_buffer_; + absl::optional reply_; + bool complete_ : 1; + bool first_reply_field_ : 1; + }; + typedef std::unique_ptr ResponseDecoderPtr; + + // ActiveRpc tracks request/response pairs. + struct ActiveRpc : LinkedObject, + public Event::DeferredDeletable, + public ThriftFilters::DecoderFilter, + public ThriftFilters::DecoderFilterCallbacks, + public ThriftFilters::FilterChainFactoryCallbacks { + ActiveRpc(ConnectionManager& parent) + : parent_(parent), request_timer_(new Stats::Timespan(parent_.stats_.request_time_ms_)), + stream_id_(parent_.stream_id_++) { + parent_.stats_.request_active_.inc(); + } + ~ActiveRpc() { + request_timer_->complete(); + parent_.stats_.request_active_.dec(); + + if (decoder_filter_ != nullptr) { + decoder_filter_->onDestroy(); + } + } + + // ThriftFilters::DecoderFilter + void onDestroy() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + ThriftFilters::FilterStatus transportBegin(absl::optional size) override { + return decoder_filter_->transportBegin(size); + } + ThriftFilters::FilterStatus transportEnd() override; + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override { + call_.emplace(std::string(name), msg_type, seq_id); + return decoder_filter_->messageBegin(name, msg_type, seq_id); + } + ThriftFilters::FilterStatus messageEnd() override { return decoder_filter_->messageEnd(); } + ThriftFilters::FilterStatus structBegin(absl::string_view name) override { + return decoder_filter_->structBegin(name); + } + ThriftFilters::FilterStatus structEnd() override { return decoder_filter_->structEnd(); } + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override { + return decoder_filter_->fieldBegin(name, field_type, field_id); + } + ThriftFilters::FilterStatus fieldEnd() override { return decoder_filter_->fieldEnd(); } + ThriftFilters::FilterStatus boolValue(bool value) override { + return decoder_filter_->boolValue(value); + } + ThriftFilters::FilterStatus byteValue(uint8_t value) override { + return decoder_filter_->byteValue(value); + } + ThriftFilters::FilterStatus int16Value(int16_t value) override { + return decoder_filter_->int16Value(value); + } + ThriftFilters::FilterStatus int32Value(int32_t value) override { + return decoder_filter_->int32Value(value); + } + ThriftFilters::FilterStatus int64Value(int64_t value) override { + return decoder_filter_->int64Value(value); + } + ThriftFilters::FilterStatus doubleValue(double value) override { + return decoder_filter_->doubleValue(value); + } + ThriftFilters::FilterStatus stringValue(absl::string_view value) override { + return decoder_filter_->stringValue(value); + } + ThriftFilters::FilterStatus mapBegin(FieldType key_type, FieldType value_type, + uint32_t size) override { + return decoder_filter_->mapBegin(key_type, value_type, size); + } + ThriftFilters::FilterStatus mapEnd() override { return decoder_filter_->mapEnd(); } + ThriftFilters::FilterStatus listBegin(FieldType elem_type, uint32_t size) override { + return decoder_filter_->listBegin(elem_type, size); + } + ThriftFilters::FilterStatus listEnd() override { return decoder_filter_->listEnd(); } + ThriftFilters::FilterStatus setBegin(FieldType elem_type, uint32_t size) override { + return decoder_filter_->setBegin(elem_type, size); + } + ThriftFilters::FilterStatus setEnd() override { return decoder_filter_->setEnd(); } + + // ThriftFilters::DecoderFilterCallbacks + uint64_t streamId() const override { return stream_id_; } + const Network::Connection* connection() const override; + void continueDecoding() override; + Router::RouteConstSharedPtr route() override; + TransportType downstreamTransportType() const override { + return parent_.decoder_->transportType(); + } + ProtocolType downstreamProtocolType() const override { + return parent_.decoder_->protocolType(); + } + void sendLocalReply(ThriftFilters::DirectResponsePtr&& response) override; + void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; + bool upstreamData(Buffer::Instance& buffer) override; + void resetDownstreamConnection() override; + + // Thrift::FilterChainFactoryCallbacks + void addDecoderFilter(ThriftFilters::DecoderFilterSharedPtr filter) override { + // TODO(zuercher): support multiple filters + filter->setDecoderFilterCallbacks(*this); + decoder_filter_ = filter; + } + + void createFilterChain(); + void onReset(); + void onError(const std::string& what); + + ConnectionManager& parent_; + Stats::TimespanPtr request_timer_; + uint64_t stream_id_; + ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ResponseDecoderPtr response_decoder_; + absl::optional cached_route_; + absl::optional call_; + Buffer::OwnedImpl response_buffer_; + }; + + typedef std::unique_ptr ActiveRpcPtr; + + void continueDecoding(); + void dispatch(); + void doDeferredRpcDestroy(ActiveRpc& rpc); + void resetAllRpcs(); + + Config& config_; + ThriftFilterStats& stats_; + + Network::ReadFilterCallbacks* read_callbacks_{}; + + DecoderPtr decoder_; + std::list rpcs_; + Buffer::OwnedImpl request_buffer_; + uint64_t stream_id_{1}; + bool stopped_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index d2cc0d898dcef..39aa8af6a2836 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -13,69 +13,72 @@ namespace NetworkFilters { namespace ThriftProxy { // MessageBegin -> StructBegin -ProtocolState DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { std::string message_name; MessageType msg_type; int32_t seq_id; if (!proto_.readMessageBegin(buffer, message_name, msg_type, seq_id)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.clear(); stack_.emplace_back(Frame(ProtocolState::MessageEnd)); - return ProtocolState::StructBegin; + return DecoderStatus(ProtocolState::StructBegin, + filter_.messageBegin(absl::string_view(message_name), msg_type, seq_id)); } // MessageEnd -> Done -ProtocolState DecoderStateMachine::messageEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::messageEnd(Buffer::Instance& buffer) { if (!proto_.readMessageEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return ProtocolState::Done; + return DecoderStatus(ProtocolState::Done, filter_.messageEnd()); } // StructBegin -> FieldBegin -ProtocolState DecoderStateMachine::structBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::structBegin(Buffer::Instance& buffer) { std::string name; if (!proto_.readStructBegin(buffer, name)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return ProtocolState::FieldBegin; + return DecoderStatus(ProtocolState::FieldBegin, filter_.structBegin(absl::string_view(name))); } // StructEnd -> stack's return state -ProtocolState DecoderStateMachine::structEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::structEnd(Buffer::Instance& buffer) { if (!proto_.readStructEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.structEnd()); } // FieldBegin -> FieldValue, or // FieldBegin -> StructEnd (stop field) -ProtocolState DecoderStateMachine::fieldBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldBegin(Buffer::Instance& buffer) { std::string name; FieldType field_type; int16_t field_id; if (!proto_.readFieldBegin(buffer, name, field_type, field_id)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } if (field_type == FieldType::Stop) { - return ProtocolState::StructEnd; + return DecoderStatus(ProtocolState::StructEnd, ThriftFilters::FilterStatus::Continue); } stack_.emplace_back(Frame(ProtocolState::FieldEnd, field_type)); - return ProtocolState::FieldValue; + return DecoderStatus(ProtocolState::FieldValue, + filter_.fieldBegin(absl::string_view(name), field_type, field_id)); } // FieldValue -> FieldEnd (via stack return state) -ProtocolState DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); @@ -83,36 +86,36 @@ ProtocolState DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { } // FieldEnd -> FieldBegin -ProtocolState DecoderStateMachine::fieldEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldEnd(Buffer::Instance& buffer) { if (!proto_.readFieldEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } popReturnState(); - return ProtocolState::FieldBegin; + return DecoderStatus(ProtocolState::FieldBegin, filter_.fieldEnd()); } // ListBegin -> ListValue -ProtocolState DecoderStateMachine::listBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listBegin(Buffer::Instance& buffer) { FieldType elem_type; uint32_t size; if (!proto_.readListBegin(buffer, elem_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::ListEnd, elem_type, size)); - return ProtocolState::ListValue; + return DecoderStatus(ProtocolState::ListValue, filter_.listBegin(elem_type, size)); } // ListValue -> ListValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // ListValue -> ListEnd -ProtocolState DecoderStateMachine::listValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } frame.remaining_--; @@ -120,34 +123,35 @@ ProtocolState DecoderStateMachine::listValue(Buffer::Instance& buffer) { } // ListEnd -> stack's return state -ProtocolState DecoderStateMachine::listEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listEnd(Buffer::Instance& buffer) { if (!proto_.readListEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.listEnd()); } // MapBegin -> MapKey -ProtocolState DecoderStateMachine::mapBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapBegin(Buffer::Instance& buffer) { FieldType key_type, value_type; uint32_t size; if (!proto_.readMapBegin(buffer, key_type, value_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::MapEnd, key_type, value_type, size)); - return ProtocolState::MapKey; + return DecoderStatus(ProtocolState::MapKey, filter_.mapBegin(key_type, value_type, size)); } // MapKey -> MapValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on key type), or // MapKey -> MapEnd -ProtocolState DecoderStateMachine::mapKey(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapKey(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } return handleValue(buffer, frame.elem_type_, ProtocolState::MapValue); @@ -155,7 +159,7 @@ ProtocolState DecoderStateMachine::mapKey(Buffer::Instance& buffer) { // MapValue -> MapKey, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // MapValue -> MapKey -ProtocolState DecoderStateMachine::mapValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); ASSERT(frame.remaining_ != 0); @@ -165,34 +169,35 @@ ProtocolState DecoderStateMachine::mapValue(Buffer::Instance& buffer) { } // MapEnd -> stack's return state -ProtocolState DecoderStateMachine::mapEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapEnd(Buffer::Instance& buffer) { if (!proto_.readMapEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.mapEnd()); } // SetBegin -> SetValue -ProtocolState DecoderStateMachine::setBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setBegin(Buffer::Instance& buffer) { FieldType elem_type; uint32_t size; if (!proto_.readSetBegin(buffer, elem_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::SetEnd, elem_type, size)); - return ProtocolState::SetValue; + return DecoderStatus(ProtocolState::SetValue, filter_.setBegin(elem_type, size)); } // SetValue -> SetValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // SetValue -> SetEnd -ProtocolState DecoderStateMachine::setValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } frame.remaining_--; @@ -200,85 +205,88 @@ ProtocolState DecoderStateMachine::setValue(Buffer::Instance& buffer) { } // SetEnd -> stack's return state -ProtocolState DecoderStateMachine::setEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setEnd(Buffer::Instance& buffer) { if (!proto_.readSetEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.setEnd()); } -ProtocolState DecoderStateMachine::handleValue(Buffer::Instance& buffer, FieldType elem_type, - ProtocolState return_state) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::handleValue(Buffer::Instance& buffer, + FieldType elem_type, + ProtocolState return_state) { switch (elem_type) { - case FieldType::Bool: - bool value; - if (!proto_.readBool(buffer, value)) { - return ProtocolState::WaitForData; + case FieldType::Bool: { + bool value{}; + if (proto_.readBool(buffer, value)) { + return DecoderStatus(return_state, filter_.boolValue(value)); } break; + } case FieldType::Byte: { - uint8_t value; - if (!proto_.readByte(buffer, value)) { - return ProtocolState::WaitForData; + uint8_t value{}; + if (proto_.readByte(buffer, value)) { + return DecoderStatus(return_state, filter_.byteValue(value)); } break; } case FieldType::I16: { - int16_t value; - if (!proto_.readInt16(buffer, value)) { - return ProtocolState::WaitForData; + int16_t value{}; + if (proto_.readInt16(buffer, value)) { + return DecoderStatus(return_state, filter_.int16Value(value)); } break; } case FieldType::I32: { - int32_t value; - if (!proto_.readInt32(buffer, value)) { - return ProtocolState::WaitForData; + int32_t value{}; + if (proto_.readInt32(buffer, value)) { + return DecoderStatus(return_state, filter_.int32Value(value)); } break; } case FieldType::I64: { - int64_t value; - if (!proto_.readInt64(buffer, value)) { - return ProtocolState::WaitForData; + int64_t value{}; + if (proto_.readInt64(buffer, value)) { + return DecoderStatus(return_state, filter_.int64Value(value)); } break; } case FieldType::Double: { - double value; - if (!proto_.readDouble(buffer, value)) { - return ProtocolState::WaitForData; + double value{}; + if (proto_.readDouble(buffer, value)) { + return DecoderStatus(return_state, filter_.doubleValue(value)); } break; } case FieldType::String: { std::string value; - if (!proto_.readString(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readString(buffer, value)) { + return DecoderStatus(return_state, filter_.stringValue(value)); } break; } case FieldType::Struct: stack_.emplace_back(Frame(return_state)); - return ProtocolState::StructBegin; + return DecoderStatus(ProtocolState::StructBegin, ThriftFilters::FilterStatus::Continue); case FieldType::Map: stack_.emplace_back(Frame(return_state)); - return ProtocolState::MapBegin; + return DecoderStatus(ProtocolState::MapBegin, ThriftFilters::FilterStatus::Continue); case FieldType::List: stack_.emplace_back(Frame(return_state)); - return ProtocolState::ListBegin; + return DecoderStatus(ProtocolState::ListBegin, ThriftFilters::FilterStatus::Continue); case FieldType::Set: stack_.emplace_back(Frame(return_state)); - return ProtocolState::SetBegin; + return DecoderStatus(ProtocolState::SetBegin, ThriftFilters::FilterStatus::Continue); default: throw EnvoyException(fmt::format("unknown field type {}", static_cast(elem_type))); } - return return_state; + return DecoderStatus(ProtocolState::WaitForData); } -ProtocolState DecoderStateMachine::handleState(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::handleState(Buffer::Instance& buffer) { switch (state_) { case ProtocolState::MessageBegin: return messageBegin(buffer); @@ -315,7 +323,7 @@ ProtocolState DecoderStateMachine::handleState(Buffer::Instance& buffer) { case ProtocolState::MessageEnd: return messageEnd(buffer); default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -328,61 +336,97 @@ ProtocolState DecoderStateMachine::popReturnState() { ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { while (state_ != ProtocolState::Done) { - ProtocolState s = handleState(buffer); - if (s == ProtocolState::WaitForData) { - return s; + DecoderStatus s = handleState(buffer); + if (s.next_state_ == ProtocolState::WaitForData) { + return ProtocolState::WaitForData; } - state_ = s; + state_ = s.next_state_; + + ASSERT(s.filter_status_.has_value()); + if (s.filter_status_.value() == ThriftFilters::FilterStatus::StopIteration) { + return ProtocolState::StopIteration; + } } return state_; } -Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol) - : transport_(std::move(transport)), protocol_(std::move(protocol)), state_machine_{}, - frame_started_(false) {} +Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks) + : transport_(std::move(transport)), protocol_(std::move(protocol)), callbacks_(callbacks) {} + +void Decoder::complete() { + request_.reset(); + state_machine_ = nullptr; + frame_started_ = false; + frame_ended_ = false; +} -void Decoder::onData(Buffer::Instance& data) { +ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ENVOY_LOG(debug, "thrift: {} bytes available", data.length()); + buffer_underflow = false; - while (true) { - if (!frame_started_) { - // Look for start of next frame. - if (!transport_->decodeFrameStart(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); - return; - } - ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); - - frame_started_ = true; - state_machine_ = std::make_unique(*protocol_); - } + if (frame_ended_) { + // Continuation after filter stopped iteration on transportComplete callback. + complete(); + buffer_underflow = (data.length() == 0); + return ThriftFilters::FilterStatus::Continue; + } - ASSERT(state_machine_ != nullptr); + if (!frame_started_) { + // Look for start of next frame. + absl::optional size{}; + if (!transport_->decodeFrameStart(data, size)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } + ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); - ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), - ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); + request_ = std::make_unique(callbacks_.newDecoderFilter()); + frame_started_ = true; + state_machine_ = std::make_unique(*protocol_, request_->filter_); - ProtocolState rv = state_machine_->run(data); - if (rv == ProtocolState::WaitForData) { - ENVOY_LOG(debug, "thrift: wait for data"); - return; + if (request_->filter_.transportBegin(size) == ThriftFilters::FilterStatus::StopIteration) { + return ThriftFilters::FilterStatus::StopIteration; } + } - ASSERT(rv == ProtocolState::Done); + ASSERT(state_machine_ != nullptr); - // Message complete, get decode end of frame. - if (!transport_->decodeFrameEnd(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); - return; - } - ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), + ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); - // Reset for next frame. - state_machine_ = nullptr; - frame_started_ = false; + ProtocolState rv = state_machine_->run(data); + if (rv == ProtocolState::WaitForData) { + ENVOY_LOG(debug, "thrift: wait for data"); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } else if (rv == ProtocolState::StopIteration) { + ENVOY_LOG(debug, "thrift: wait for continuation"); + return ThriftFilters::FilterStatus::StopIteration; } + + ASSERT(rv == ProtocolState::Done); + + // Message complete, decode end of frame. + if (!transport_->decodeFrameEnd(data)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } + + frame_ended_ = true; + + ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + if (request_->filter_.transportEnd() == ThriftFilters::FilterStatus::StopIteration) { + return ThriftFilters::FilterStatus::StopIteration; + } + + // Reset for next frame. + complete(); + buffer_underflow = (data.length() == 0); + return ThriftFilters::FilterStatus::Continue; } } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index a05d70cee30c8..26068858beb6d 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -6,6 +6,7 @@ #include "common/common/assert.h" #include "common/common/logger.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/protocol.h" #include "extensions/filters/network/thrift_proxy/transport.h" @@ -15,6 +16,7 @@ namespace NetworkFilters { namespace ThriftProxy { #define ALL_PROTOCOL_STATES(FUNCTION) \ + FUNCTION(StopIteration) \ FUNCTION(WaitForData) \ FUNCTION(MessageBegin) \ FUNCTION(MessageEnd) \ @@ -61,7 +63,8 @@ class ProtocolStateNameValues { */ class DecoderStateMachine { public: - DecoderStateMachine(Protocol& proto) : proto_(proto), state_(ProtocolState::MessageBegin) {} + DecoderStateMachine(Protocol& proto, ThriftFilters::DecoderFilter& filter) + : proto_(proto), filter_(filter), state_(ProtocolState::MessageBegin) {} /** * Consumes as much data from the configured Buffer as possible and executes the decoding state @@ -114,72 +117,107 @@ class DecoderStateMachine { uint32_t remaining_; }; + struct DecoderStatus { + DecoderStatus(ProtocolState next_state) : next_state_(next_state), filter_status_{} {}; + DecoderStatus(ProtocolState next_state, ThriftFilters::FilterStatus filter_status) + : next_state_(next_state), filter_status_(filter_status){}; + + ProtocolState next_state_; + absl::optional filter_status_; + }; + // These functions map directly to the matching ProtocolState values. Each returns the next state // or ProtocolState::WaitForData if more data is required. - ProtocolState messageBegin(Buffer::Instance& buffer); - ProtocolState messageEnd(Buffer::Instance& buffer); - ProtocolState structBegin(Buffer::Instance& buffer); - ProtocolState structEnd(Buffer::Instance& buffer); - ProtocolState fieldBegin(Buffer::Instance& buffer); - ProtocolState fieldValue(Buffer::Instance& buffer); - ProtocolState fieldEnd(Buffer::Instance& buffer); - ProtocolState listBegin(Buffer::Instance& buffer); - ProtocolState listValue(Buffer::Instance& buffer); - ProtocolState listEnd(Buffer::Instance& buffer); - ProtocolState mapBegin(Buffer::Instance& buffer); - ProtocolState mapKey(Buffer::Instance& buffer); - ProtocolState mapValue(Buffer::Instance& buffer); - ProtocolState mapEnd(Buffer::Instance& buffer); - ProtocolState setBegin(Buffer::Instance& buffer); - ProtocolState setValue(Buffer::Instance& buffer); - ProtocolState setEnd(Buffer::Instance& buffer); + DecoderStatus messageBegin(Buffer::Instance& buffer); + DecoderStatus messageEnd(Buffer::Instance& buffer); + DecoderStatus structBegin(Buffer::Instance& buffer); + DecoderStatus structEnd(Buffer::Instance& buffer); + DecoderStatus fieldBegin(Buffer::Instance& buffer); + DecoderStatus fieldValue(Buffer::Instance& buffer); + DecoderStatus fieldEnd(Buffer::Instance& buffer); + DecoderStatus listBegin(Buffer::Instance& buffer); + DecoderStatus listValue(Buffer::Instance& buffer); + DecoderStatus listEnd(Buffer::Instance& buffer); + DecoderStatus mapBegin(Buffer::Instance& buffer); + DecoderStatus mapKey(Buffer::Instance& buffer); + DecoderStatus mapValue(Buffer::Instance& buffer); + DecoderStatus mapEnd(Buffer::Instance& buffer); + DecoderStatus setBegin(Buffer::Instance& buffer); + DecoderStatus setValue(Buffer::Instance& buffer); + DecoderStatus setEnd(Buffer::Instance& buffer); // handleValue represents the generic Value state from the state machine documentation. It // returns either ProtocolState::WaitForData if more data is required or the next state. For // structs, lists, maps, or sets the return_state is pushed onto the stack and the next state is // based on elem_type. For primitive value types, return_state is returned as the next state // (unless WaitForData is returned). - ProtocolState handleValue(Buffer::Instance& buffer, FieldType elem_type, + DecoderStatus handleValue(Buffer::Instance& buffer, FieldType elem_type, ProtocolState return_state); // handleState delegates to the appropriate method based on state_. - ProtocolState handleState(Buffer::Instance& buffer); + DecoderStatus handleState(Buffer::Instance& buffer); // Helper method to retrieve the current frame's return state and remove the frame from the // stack. ProtocolState popReturnState(); Protocol& proto_; + ThriftFilters::DecoderFilter& filter_; ProtocolState state_; std::vector stack_; }; typedef std::unique_ptr DecoderStateMachinePtr; +class DecoderCallbacks { +public: + virtual ~DecoderCallbacks() {} + + /** + * @return DecoderFilter& a new DecoderFilter for a message. + */ + virtual ThriftFilters::DecoderFilter& newDecoderFilter() PURE; +}; + /** * Decoder encapsulates a configured TransportPtr and ProtocolPtr. */ class Decoder : public Logger::Loggable { public: - Decoder(TransportPtr&& transport, ProtocolPtr&& protocol); + Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks); + Decoder(TransportType transport_type, ProtocolType protocol_type, DecoderCallbacks& callbacks); /** - * Drains data from the given buffer while executing a DecoderStateMachine over the data. A new - * DecoderStateMachine is instantiated for each message. + * Drains data from the given buffer while executing a DecoderStateMachine over the data. * * @param data a Buffer containing Thrift protocol data + * @param buffer_underflow bool set to true if more data is required to continue decoding + * @return ThriftFilters::FilterStatus::StopIteration when waiting for filter continuation, + * Continue otherwise. * @throw EnvoyException on Thrift protocol errors */ - void onData(Buffer::Instance& data); + ThriftFilters::FilterStatus onData(Buffer::Instance& data, bool& buffer_underflow); - const Transport& transport() { return *transport_; } - const Protocol& protocol() { return *protocol_; } + TransportType transportType() { return transport_->type(); } + ProtocolType protocolType() { return protocol_->type(); } private: + struct ActiveRequest { + ActiveRequest(ThriftFilters::DecoderFilter& filter) : filter_(filter) {} + + ThriftFilters::DecoderFilter& filter_; + }; + typedef std::unique_ptr ActiveRequestPtr; + + void complete(); + TransportPtr transport_; ProtocolPtr protocol_; + DecoderCallbacks& callbacks_; + ActiveRequestPtr request_; DecoderStateMachinePtr state_machine_; - bool frame_started_; + bool frame_started_{false}; + bool frame_ended_{false}; }; typedef std::unique_ptr DecoderPtr; diff --git a/source/extensions/filters/network/thrift_proxy/filter.cc b/source/extensions/filters/network/thrift_proxy/filter.cc deleted file mode 100644 index 7c5f561d5dd95..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/filter.cc +++ /dev/null @@ -1,310 +0,0 @@ -#include "extensions/filters/network/thrift_proxy/filter.h" - -#include "envoy/common/exception.h" - -#include "common/common/assert.h" - -#include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/protocol.h" -#include "extensions/filters/network/thrift_proxy/transport.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -Filter::Filter(const std::string& stat_prefix, Stats::Scope& scope) - : req_callbacks_(*this), resp_callbacks_(*this), stats_(generateStats(stat_prefix, scope)) {} - -Filter::~Filter() {} - -void Filter::onEvent(Network::ConnectionEvent event) { - if (active_call_map_.empty() && req_ == nullptr && resp_ == nullptr) { - return; - } - - if (event == Network::ConnectionEvent::RemoteClose) { - stats_.cx_destroy_local_with_active_rq_.inc(); - } - - if (event == Network::ConnectionEvent::LocalClose) { - stats_.cx_destroy_remote_with_active_rq_.inc(); - } -} - -Network::FilterStatus Filter::onData(Buffer::Instance& data, bool) { - if (!sniffing_) { - if (req_buffer_.length() > 0) { - // Stopped sniffing during response (in onWrite). Make sure leftover req_buffer_ contents are - // at the start of data or the upstream will see a corrupted request. - req_buffer_.move(data); - data.move(req_buffer_); - ASSERT(req_buffer_.length() == 0); - } - - return Network::FilterStatus::Continue; - } - - if (req_decoder_ == nullptr) { - req_decoder_ = std::make_unique(std::make_unique(req_callbacks_), - std::make_unique(req_callbacks_)); - } - - ENVOY_LOG(trace, "thrift: read {} bytes", data.length()); - req_buffer_.move(data); - - try { - BufferWrapper wrapped(req_buffer_); - - req_decoder_->onData(wrapped); - - // Move consumed portion of request back to data for the upstream to consume. - uint64_t pos = wrapped.position(); - if (pos > 0) { - data.move(req_buffer_, pos); - } - } catch (const EnvoyException& ex) { - ENVOY_LOG(error, "thrift error: {}", ex.what()); - req_decoder_.reset(); - data.move(req_buffer_); - stats_.request_decoding_error_.inc(); - sniffing_ = false; - } - - return Network::FilterStatus::Continue; -} - -Network::FilterStatus Filter::onWrite(Buffer::Instance& data, bool) { - if (!sniffing_) { - if (resp_buffer_.length() > 0) { - // Stopped sniffing during request (in onData). Make sure resp_buffer_ contents are at the - // start of data or the downstream will see a corrupted response. - resp_buffer_.move(data); - data.move(resp_buffer_); - ASSERT(resp_buffer_.length() == 0); - } - - return Network::FilterStatus::Continue; - } - - if (resp_decoder_ == nullptr) { - resp_decoder_ = std::make_unique(std::make_unique(resp_callbacks_), - std::make_unique(resp_callbacks_)); - } - - ENVOY_LOG(trace, "thrift wrote {} bytes", data.length()); - resp_buffer_.move(data); - - try { - BufferWrapper wrapped(resp_buffer_); - - resp_decoder_->onData(wrapped); - - // Move consumed portion of response back to data for the downstream to consume. - uint64_t pos = wrapped.position(); - if (pos > 0) { - data.move(resp_buffer_, pos); - } - } catch (const EnvoyException& ex) { - ENVOY_LOG(error, "thrift error: {}", ex.what()); - resp_decoder_.reset(); - data.move(resp_buffer_); - - stats_.response_decoding_error_.inc(); - sniffing_ = false; - } - - return Network::FilterStatus::Continue; -} - -void Filter::chargeDownstreamRequestStart(MessageType msg_type, int32_t seq_id) { - if (req_ != nullptr) { - throw EnvoyException("unexpected request messageStart callback"); - } - - if (active_call_map_.size() >= 64) { - throw EnvoyException("too many pending calls (64), disabling sniffing"); - } - - req_ = std::make_unique(*this, msg_type, seq_id); - - stats_.request_.inc(); - switch (msg_type) { - case MessageType::Call: - stats_.request_call_.inc(); - break; - case MessageType::Oneway: - stats_.request_oneway_.inc(); - break; - default: - stats_.request_invalid_type_.inc(); - break; - } -} - -void Filter::chargeDownstreamRequestComplete() { - if (req_ == nullptr) { - throw EnvoyException("unexpected request messageComplete callback"); - } - - // One-way messages do not receive responses. - if (req_->msg_type_ == MessageType::Oneway) { - req_.reset(); - return; - } - - int32_t seq_id = req_->seq_id_; - active_call_map_.emplace(seq_id, std::move(req_)); -} - -void Filter::chargeUpstreamResponseStart(MessageType msg_type, int32_t seq_id) { - if (resp_ != nullptr) { - throw EnvoyException("unexpected response messageStart callback"); - } - - auto i = active_call_map_.find(seq_id); - if (i == active_call_map_.end()) { - throw EnvoyException(fmt::format("unknown reply seq_id {}", seq_id)); - } - - resp_ = std::move(i->second); - resp_->response_msg_type_ = msg_type; - active_call_map_.erase(i); -} - -void Filter::chargeUpstreamResponseField(FieldType field_type, int16_t field_id) { - if (resp_ == nullptr) { - throw EnvoyException("unexpected response messageField callback"); - } - - if (resp_->response_msg_type_ != MessageType::Reply) { - // If this is not a reply, we'll count an exception instead of an error, so leave - // resp_->success_ unset. - return; - } - - if (resp_->success_.has_value()) { - // If resp->success_ is already set, leave the existing value. - return; - } - - // Successful replies have a single field, with field_id 0 that contains the response value. - // IDL-level exceptions are encoded as a single field with field_id >= 1. - resp_->success_ = field_id == 0 && field_type != FieldType::Stop; -} - -void Filter::chargeUpstreamResponseComplete() { - if (resp_ == nullptr) { - throw EnvoyException("unexpected response messageComplete callback"); - } - - stats_.response_.inc(); - switch (resp_->response_msg_type_) { - case MessageType::Reply: - stats_.response_reply_.inc(); - break; - case MessageType::Exception: - stats_.response_exception_.inc(); - break; - default: - stats_.response_invalid_type_.inc(); - break; - } - - if (resp_->success_.has_value()) { - if (resp_->success_.value()) { - stats_.response_success_.inc(); - } else { - stats_.response_error_.inc(); - } - } - - resp_.reset(); -} - -void Filter::RequestCallbacks::transportFrameStart(absl::optional size) { - UNREFERENCED_PARAMETER(size); - ENVOY_LOG(debug, "thrift request: started {} frame", parent_.req_decoder_->transport().name()); -} - -void Filter::RequestCallbacks::transportFrameComplete() { - ENVOY_LOG(debug, "thrift request: ended {} frame", parent_.req_decoder_->transport().name()); -} - -void Filter::RequestCallbacks::messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) { - ENVOY_LOG(debug, "thrift request: started {} message {}: {}", - parent_.req_decoder_->protocol().name(), name, seq_id); - parent_.chargeDownstreamRequestStart(msg_type, seq_id); -} - -void Filter::RequestCallbacks::structBegin(const absl::string_view name) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift request: started {} struct", parent_.req_decoder_->protocol().name()); -} - -void Filter::RequestCallbacks::structField(const absl::string_view name, FieldType field_type, - int16_t field_id) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift request: started {} field {}, type {}", - parent_.req_decoder_->protocol().name(), field_id, static_cast(field_type)); -} - -void Filter::RequestCallbacks::structEnd() { - ENVOY_LOG(debug, "thrift request: ended {} struct", parent_.req_decoder_->protocol().name()); -} - -void Filter::RequestCallbacks::messageComplete() { - ENVOY_LOG(debug, "thrift request: ended {} message", parent_.req_decoder_->protocol().name()); - parent_.chargeDownstreamRequestComplete(); -} - -void Filter::ResponseCallbacks::transportFrameStart(absl::optional size) { - UNREFERENCED_PARAMETER(size); - ENVOY_LOG(debug, "thrift response: started {} frame", parent_.resp_decoder_->transport().name()); -} - -void Filter::ResponseCallbacks::transportFrameComplete() { - ENVOY_LOG(debug, "thrift response: ended {} frame", parent_.resp_decoder_->transport().name()); -} - -void Filter::ResponseCallbacks::messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) { - ENVOY_LOG(debug, "thrift response: started {} message {}: {}", - parent_.resp_decoder_->protocol().name(), name, seq_id); - parent_.chargeUpstreamResponseStart(msg_type, seq_id); -} - -void Filter::ResponseCallbacks::structBegin(const absl::string_view name) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift response: started {} struct", parent_.req_decoder_->protocol().name()); - depth_++; -} - -void Filter::ResponseCallbacks::structField(const absl::string_view name, FieldType field_type, - int16_t field_id) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift response: started {} field {}, type {}", - parent_.req_decoder_->protocol().name(), field_id, static_cast(field_type)); - - if (depth_ == 1) { - // Only care about the outermost struct, which corresponds to the success or failure of the - // request. - parent_.chargeUpstreamResponseField(field_type, field_id); - } -} - -void Filter::ResponseCallbacks::structEnd() { - ENVOY_LOG(debug, "thrift request: ended {} struct", parent_.req_decoder_->protocol().name()); - depth_--; -} - -void Filter::ResponseCallbacks::messageComplete() { - ENVOY_LOG(debug, "thrift response: ended {} message", parent_.resp_decoder_->protocol().name()); - parent_.chargeUpstreamResponseComplete(); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filter.h b/source/extensions/filters/network/thrift_proxy/filter.h deleted file mode 100644 index a0f028ed0221d..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/filter.h +++ /dev/null @@ -1,173 +0,0 @@ -#pragma once - -#include - -#include "envoy/network/connection.h" -#include "envoy/network/filter.h" -#include "envoy/stats/stats.h" -#include "envoy/stats/stats_macros.h" -#include "envoy/stats/timespan.h" - -#include "common/buffer/buffer_impl.h" -#include "common/common/logger.h" - -#include "extensions/filters/network/thrift_proxy/decoder.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -/** - * All thrift filter stats. @see stats_macros.h - */ -// clang-format off -#define ALL_THRIFT_FILTER_STATS(COUNTER, GAUGE, HISTOGRAM) \ - COUNTER(request) \ - COUNTER(request_call) \ - COUNTER(request_oneway) \ - COUNTER(request_invalid_type) \ - GAUGE(request_active) \ - COUNTER(request_decoding_error) \ - HISTOGRAM(request_time_ms) \ - COUNTER(response) \ - COUNTER(response_reply) \ - COUNTER(response_success) \ - COUNTER(response_error) \ - COUNTER(response_exception) \ - COUNTER(response_invalid_type) \ - COUNTER(response_decoding_error) \ - COUNTER(cx_destroy_local_with_active_rq) \ - COUNTER(cx_destroy_remote_with_active_rq) -// clang-format on - -/** - * Struct definition for all mongo proxy stats. @see stats_macros.h - */ -struct ThriftFilterStats { - ALL_THRIFT_FILTER_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, GENERATE_HISTOGRAM_STRUCT) -}; - -/** - * A sniffing filter for thrift traffic. - */ -class Filter : public Network::Filter, - public Network::ConnectionCallbacks, - Logger::Loggable { -public: - Filter(const std::string& stat_prefix, Stats::Scope& scope); - ~Filter(); - - // Network::ReadFilter - Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; - Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } - void initializeReadFilterCallbacks(Network::ReadFilterCallbacks&) override {} - - // Network::WriteFilter - Network::FilterStatus onWrite(Buffer::Instance& data, bool end_stream) override; - - // Network::ConnectionCallbacks - void onEvent(Network::ConnectionEvent) override; - void onAboveWriteBufferHighWatermark() override {} - void onBelowWriteBufferLowWatermark() override {} - -private: - // RequestCallbacks handles callbacks related to decoding downstream requests. - class RequestCallbacks : public virtual ProtocolCallbacks, public virtual TransportCallbacks { - public: - RequestCallbacks(Filter& parent) : parent_(parent) {} - - // TransportCallbacks - void transportFrameStart(absl::optional size) override; - void transportFrameComplete() override; - - // ProtocolCallbacks - void messageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) override; - void structBegin(const absl::string_view name) override; - void structField(const absl::string_view name, FieldType field_type, int16_t field_id) override; - void structEnd() override; - void messageComplete() override; - - private: - Filter& parent_; - }; - - // ResponseCallbacks handles callbacks related to decoding upstream responses. - class ResponseCallbacks : public virtual ProtocolCallbacks, public virtual TransportCallbacks { - public: - ResponseCallbacks(Filter& parent) : parent_(parent) {} - - // TransportCallbacks - void transportFrameStart(absl::optional size) override; - void transportFrameComplete() override; - - // ProtocolCallbacks - void messageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) override; - void structBegin(const absl::string_view name) override; - void structField(const absl::string_view name, FieldType field_type, int16_t field_id) override; - void structEnd() override; - void messageComplete() override; - - private: - Filter& parent_; - int depth_{0}; - }; - - // ActiveMessage tracks downstream requests for which no response has been received. - struct ActiveMessage { - ActiveMessage(Filter& parent, MessageType msg_type, int32_t seq_id) - : parent_(parent), request_timer_(new Stats::Timespan(parent_.stats_.request_time_ms_)), - msg_type_(msg_type), seq_id_(seq_id) { - parent_.stats_.request_active_.inc(); - } - ~ActiveMessage() { - request_timer_->complete(); - parent_.stats_.request_active_.dec(); - } - - Filter& parent_; - Stats::TimespanPtr request_timer_; - const MessageType msg_type_; - const int32_t seq_id_; - MessageType response_msg_type_{}; - absl::optional success_{}; - }; - typedef std::unique_ptr ActiveMessagePtr; - - ThriftFilterStats generateStats(const std::string& prefix, Stats::Scope& scope) { - return ThriftFilterStats{ALL_THRIFT_FILTER_STATS(POOL_COUNTER_PREFIX(scope, prefix), - POOL_GAUGE_PREFIX(scope, prefix), - POOL_HISTOGRAM_PREFIX(scope, prefix))}; - } - - void chargeDownstreamRequestStart(MessageType msg_type, int32_t seq_id); - void chargeDownstreamRequestComplete(); - void chargeUpstreamResponseStart(MessageType msg_type, int32_t seq_id); - void chargeUpstreamResponseField(FieldType field_type, int16_t field_id); - void chargeUpstreamResponseComplete(); - - // Downstream request decoder, callbacks, and buffer. - DecoderPtr req_decoder_{}; - RequestCallbacks req_callbacks_; - Buffer::OwnedImpl req_buffer_; - // Request currently being decoded. - ActiveMessagePtr req_; - - // Upstream response decoder, callbacks, and buffer. - DecoderPtr resp_decoder_{}; - ResponseCallbacks resp_callbacks_; - Buffer::OwnedImpl resp_buffer_; - // Response currently being decoded. - ActiveMessagePtr resp_; - - // List of active request messages. - std::unordered_map active_call_map_; - - bool sniffing_{true}; - ThriftFilterStats stats_; -}; - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/BUILD b/source/extensions/filters/network/thrift_proxy/filters/BUILD new file mode 100644 index 0000000000000..7f374b1fe5fd1 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/BUILD @@ -0,0 +1,50 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "filter_config_interface", + hdrs = ["filter_config.h"], + deps = [ + ":filter_interface", + "//include/envoy/server:filter_config_interface", + "//source/common/common:macros", + "//source/common/protobuf:cc_wkt_protos", + ], +) + +envoy_cc_library( + name = "factory_base_lib", + hdrs = ["factory_base.h"], + deps = [ + ":filter_config_interface", + "//source/common/protobuf:utility_lib", + ], +) + +envoy_cc_library( + name = "filter_interface", + hdrs = ["filter.h"], + external_deps = ["abseil_optional"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/network:connection_interface", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + ], +) + +envoy_cc_library( + name = "well_known_names", + hdrs = ["well_known_names.h"], + deps = [ + "//source/common/singleton:const_singleton", + ], +) diff --git a/source/extensions/filters/network/thrift_proxy/filters/factory_base.h b/source/extensions/filters/network/thrift_proxy/filters/factory_base.h new file mode 100644 index 0000000000000..bf2bb292f043b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/factory_base.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common/protobuf/utility.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter_config.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +template class FactoryBase : public NamedThriftFilterConfigFactory { +public: + FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message& proto_config, + const std::string& stats_prefix, + Server::Configuration::FactoryContext& context) override { + return createFilterFactoryFromProtoTyped( + MessageUtil::downcastAndValidate(proto_config), stats_prefix, context); + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return name_; } + +protected: + FactoryBase(const std::string& name) : name_(name) {} + +private: + virtual FilterFactoryCb + createFilterFactoryFromProtoTyped(const ConfigProto& proto_config, + const std::string& stats_prefix, + Server::Configuration::FactoryContext& context) PURE; + + const std::string name_; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h new file mode 100644 index 0000000000000..969ffcadfc46c --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -0,0 +1,306 @@ +#pragma once + +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/network/connection.h" + +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) PURE; +}; + +typedef std::unique_ptr DirectResponsePtr; + +/** + * Decoder filter callbacks add additional callbacks. + */ +class DecoderFilterCallbacks { +public: + virtual ~DecoderFilterCallbacks() {} + + /** + * @return uint64_t the ID of the originating stream for logging purposes. + */ + virtual uint64_t streamId() const PURE; + + /** + * @return const Network::Connection* the originating connection, or nullptr if there is none. + */ + virtual const Network::Connection* connection() const PURE; + + /** + * Continue iterating through the filter chain with buffered data. This routine can only be + * called if the filter has previously returned StopIteration from one of the DecoderFilter + * methods. The connection manager will callbacks to the next filter in the chain. Further note + * that if the request is not complete, the calling filter may receive further callbacks and must + * return an appropriate status code depending on what the filter needs to do. + */ + virtual void continueDecoding() PURE; + + /** + * @return RouteConstSharedPtr the route for the current request. + */ + virtual Router::RouteConstSharedPtr route() PURE; + + /** + * @return TransportType the originating transport. + */ + virtual TransportType downstreamTransportType() const PURE; + + /** + * @return ProtocolType the originating protocol. + */ + virtual ProtocolType downstreamProtocolType() const PURE; + + /** + * Create a locally generated response using the provided response object. + * @param response DirectResponsePtr the response to send to the downstream client + */ + virtual void sendLocalReply(DirectResponsePtr&& response) PURE; + + /** + * Indicates the start of an upstream response. May only be called once. + * @param transport_type TransportType the upstream is using + * @param protocol_type ProtocolType the upstream is using + */ + virtual void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) PURE; + + /** + * Called with upstream response data. + * @param data supplies the upstream's data + * @return true if the upstream response is complete; false if more data is expected + */ + virtual bool upstreamData(Buffer::Instance& data) PURE; + + /** + * Reset the downstream connection. + */ + virtual void resetDownstreamConnection() PURE; +}; + +enum class FilterStatus { + // Continue filter chain iteration. + Continue, + + // Stop iterating over filters in the filter chain. Iteration must be explicitly restarted via + // continueDecoding(). + StopIteration +}; + +/** + * Decoder filter interface. + */ +class DecoderFilter { +public: + virtual ~DecoderFilter() {} + + /** + * This routine is called prior to a filter being destroyed. This may happen after normal stream + * finish (both downstream and upstream) or due to reset. Every filter is responsible for making + * sure that any async events are cleaned up in the context of this routine. This includes timers, + * network calls, etc. The reason there is an onDestroy() method vs. doing this type of cleanup + * in the destructor is due to the deferred deletion model that Envoy uses to avoid stack unwind + * complications. Filters must not invoke either encoder or decoder filter callbacks after having + * onDestroy() invoked. + */ + virtual void onDestroy() PURE; + + /** + * Called by the connection manager once to initialize the filter decoder callbacks that the + * filter should use. Callbacks will not be invoked by the filter after onDestroy() is called. + */ + virtual void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) PURE; + + /** + * Resets the upstream connection. + */ + virtual void resetUpstreamConnection() PURE; + + /** + * Indicates the start of a Thrift transport frame was detected. Unframed transports generate + * simulated start messages. + * @param size the size of the message, if available to the transport + */ + virtual FilterStatus transportBegin(absl::optional size) PURE; + + /** + * Indicates the end of a Thrift transport frame was detected. Unframed transport generate + * simulated complete messages. + */ + virtual FilterStatus transportEnd() PURE; + + /** + * Indicates that the start of a Thrift protocol message was detected. + * @param name the name of the message, if available + * @param msg_type the type of the message + * @param seq_id the message sequence id + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) PURE; + + /** + * Indicates that the end of a Thrift protocol message was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus messageEnd() PURE; + + /** + * Indicates that the start of a Thrift protocol struct was detected. + * @param name the name of the struct, if available + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus structBegin(absl::string_view name) PURE; + + /** + * Indicates that the end of a Thrift protocol struct was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus structEnd() PURE; + + /** + * Indicates that the start of Thrift protocol struct field was detected. + * @param name the name of the field, if available + * @param field_type the type of the field + * @param field_id the field id + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) PURE; + + /** + * Indicates that the end of a Thrift protocol struct field was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus fieldEnd() PURE; + + /** + * A struct field, map key, map value, list element or set element was detected. + * @param value type value of the field + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus boolValue(bool value) PURE; + virtual FilterStatus byteValue(uint8_t value) PURE; + virtual FilterStatus int16Value(int16_t value) PURE; + virtual FilterStatus int32Value(int32_t value) PURE; + virtual FilterStatus int64Value(int64_t value) PURE; + virtual FilterStatus doubleValue(double value) PURE; + virtual FilterStatus stringValue(absl::string_view value) PURE; + + /** + * Indicates the start of a Thrift protocol map was detected. + * @param key_type the map key type + * @param value_type the map value type + * @param size the number of key-value pairs + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus mapBegin(FieldType key_type, FieldType value_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol map was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus mapEnd() PURE; + + /** + * Indicates the start of a Thrift protocol list was detected. + * @param elem_type the list value type + * @param size the number of values in the list + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus listBegin(FieldType elem_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol list was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus listEnd() PURE; + + /** + * Indicates the start of a Thrift protocol set was detected. + * @param elem_type the set value type + * @param size the number of values in the set + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus setBegin(FieldType elem_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol set was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus setEnd() PURE; +}; + +typedef std::shared_ptr DecoderFilterSharedPtr; + +/** + * These callbacks are provided by the connection manager to the factory so that the factory can + * build the filter chain in an application specific way. + */ +class FilterChainFactoryCallbacks { +public: + virtual ~FilterChainFactoryCallbacks() {} + + /** + * Add a decoder filter that is used when reading connection data. + * @param filter supplies the filter to add. + */ + virtual void addDecoderFilter(DecoderFilterSharedPtr filter) PURE; +}; + +/** + * This function is used to wrap the creation of a Thrift filter chain for new connections as they + * come in. Filter factories create the function at configuration initialization time, and then + * they are used at runtime. + * @param callbacks supplies the callbacks for the stream to install filters to. Typically the + * function will install a single filter, but it's technically possibly to install more than one + * if desired. + */ +typedef std::function FilterFactoryCb; + +/** + * A FilterChainFactory is used by a connection manager to create a Thrift level filter chain when + * a new connection is created. Typically it would be implemented by a configuration engine that + * would install a set of filters that are able to process an application scenario on top of a + * stream of Thrift requests. + */ +class FilterChainFactory { +public: + virtual ~FilterChainFactory() {} + + /** + * Called when a new Thrift stream is created on the connection. + * @param callbacks supplies the "sink" that is used for actually creating the filter chain. @see + * FilterChainFactoryCallbacks. + */ + virtual void createFilterChain(FilterChainFactoryCallbacks& callbacks) PURE; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter_config.h b/source/extensions/filters/network/thrift_proxy/filters/filter_config.h new file mode 100644 index 0000000000000..86f4b7730517b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/filter_config.h @@ -0,0 +1,55 @@ +#pragma once + +#include "envoy/server/filter_config.h" + +#include "common/common/macros.h" +#include "common/protobuf/protobuf.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +/** + * Implemented by each Thrift filter and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedThriftFilterConfigFactory { +public: + virtual ~NamedThriftFilterConfigFactory() {} + + /** + * Create a particular thrift filter factory implementation. If the implementation is unable to + * produce a factory with the provided parameters, it should throw an EnvoyException in the case + * of general error. The returned callback should always be initialized. + * @param config supplies the configuration for the filter + * @param stat_prefix prefix for stat logging + * @param context supplies the filter's context. + * @return FilterFactoryCb the factory creation function. + */ + virtual FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message& config, const std::string& stat_prefix, + Server::Configuration::FactoryContext& context) PURE; + + /** + * @return ProtobufTypes::MessagePtr create empty config proto message for v2. The filter + * config, which arrives in an opaque google.protobuf.Struct message, will be converted to + * JSON and then parsed into this empty proto. + */ + virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; + + /** + * @return std::string the identifying name for a particular implementation of a thrift filter + * produced by the factory. + */ + virtual std::string name() PURE; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h b/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h new file mode 100644 index 0000000000000..41abac8d54753 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h @@ -0,0 +1,25 @@ +#pragma once + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +/** + * Well-known http filter names. + * NOTE: New filters should use the well known name: envoy.filters.thrift.name. + */ +class ThriftFilterNameValues { +public: + // Router filter + const std::string ROUTER = "envoy.filters.thrift.router"; +}; + +typedef ConstSingleton ThriftFilterNames; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc new file mode 100644 index 0000000000000..a45861a349e4b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc @@ -0,0 +1,59 @@ +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" + +#include "envoy/common/exception.h" + +#include "extensions/filters/network/thrift_proxy/buffer_helper.h" +#include "extensions/filters/network/thrift_proxy/transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer, + absl::optional& size) { + if (buffer.length() < 4) { + return false; + } + + int32_t thrift_size = BufferHelper::peekI32(buffer); + + if (thrift_size <= 0 || thrift_size > MaxFrameSize) { + throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", thrift_size)); + } + + buffer.drain(4); + + size = static_cast(thrift_size); + return true; +} + +bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { return true; } + +void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) { + uint64_t size = message.length(); + if (size == 0 || size > MaxFrameSize) { + throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", size)); + } + + int32_t thrift_size = static_cast(size); + + BufferHelper::writeI32(buffer, thrift_size); + buffer.move(message); +} + +class FramedTransportConfigFactory : public TransportFactoryBase { +public: + FramedTransportConfigFactory() : TransportFactoryBase(TransportNames::get().FRAMED) {} +}; + +/** + * Static registration for the framed transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h new file mode 100644 index 0000000000000..4c7569487ea38 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "envoy/buffer/buffer.h" + +#include "extensions/filters/network/thrift_proxy/transport_impl.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * FramedTransportImpl implements the Thrift Framed transport. + * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + */ +class FramedTransportImpl : public Transport { +public: + FramedTransportImpl() {} + + // Transport + const std::string& name() const override { return TransportNames::get().FRAMED; } + TransportType type() const override { return TransportType::Framed; } + bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; + bool decodeFrameEnd(Buffer::Instance& buffer) override; + void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; + + static const int32_t MaxFrameSize = 0xFA0000; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 442bfc576d977..02f2808427e08 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -1,23 +1,33 @@ #pragma once #include -#include #include -#include -#include #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" +#include "envoy/registry/registry.h" -#include "common/common/fmt.h" -#include "common/common/macros.h" +#include "common/common/assert.h" +#include "common/config/utility.h" #include "common/singleton/const_singleton.h" +#include "absl/strings/string_view.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +enum class ProtocolType { + Binary, + LaxBinary, + Compact, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST PROTOCOL TYPE + LastProtocolType = Auto, +}; + /** * Names of available Protocol implementations. */ @@ -32,11 +42,23 @@ class ProtocolNameValues { // Compact protocol const std::string COMPACT = "compact"; - // JSON protocol - const std::string JSON = "json"; - // Auto-detection protocol const std::string AUTO = "auto"; + + const std::string& fromType(ProtocolType type) const { + switch (type) { + case ProtocolType::Binary: + return BINARY; + case ProtocolType::LaxBinary: + return LAX_BINARY; + case ProtocolType::Compact: + return COMPACT; + case ProtocolType::Auto: + return AUTO; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } }; typedef ConstSingleton ProtocolNames; @@ -78,48 +100,6 @@ enum class FieldType { LastFieldType = List, }; -/** - * ProtocolCallbacks are Thrift protocol-level callbacks. - */ -class ProtocolCallbacks { -public: - virtual ~ProtocolCallbacks() {} - - /** - * Indicates that the start of a Thrift protocol message was detected. - * @param name the name of the message, if available - * @param msg_type the type of the message - * @param seq_id the message sequence id - */ - virtual void messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) PURE; - - /** - * Indicates that the start of a Thrift protocol struct was detected. - * @param name the name of the struct, if available - */ - virtual void structBegin(const absl::string_view name) PURE; - - /** - * Indicates that the start of Thrift protocol struct field was detected. - * @param name the name of the field, if available - * @param field_type the type of the field - * @param field_id the field id - */ - virtual void structField(const absl::string_view name, FieldType field_type, - int16_t field_id) PURE; - - /** - * Indicates that the end of a Thrift protocol struct was detected. - */ - virtual void structEnd() PURE; - - /** - * Indicates that the end of a Thrift protocol message was detected. - */ - virtual void messageComplete() PURE; -}; - /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -128,8 +108,16 @@ class Protocol { public: virtual ~Protocol() {} + /** + * @return const std::string& the human-readable name of the protocol + */ virtual const std::string& name() const PURE; + /** + * @return ProtocolType the protocol type + */ + virtual ProtocolType type() const PURE; + /** * Reads the start of a Thrift protocol message from the buffer and updates the name, msg_type, * and seq_id parameters with values from the message header. If successful, the message header @@ -339,104 +327,199 @@ class Protocol { * @throw EnvoyException if the data is not a valid set footer */ virtual bool readBinary(Buffer::Instance& buffer, std::string& value) PURE; + + /** + * Writes the start of a Thrift protocol message to the buffer. + * @param buffer Buffer::Instance to modify + * @param name the message name + * @param msg_type the message's MessageType + * @param seq_id the message sequende ID + */ + virtual void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, + MessageType msg_type, int32_t seq_id) PURE; + + /** + * Writes the end of a Thrift protocol message to the buffer. + * @param buffer Buffer::Instance to modify + */ + virtual void writeMessageEnd(Buffer::Instance& buffer) PURE; + + /** + * Writes the start of a Thrift struct to the buffer. + * @param buffer Buffer::Instance to modify + * @param name the struct name, if known + */ + virtual void writeStructBegin(Buffer::Instance& buffer, const std::string& name) PURE; + + /** + * Writes the end of a Thrift struct to the buffer. + * @param buffer Buffer::Instance to modify + */ + virtual void writeStructEnd(Buffer::Instance& buffer) PURE; + + /** + * Writes the start of a Thrift struct field to the buffer + * @param buffer Buffer::Instance to modify + * @param name the field name, if known + * @param field_type the field's FieldType + * @param field_id the field ID + */ + virtual void writeFieldBegin(Buffer::Instance& buffer, const std::string& name, + FieldType field_type, int16_t field_id) PURE; + + /** + * Writes the end of a Thrift struct field to the buffer. + * @param buffer Buffer::Instance to modify + */ + virtual void writeFieldEnd(Buffer::Instance& buffer) PURE; + + /** + * Writes the start of a Thrift map to the buffer. + * @param buffer Buffer::Instance to modify + * @param key_type the map key FieldType + * @param value_type the map value FieldType + * @param size the number of key-value pairs in the map + */ + virtual void writeMapBegin(Buffer::Instance& buffer, FieldType key_type, FieldType value_type, + uint32_t size) PURE; + + /** + * Writes the end of a Thrift map to the buffer. + * @param buffer Buffer::Instance to modify + */ + virtual void writeMapEnd(Buffer::Instance& buffer) PURE; + + /** + * Writes the start of a Thrift list to the buffer. + * @param buffer Buffer::Instance to modify + * @param elem_type the list element FieldType + * @param size the number of list members + */ + virtual void writeListBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) PURE; + + /** + * Writes the end of a Thrift list to the buffer. + * @param buffer Buffer::Instance to modify + */ + virtual void writeListEnd(Buffer::Instance& buffer) PURE; + + /** + * Writes the start of a Thrift set to the buffer. + * @param buffer Buffer::Instance to modify + * @param elem_type the set element FieldType + * @param size the number of set members + */ + virtual void writeSetBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) PURE; + + /** + * Writes the end of a Thrift set to the buffer. + * @param buffer Buffer::Instance to modify + */ + virtual void writeSetEnd(Buffer::Instance& buffer) PURE; + + /** + * Writes a boolean value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value bool to write + */ + virtual void writeBool(Buffer::Instance& buffer, bool value) PURE; + + /** + * Writes a byte value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value uint8_t to write + */ + virtual void writeByte(Buffer::Instance& buffer, uint8_t value) PURE; + + /** + * Writes a int16_t value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value int16_t to write + */ + virtual void writeInt16(Buffer::Instance& buffer, int16_t value) PURE; + + /** + * Writes a int32_t value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value int32_t to write + */ + virtual void writeInt32(Buffer::Instance& buffer, int32_t value) PURE; + + /** + * Writes a int64_t value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value int64_t to write + */ + virtual void writeInt64(Buffer::Instance& buffer, int64_t value) PURE; + + /** + * Writes a double value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value double to write + */ + virtual void writeDouble(Buffer::Instance& buffer, double value) PURE; + + /** + * Writes a string value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value std::string to write + */ + virtual void writeString(Buffer::Instance& buffer, const std::string& value) PURE; + + /** + * Writes a binary value to the buffer. + * @param buffer Buffer::Instance to modify + * @param value std::string to write + */ + virtual void writeBinary(Buffer::Instance& buffer, const std::string& value) PURE; }; typedef std::unique_ptr ProtocolPtr; -/* - * ProtocolImplBase provides a base class for Protocol implementations. +/** + * Implemented by each Thrift protocol and registered via Registry::registerFactory or the + * convenience class RegisterFactory. */ -class ProtocolImplBase : public virtual Protocol { +class NamedProtocolConfigFactory { public: - ProtocolImplBase(ProtocolCallbacks& callbacks) : callbacks_(callbacks) {} + virtual ~NamedProtocolConfigFactory() {} -protected: - void onMessageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) const { - callbacks_.messageStart(name, msg_type, seq_id); - } - void onStructBegin(const absl::string_view name) const { callbacks_.structBegin(name); } - void onStructField(const absl::string_view name, FieldType field_type, int16_t field_id) const { - callbacks_.structField(name, field_type, field_id); - } - void onStructEnd() const { callbacks_.structEnd(); } - void onMessageComplete() const { callbacks_.messageComplete(); } + /** + * Create a particular Thrift protocol + * @return ProtocolFactoryCb the protocol + */ + virtual ProtocolPtr createProtocol() PURE; - ProtocolCallbacks& callbacks_; + /** + * @return std::string the identifying name for a particular implementation of thrift protocol + * produced by the factory. + */ + virtual std::string name() PURE; + + /** + * Convenience method to lookup a factory by type. + * @param ProtocolType the protocol type + * @return NamedProtocolConfigFactory& for the ProtocolType + */ + static NamedProtocolConfigFactory& getFactory(ProtocolType type) { + const std::string& name = ProtocolNames::get().fromType(type); + return Envoy::Config::Utility::getAndCheckFactory(name); + } }; /** - * AutoProtocolImpl attempts to distinguish between the Thrift binary (strict mode only) and - * compact protocols and then delegates subsequent decoding operations to the appropriate Protocol - * implementation. + * ProtocolFactoryBase provides a template for a trivial NamedProtocolConfigFactory. */ -class AutoProtocolImpl : public ProtocolImplBase { -public: - AutoProtocolImpl(ProtocolCallbacks& callbacks) - : ProtocolImplBase(callbacks), name_(ProtocolNames::get().AUTO) {} - - // Protocol - const std::string& name() const override { return name_; } - bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, - int32_t& seq_id) override; - bool readMessageEnd(Buffer::Instance& buffer) override; - bool readStructBegin(Buffer::Instance& buffer, std::string& name) override { - return protocol_->readStructBegin(buffer, name); - } - bool readStructEnd(Buffer::Instance& buffer) override { return protocol_->readStructEnd(buffer); } - bool readFieldBegin(Buffer::Instance& buffer, std::string& name, FieldType& field_type, - int16_t& field_id) override { - return protocol_->readFieldBegin(buffer, name, field_type, field_id); - } - bool readFieldEnd(Buffer::Instance& buffer) override { return protocol_->readFieldEnd(buffer); } - bool readMapBegin(Buffer::Instance& buffer, FieldType& key_type, FieldType& value_type, - uint32_t& size) override { - return protocol_->readMapBegin(buffer, key_type, value_type, size); - } - bool readMapEnd(Buffer::Instance& buffer) override { return protocol_->readMapEnd(buffer); } - bool readListBegin(Buffer::Instance& buffer, FieldType& elem_type, uint32_t& size) override { - return protocol_->readListBegin(buffer, elem_type, size); - } - bool readListEnd(Buffer::Instance& buffer) override { return protocol_->readListEnd(buffer); } - bool readSetBegin(Buffer::Instance& buffer, FieldType& elem_type, uint32_t& size) override { - return protocol_->readSetBegin(buffer, elem_type, size); - } - bool readSetEnd(Buffer::Instance& buffer) override { return protocol_->readSetEnd(buffer); } - bool readBool(Buffer::Instance& buffer, bool& value) override { - return protocol_->readBool(buffer, value); - } - bool readByte(Buffer::Instance& buffer, uint8_t& value) override { - return protocol_->readByte(buffer, value); - } - bool readInt16(Buffer::Instance& buffer, int16_t& value) override { - return protocol_->readInt16(buffer, value); - } - bool readInt32(Buffer::Instance& buffer, int32_t& value) override { - return protocol_->readInt32(buffer, value); - } - bool readInt64(Buffer::Instance& buffer, int64_t& value) override { - return protocol_->readInt64(buffer, value); - } - bool readDouble(Buffer::Instance& buffer, double& value) override { - return protocol_->readDouble(buffer, value); - } - bool readString(Buffer::Instance& buffer, std::string& value) override { - return protocol_->readString(buffer, value); - } - bool readBinary(Buffer::Instance& buffer, std::string& value) override { - return protocol_->readBinary(buffer, value); - } +template class ProtocolFactoryBase : public NamedProtocolConfigFactory { + ProtocolPtr createProtocol() override { return std::move(std::make_unique()); } - /* - * Explicitly set the protocol. Public to simplify testing. - */ - void setProtocol(ProtocolPtr&& proto) { - protocol_ = std::move(proto); - name_ = fmt::format("{}({})", protocol_->name(), ProtocolNames::get().AUTO); - } + std::string name() override { return name_; } + +protected: + ProtocolFactoryBase(const std::string& name) : name_(name) {} private: - ProtocolPtr protocol_{}; - std::string name_; + const std::string name_; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h new file mode 100644 index 0000000000000..af7b6dfa2af3d --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -0,0 +1,144 @@ +#pragma once + +#include "envoy/buffer/buffer.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ProtocolConverter is an abstract class that implements protocol-related methods on + * ThriftFilters::DecoderFilter in terms of converting the decoded messages into a different + * protocol. + */ +class ProtocolConverter : public ThriftFilters::DecoderFilter { +public: + ProtocolConverter() {} + ~ProtocolConverter() {} + + void initProtocolConverter(ProtocolPtr&& proto, Buffer::Instance& buffer) { + proto_ = std::move(proto); + buffer_ = &buffer; + } + + // ThiftFilters::DecoderFilter + void onDestroy() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override { + proto_->writeMessageBegin(*buffer_, std::string(name), msg_type, seq_id); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus messageEnd() override { + proto_->writeMessageEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus structBegin(absl::string_view name) override { + proto_->writeStructBegin(*buffer_, std::string(name)); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus structEnd() override { + proto_->writeFieldBegin(*buffer_, "", FieldType::Stop, 0); + proto_->writeStructEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override { + proto_->writeFieldBegin(*buffer_, std::string(name), field_type, field_id); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus fieldEnd() override { + proto_->writeFieldEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus boolValue(bool value) override { + proto_->writeBool(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus byteValue(uint8_t value) override { + proto_->writeByte(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int16Value(int16_t value) override { + proto_->writeInt16(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int32Value(int32_t value) override { + proto_->writeInt32(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int64Value(int64_t value) override { + proto_->writeInt64(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus doubleValue(double value) override { + proto_->writeDouble(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus stringValue(absl::string_view value) override { + proto_->writeString(*buffer_, std::string(value)); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus mapBegin(FieldType key_type, FieldType value_type, + uint32_t size) override { + proto_->writeMapBegin(*buffer_, key_type, value_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus mapEnd() override { + proto_->writeMapEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus listBegin(FieldType elem_type, uint32_t size) override { + proto_->writeListBegin(*buffer_, elem_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus listEnd() override { + proto_->writeListEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus setBegin(FieldType elem_type, uint32_t size) override { + proto_->writeSetBegin(*buffer_, elem_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus setEnd() override { + proto_->writeSetEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + +protected: + ProtocolType protocolType() const { return proto_->type(); } + +private: + ProtocolPtr proto_; + Buffer::Instance* buffer_{}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/protocol.cc b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc similarity index 68% rename from source/extensions/filters/network/thrift_proxy/protocol.cc rename to source/extensions/filters/network/thrift_proxy/protocol_impl.cc index cf8f03c551197..46636099f5430 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.cc +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc @@ -1,4 +1,4 @@ -#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" #include @@ -8,9 +8,9 @@ #include "common/common/byte_order.h" #include "common/common/macros.h" -#include "extensions/filters/network/thrift_proxy/binary_protocol.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/compact_protocol.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" namespace Envoy { namespace Extensions { @@ -26,9 +26,9 @@ bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& n uint16_t version = BufferHelper::peekU16(buffer); if (BinaryProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique(callbacks_)); + setProtocol(std::make_unique()); } else if (CompactProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique(callbacks_)); + setProtocol(std::make_unique()); } else { throw EnvoyException( fmt::format("unknown thrift auto protocol message start {:04x}", version)); @@ -41,10 +41,20 @@ bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& n } bool AutoProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { - RELEASE_ASSERT(protocol_ != nullptr); + RELEASE_ASSERT(protocol_ != nullptr, ""); return protocol_->readMessageEnd(buffer); } +class AutoProtocolConfigFactory : public ProtocolFactoryBase { +public: + AutoProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().AUTO) {} +}; + +/** + * Static registration for the auto protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.h b/source/extensions/filters/network/thrift_proxy/protocol_impl.h new file mode 100644 index 0000000000000..0eb374e9f9869 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.h @@ -0,0 +1,151 @@ +#pragma once + +#include + +#include "envoy/buffer/buffer.h" + +#include "common/common/fmt.h" + +#include "extensions/filters/network/thrift_proxy/protocol.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * AutoProtocolImpl attempts to distinguish between the Thrift binary (strict mode only) and + * compact protocols and then delegates subsequent decoding operations to the appropriate Protocol + * implementation. + */ +class AutoProtocolImpl : public Protocol { +public: + AutoProtocolImpl() : name_(ProtocolNames::get().AUTO) {} + + // Protocol + const std::string& name() const override { return name_; } + ProtocolType type() const override { + if (protocol_ != nullptr) { + return protocol_->type(); + } + return ProtocolType::Auto; + } + + bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, + int32_t& seq_id) override; + bool readMessageEnd(Buffer::Instance& buffer) override; + bool readStructBegin(Buffer::Instance& buffer, std::string& name) override { + return protocol_->readStructBegin(buffer, name); + } + bool readStructEnd(Buffer::Instance& buffer) override { return protocol_->readStructEnd(buffer); } + bool readFieldBegin(Buffer::Instance& buffer, std::string& name, FieldType& field_type, + int16_t& field_id) override { + return protocol_->readFieldBegin(buffer, name, field_type, field_id); + } + bool readFieldEnd(Buffer::Instance& buffer) override { return protocol_->readFieldEnd(buffer); } + bool readMapBegin(Buffer::Instance& buffer, FieldType& key_type, FieldType& value_type, + uint32_t& size) override { + return protocol_->readMapBegin(buffer, key_type, value_type, size); + } + bool readMapEnd(Buffer::Instance& buffer) override { return protocol_->readMapEnd(buffer); } + bool readListBegin(Buffer::Instance& buffer, FieldType& elem_type, uint32_t& size) override { + return protocol_->readListBegin(buffer, elem_type, size); + } + bool readListEnd(Buffer::Instance& buffer) override { return protocol_->readListEnd(buffer); } + bool readSetBegin(Buffer::Instance& buffer, FieldType& elem_type, uint32_t& size) override { + return protocol_->readSetBegin(buffer, elem_type, size); + } + bool readSetEnd(Buffer::Instance& buffer) override { return protocol_->readSetEnd(buffer); } + bool readBool(Buffer::Instance& buffer, bool& value) override { + return protocol_->readBool(buffer, value); + } + bool readByte(Buffer::Instance& buffer, uint8_t& value) override { + return protocol_->readByte(buffer, value); + } + bool readInt16(Buffer::Instance& buffer, int16_t& value) override { + return protocol_->readInt16(buffer, value); + } + bool readInt32(Buffer::Instance& buffer, int32_t& value) override { + return protocol_->readInt32(buffer, value); + } + bool readInt64(Buffer::Instance& buffer, int64_t& value) override { + return protocol_->readInt64(buffer, value); + } + bool readDouble(Buffer::Instance& buffer, double& value) override { + return protocol_->readDouble(buffer, value); + } + bool readString(Buffer::Instance& buffer, std::string& value) override { + return protocol_->readString(buffer, value); + } + bool readBinary(Buffer::Instance& buffer, std::string& value) override { + return protocol_->readBinary(buffer, value); + } + void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, + int32_t seq_id) override { + protocol_->writeMessageBegin(buffer, name, msg_type, seq_id); + } + void writeMessageEnd(Buffer::Instance& buffer) override { protocol_->writeMessageEnd(buffer); } + void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override { + protocol_->writeStructBegin(buffer, name); + } + void writeStructEnd(Buffer::Instance& buffer) override { protocol_->writeStructEnd(buffer); } + void writeFieldBegin(Buffer::Instance& buffer, const std::string& name, FieldType field_type, + int16_t field_id) override { + protocol_->writeFieldBegin(buffer, name, field_type, field_id); + } + void writeFieldEnd(Buffer::Instance& buffer) override { protocol_->writeFieldEnd(buffer); } + void writeMapBegin(Buffer::Instance& buffer, FieldType key_type, FieldType value_type, + uint32_t size) override { + protocol_->writeMapBegin(buffer, key_type, value_type, size); + } + void writeMapEnd(Buffer::Instance& buffer) override { protocol_->writeMapEnd(buffer); } + void writeListBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) override { + protocol_->writeListBegin(buffer, elem_type, size); + } + void writeListEnd(Buffer::Instance& buffer) override { protocol_->writeListEnd(buffer); } + void writeSetBegin(Buffer::Instance& buffer, FieldType elem_type, uint32_t size) override { + protocol_->writeSetBegin(buffer, elem_type, size); + } + void writeSetEnd(Buffer::Instance& buffer) override { protocol_->writeSetEnd(buffer); } + void writeBool(Buffer::Instance& buffer, bool value) override { + protocol_->writeBool(buffer, value); + } + void writeByte(Buffer::Instance& buffer, uint8_t value) override { + protocol_->writeByte(buffer, value); + } + void writeInt16(Buffer::Instance& buffer, int16_t value) override { + protocol_->writeInt16(buffer, value); + } + void writeInt32(Buffer::Instance& buffer, int32_t value) override { + protocol_->writeInt32(buffer, value); + } + void writeInt64(Buffer::Instance& buffer, int64_t value) override { + protocol_->writeInt64(buffer, value); + } + void writeDouble(Buffer::Instance& buffer, double value) override { + protocol_->writeDouble(buffer, value); + } + void writeString(Buffer::Instance& buffer, const std::string& value) override { + protocol_->writeString(buffer, value); + } + void writeBinary(Buffer::Instance& buffer, const std::string& value) override { + protocol_->writeBinary(buffer, value); + } + + /* + * Explicitly set the protocol. Public to simplify testing. + */ + void setProtocol(ProtocolPtr&& proto) { + protocol_ = std::move(proto); + name_ = fmt::format("{}({})", protocol_->name(), ProtocolNames::get().AUTO); + } + +private: + ProtocolPtr protocol_{}; + std::string name_; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD new file mode 100644 index 0000000000000..b06fc47112a9d --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -0,0 +1,51 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + ":router_lib", + "//include/envoy/registry", + "//source/extensions/filters/network/thrift_proxy/filters:factory_base_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_config_interface", + "//source/extensions/filters/network/thrift_proxy/filters:well_known_names", + "@envoy_api//envoy/config/filter/network/thrift_proxy/v2alpha1/router:router_cc", + ], +) + +envoy_cc_library( + name = "router_interface", + hdrs = ["router.h"], + external_deps = ["abseil_optional"], + deps = [], +) + +envoy_cc_library( + name = "router_lib", + srcs = ["router_impl.cc"], + hdrs = ["router_impl.h"], + deps = [ + ":router_interface", + "//include/envoy/tcp:conn_pool_interface", + "//include/envoy/upstream:cluster_manager_interface", + "//include/envoy/upstream:load_balancer_interface", + "//include/envoy/upstream:thread_local_cluster_interface", + "//source/common/common:logger_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_lib", + "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "@envoy_api//envoy/config/filter/network/thrift_proxy/v2alpha1:thrift_proxy_cc", + ], +) diff --git a/source/extensions/filters/network/thrift_proxy/router/config.cc b/source/extensions/filters/network/thrift_proxy/router/config.cc new file mode 100644 index 0000000000000..8f6bddcf6fbdb --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/config.cc @@ -0,0 +1,34 @@ +#include "extensions/filters/network/thrift_proxy/router/config.h" + +#include "envoy/registry/registry.h" + +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +ThriftFilters::FilterFactoryCb RouterFilterConfig::createFilterFactoryFromProtoTyped( + const envoy::config::filter::network::thrift_proxy::v2alpha1::router::Router& proto_config, + const std::string& stat_prefix, Server::Configuration::FactoryContext& context) { + UNREFERENCED_PARAMETER(proto_config); + UNREFERENCED_PARAMETER(stat_prefix); + + return [&context](ThriftFilters::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addDecoderFilter(std::make_shared(context.clusterManager())); + }; +} + +/** + * Static registration for the router filter. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/config.h b/source/extensions/filters/network/thrift_proxy/router/config.h new file mode 100644 index 0000000000000..937847d78a849 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/config.h @@ -0,0 +1,31 @@ +#pragma once + +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/router/router.pb.h" +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/router/router.pb.validate.h" + +#include "extensions/filters/network/thrift_proxy/filters/factory_base.h" +#include "extensions/filters/network/thrift_proxy/filters/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +class RouterFilterConfig + : public ThriftFilters::FactoryBase< + envoy::config::filter::network::thrift_proxy::v2alpha1::router::Router> { +public: + RouterFilterConfig() : FactoryBase(ThriftFilters::ThriftFilterNames::get().ROUTER) {} + +private: + ThriftFilters::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::config::filter::network::thrift_proxy::v2alpha1::router::Router& proto_config, + const std::string& stat_prefix, Server::Configuration::FactoryContext& context) override; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h new file mode 100644 index 0000000000000..32d717de52351 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +/** + * RouteEntry is an individual resolved route entry. + */ +class RouteEntry { +public: + virtual ~RouteEntry() {} + + /** + * @return const std::string& the upstream cluster that owns the route. + */ + virtual const std::string& clusterName() const PURE; +}; + +/** + * Route holds the RouteEntry for a request. + */ +class Route { +public: + virtual ~Route() {} + + /** + * @return the route entry or nullptr if there is no matching route for the request. + */ + virtual const RouteEntry* routeEntry() const PURE; +}; + +typedef std::shared_ptr RouteConstSharedPtr; + +/** + * The router configuration. + */ +class Config { +public: + virtual ~Config() {} + + /** + * Based on the incoming Thrift request transport and/or protocol data, determine the target + * route for the request. + * @param method supplies the thrift method name + * @return the route or nullptr if there is no matching route for the request. + */ + virtual RouteConstSharedPtr route(const std::string& method) const PURE; +}; + +typedef std::shared_ptr ConfigConstSharedPtr; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc new file mode 100644 index 0000000000000..09d1ac142c2d5 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -0,0 +1,302 @@ +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/upstream/cluster_manager.h" +#include "envoy/upstream/thread_local_cluster.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +RouteEntryImplBase::RouteEntryImplBase( + const envoy::config::filter::network::thrift_proxy::v2alpha1::Route& route) + : cluster_name_(route.route().cluster()) {} + +const std::string& RouteEntryImplBase::clusterName() const { return cluster_name_; } + +const RouteEntry* RouteEntryImplBase::routeEntry() const { return this; } + +RouteConstSharedPtr RouteEntryImplBase::clusterEntry() const { return shared_from_this(); } + +MethodNameRouteEntryImpl::MethodNameRouteEntryImpl( + const envoy::config::filter::network::thrift_proxy::v2alpha1::Route& route) + : RouteEntryImplBase(route), method_name_(route.match().method()) {} + +RouteConstSharedPtr MethodNameRouteEntryImpl::matches(const std::string& method_name) const { + if (method_name_.empty() || method_name_ == method_name) { + return clusterEntry(); + } + + return nullptr; +} + +RouteMatcher::RouteMatcher( + const envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration& config) { + for (const auto& route : config.routes()) { + routes_.emplace_back(new MethodNameRouteEntryImpl(route)); + } +} + +RouteConstSharedPtr RouteMatcher::route(const std::string& method_name) const { + for (const auto& route : routes_) { + RouteConstSharedPtr route_entry = route->matches(method_name); + if (nullptr != route_entry) { + return route_entry; + } + } + + return nullptr; +} + +void Router::onDestroy() { + if (upstream_request_ != nullptr) { + upstream_request_->resetStream(); + } + cleanup(); +} + +void Router::setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) { + callbacks_ = &callbacks; + + // TODO(zuercher): handle buffer limits +} + +void Router::resetUpstreamConnection() { + if (upstream_request_ != nullptr) { + upstream_request_->resetStream(); + } +} + +ThriftFilters::FilterStatus Router::transportBegin(absl::optional size) { + UNREFERENCED_PARAMETER(size); + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus Router::transportEnd() { + if (upstream_request_->msg_type_ == MessageType::Oneway) { + // No response expected + upstream_request_->onResponseComplete(); + cleanup(); + } + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus Router::messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) { + // TODO(zuercher): route stats (e.g., no_route, no_cluster, upstream_rq_maintenance_mode, no + // healtthy upstream) + + route_ = callbacks_->route(); + if (!route_) { + ENVOY_STREAM_LOG(debug, "no cluster match for method '{}'", *callbacks_, name); + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ + new AppException(name, seq_id, AppExceptionType::UnknownMethod, + fmt::format("no route for method '{}'", name))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + route_entry_ = route_->routeEntry(); + + Upstream::ThreadLocalCluster* cluster = cluster_manager_.get(route_entry_->clusterName()); + if (!cluster) { + ENVOY_STREAM_LOG(debug, "unknown cluster '{}'", *callbacks_, route_entry_->clusterName()); + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ + new AppException(name, seq_id, AppExceptionType::InternalError, + fmt::format("unknown cluster '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + cluster_ = cluster->info(); + ENVOY_STREAM_LOG(debug, "cluster '{}' match for method '{}'", *callbacks_, + route_entry_->clusterName(), name); + + if (cluster_->maintenanceMode()) { + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + name, seq_id, AppExceptionType::InternalError, + fmt::format("maintenance mode for cluster '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( + route_entry_->clusterName(), Upstream::ResourcePriority::Default, this); + if (!conn_pool) { + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + name, seq_id, AppExceptionType::InternalError, + fmt::format("no healthy upstream for '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + ENVOY_STREAM_LOG(debug, "router decoding request", *callbacks_); + + upstream_request_.reset(new UpstreamRequest(*this, *conn_pool, name, msg_type, seq_id)); + upstream_request_->start(); + return ThriftFilters::FilterStatus::StopIteration; +} + +ThriftFilters::FilterStatus Router::messageEnd() { + ProtocolConverter::messageEnd(); + + Buffer::OwnedImpl transport_buffer; + upstream_request_->transport_->encodeFrame(transport_buffer, upstream_request_buffer_); + upstream_request_->conn_data_->connection().write(transport_buffer, false); + upstream_request_->onRequestComplete(); + return ThriftFilters::FilterStatus::Continue; +} + +void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { + ASSERT(!upstream_request_->response_complete_); + + if (!upstream_request_->response_started_) { + callbacks_->startUpstreamResponse(upstream_request_->transport_->type(), protocolType()); + upstream_request_->response_started_ = true; + } + + if (callbacks_->upstreamData(data)) { + upstream_request_->onResponseComplete(); + cleanup(); + return; + } + + if (end_stream) { + // Response is incomplete, but no more data is coming. + upstream_request_->onResponseComplete(); + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); + cleanup(); + } +} + +void Router::onEvent(Network::ConnectionEvent event) { + if (!upstream_request_ || upstream_request_->response_complete_) { + // Client closed connection after completing response. + return; + } + + switch (event) { + case Network::ConnectionEvent::RemoteClose: + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); + break; + case Network::ConnectionEvent::LocalClose: + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure); + break; + default: + // Connected is consumed by the connection pool. + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +const Network::Connection* Router::downstreamConnection() const { + if (callbacks_ != nullptr) { + return callbacks_->connection(); + } + + return nullptr; +} + +void Router::convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id) { + ProtocolConverter::messageBegin(absl::string_view(name), msg_type, seq_id); +} + +void Router::cleanup() { upstream_request_.reset(); } + +Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, + absl::string_view method_name, MessageType msg_type, + int32_t seq_id) + : parent_(parent), conn_pool_(pool), method_name_(std::string(method_name)), + msg_type_(msg_type), seq_id_(seq_id), request_complete_(false), response_started_(false), + response_complete_(false) {} + +Router::UpstreamRequest::~UpstreamRequest() {} + +void Router::UpstreamRequest::start() { + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(*this); + if (handle) { + conn_pool_handle_ = handle; + } +} + +void Router::UpstreamRequest::resetStream() { + if (conn_data_ != nullptr) { + conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); + conn_data_.reset(); + } +} + +void Router::UpstreamRequest::onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) { + // Mimic an upstream reset. + onUpstreamHostSelected(host); + onResetStream(reason); +} + +void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn_data, + Upstream::HostDescriptionConstSharedPtr host) { + onUpstreamHostSelected(host); + conn_data_ = std::move(conn_data); + conn_data_->addUpstreamCallbacks(parent_); + + conn_pool_handle_ = nullptr; + + // TODO(zuercher): let cluster specify a specific transport and protocol + transport_ = + NamedTransportConfigFactory::getFactory(parent_.callbacks_->downstreamTransportType()) + .createTransport(); + + parent_.initProtocolConverter( + NamedProtocolConfigFactory::getFactory(parent_.callbacks_->downstreamProtocolType()) + .createProtocol(), + parent_.upstream_request_buffer_); + + // TODO(zuercher): need to use an upstream-connection-specific sequence id + parent_.convertMessageBegin(method_name_, msg_type_, seq_id_); + + parent_.callbacks_->continueDecoding(); +} + +void Router::UpstreamRequest::onRequestComplete() { request_complete_ = true; } + +void Router::UpstreamRequest::onResponseComplete() { + response_complete_ = true; + conn_data_.reset(); +} + +void Router::UpstreamRequest::onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host) { + upstream_host_ = host; +} + +void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReason reason) { + switch (reason) { + case Tcp::ConnectionPool::PoolFailureReason::Overflow: + parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + method_name_, seq_id_, AppExceptionType::InternalError, + fmt::format("too many connections to '{}'", upstream_host_->address()->asString()))}); + break; + case Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::Timeout: + // TODO(zuercher): distinguish between these cases where appropriate (particularly timeout) + if (!response_started_) { + parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + method_name_, seq_id_, AppExceptionType::InternalError, + fmt::format("connection failure '{}'", upstream_host_->address()->asString()))}); + return; + } + + parent_.callbacks_->resetDownstreamConnection(); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h new file mode 100644 index 0000000000000..d298d38b354cd --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include + +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/tcp/conn_pool.h" +#include "envoy/upstream/load_balancer.h" + +#include "common/common/logger.h" + +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +class RouteEntryImplBase : public RouteEntry, + public Route, + public std::enable_shared_from_this { +public: + RouteEntryImplBase(const envoy::config::filter::network::thrift_proxy::v2alpha1::Route& route); + + // Router::RouteEntry + const std::string& clusterName() const override; + + // Router::Route + const RouteEntry* routeEntry() const override; + + virtual RouteConstSharedPtr matches(const std::string& method_name) const PURE; + +protected: + RouteConstSharedPtr clusterEntry() const; + +private: + const std::string cluster_name_; +}; + +typedef std::shared_ptr RouteEntryImplBaseConstSharedPtr; + +class MethodNameRouteEntryImpl : public RouteEntryImplBase { +public: + MethodNameRouteEntryImpl( + const envoy::config::filter::network::thrift_proxy::v2alpha1::Route& route); + + const std::string& methodName() const { return method_name_; } + + // RoutEntryImplBase + RouteConstSharedPtr matches(const std::string& method_name) const override; + +private: + const std::string method_name_; +}; + +class RouteMatcher { +public: + RouteMatcher(const envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration&); + + RouteConstSharedPtr route(const std::string& method_name) const; + +private: + std::vector routes_; +}; + +class Router : public Tcp::ConnectionPool::UpstreamCallbacks, + public Upstream::LoadBalancerContext, + public ProtocolConverter, + Logger::Loggable { +public: + Router(Upstream::ClusterManager& cluster_manager) : cluster_manager_(cluster_manager) {} + + ~Router() {} + + // ProtocolConverter + void onDestroy() override; + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) override; + void resetUpstreamConnection() override; + ThriftFilters::FilterStatus transportBegin(absl::optional size) override; + ThriftFilters::FilterStatus transportEnd() override; + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override; + ThriftFilters::FilterStatus messageEnd() override; + + // Upstream::LoadBalancerContext + absl::optional computeHashKey() override { return {}; } + const Envoy::Router::MetadataMatchCriteria* metadataMatchCriteria() override { return nullptr; } + const Network::Connection* downstreamConnection() const override; + const Http::HeaderMap* downstreamHeaders() const override { return nullptr; } + + // Tcp::ConnectionPool::UpstreamCallbacks + void onUpstreamData(Buffer::Instance& data, bool end_stream) override; + void onEvent(Network::ConnectionEvent event) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + +private: + struct UpstreamRequest : public Tcp::ConnectionPool::Callbacks { + UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, + absl::string_view method_name, MessageType msg_type, int32_t seq_id); + ~UpstreamRequest(); + + void start(); + void resetStream(); + + // Tcp::ConnectionPool::Callbacks + void onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) override; + void onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn, + Upstream::HostDescriptionConstSharedPtr host) override; + + void onRequestComplete(); + void onResponseComplete(); + void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host); + void onResetStream(Tcp::ConnectionPool::PoolFailureReason reason); + + Router& parent_; + Tcp::ConnectionPool::Instance& conn_pool_; + const std::string method_name_; + const MessageType msg_type_; + const int32_t seq_id_; + + Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; + Tcp::ConnectionPool::ConnectionDataPtr conn_data_; + Upstream::HostDescriptionConstSharedPtr upstream_host_; + TransportPtr transport_; + ProtocolType proto_type_{ProtocolType::Auto}; + + bool request_complete_ : 1; + bool response_started_ : 1; + bool response_complete_ : 1; + }; + + void convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id); + void cleanup(); + + Upstream::ClusterManager& cluster_manager_; + + ThriftFilters::DecoderFilterCallbacks* callbacks_{}; + RouteConstSharedPtr route_{}; + const RouteEntry* route_entry_{}; + Upstream::ClusterInfoConstSharedPtr cluster_; + + std::unique_ptr upstream_request_; + Buffer::OwnedImpl upstream_request_buffer_; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/stats.h b/source/extensions/filters/network/thrift_proxy/stats.h new file mode 100644 index 0000000000000..a630fc735d034 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/stats.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +#include "envoy/stats/stats.h" +#include "envoy/stats/stats_macros.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * All thrift filter stats. @see stats_macros.h + */ +// clang-format off +#define ALL_THRIFT_FILTER_STATS(COUNTER, GAUGE, HISTOGRAM) \ + COUNTER(request) \ + COUNTER(request_call) \ + COUNTER(request_oneway) \ + COUNTER(request_invalid_type) \ + GAUGE(request_active) \ + COUNTER(request_decoding_error) \ + HISTOGRAM(request_time_ms) \ + COUNTER(response) \ + COUNTER(response_reply) \ + COUNTER(response_success) \ + COUNTER(response_error) \ + COUNTER(response_exception) \ + COUNTER(response_invalid_type) \ + COUNTER(response_decoding_error) \ + COUNTER(cx_destroy_local_with_active_rq) \ + COUNTER(cx_destroy_remote_with_active_rq) +// clang-format on + +/** + * Struct definition for all mongo proxy stats. @see stats_macros.h + */ +struct ThriftFilterStats { + ALL_THRIFT_FILTER_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, GENERATE_HISTOGRAM_STRUCT) + + static ThriftFilterStats generateStats(const std::string& prefix, Stats::Scope& scope) { + return ThriftFilterStats{ALL_THRIFT_FILTER_STATS(POOL_COUNTER_PREFIX(scope, prefix), + POOL_GAUGE_PREFIX(scope, prefix), + POOL_HISTOGRAM_PREFIX(scope, prefix))}; + } +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/transport.h b/source/extensions/filters/network/thrift_proxy/transport.h index 809de8a6ece83..1bda083e6bc3e 100644 --- a/source/extensions/filters/network/thrift_proxy/transport.h +++ b/source/extensions/filters/network/thrift_proxy/transport.h @@ -4,8 +4,10 @@ #include #include "envoy/buffer/buffer.h" +#include "envoy/registry/registry.h" -#include "common/common/fmt.h" +#include "common/common/assert.h" +#include "common/config/utility.h" #include "common/singleton/const_singleton.h" #include "absl/types/optional.h" @@ -15,6 +17,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +enum class TransportType { + Framed, + Unframed, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST TRANSPORT TYPE + LastTransportType = Auto, + +}; + /** * Names of available Transport implementations. */ @@ -28,29 +40,23 @@ class TransportNameValues { // Auto-detection transport const std::string AUTO = "auto"; + + const std::string& fromType(TransportType type) const { + switch (type) { + case TransportType::Framed: + return FRAMED; + case TransportType::Unframed: + return UNFRAMED; + case TransportType::Auto: + return AUTO; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } }; typedef ConstSingleton TransportNames; -/** - * TransportCallbacks are Thrift transport-level callbacks. - */ -class TransportCallbacks { -public: - virtual ~TransportCallbacks() {} - - /** - * Indicates the start of a Thrift transport frame was detected. - * @param size the size of the message, if available to the transport - */ - virtual void transportFrameStart(absl::optional size) PURE; - - /** - * Indicates the end of a Thrift transport frame was detected. - */ - virtual void transportFrameComplete() PURE; -}; - /** * Transport represents a Thrift transport. The Thrift transport is nominally a generic, * bi-directional byte stream. In Envoy we assume it always represents a network byte stream and @@ -67,20 +73,27 @@ class Transport { */ virtual const std::string& name() const PURE; + /** + * @return TransportType the transport type + */ + virtual TransportType type() const PURE; + /* - * decodeFrameStart decodes the start of a transport message, potentially invoking callbacks. - * If successful, the start of the frame is removed from the buffer. + * Decodes the start of a transport message. If successful, the start of the frame is removed + * from the buffer. * * @param buffer the currently buffered thrift data. + * @param size updated with the frame size on success. If frame size is not encoded, the size + * is cleared on success. * @return bool true if a complete frame header was successfully consumed, false if more data * is required. * @throws EnvoyException if the data is not valid for this transport. */ - virtual bool decodeFrameStart(Buffer::Instance& buffer) PURE; + virtual bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) PURE; /* - * decodeFrameEnd decodes the end of a transport message, potentially invoking callbacks. - * If successful, the end of the frame is removed from the buffer. + * Decodes the end of a transport message. If successful, the end of the frame is removed from + * the buffer. * * @param buffer the currently buffered thrift data. * @return bool true if a complete frame trailer was successfully consumed, false if more data @@ -88,83 +101,63 @@ class Transport { * @throws EnvoyException if the data is not valid for this transport. */ virtual bool decodeFrameEnd(Buffer::Instance& buffer) PURE; + + /** + * Wraps the given message buffer with the transport's header and trailer (if any). After + * encoding, message will be empty. + * @param buffer is the output buffer + * @param message a protocol-encoded message + * @throws EnvoyException if the message is too large for the transport + */ + virtual void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) PURE; }; typedef std::unique_ptr TransportPtr; -/* - * TransportImplBase provides a base class for Transport implementations. - */ -class TransportImplBase : public virtual Transport { -public: - TransportImplBase(TransportCallbacks& callbacks) : callbacks_(callbacks) {} - -protected: - void onFrameStart(absl::optional size) const { callbacks_.transportFrameStart(size); } - void onFrameComplete() const { callbacks_.transportFrameComplete(); } - - TransportCallbacks& callbacks_; -}; - /** - * FramedTransportImpl implements the Thrift Framed transport. - * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + * Implemented by each Thrift transport and registered via Registry::registerFactory or the + * convenience class RegisterFactory. */ -class FramedTransportImpl : public TransportImplBase { +class NamedTransportConfigFactory { public: - FramedTransportImpl(TransportCallbacks& callbacks) : TransportImplBase(callbacks) {} + virtual ~NamedTransportConfigFactory() {} - // Transport - const std::string& name() const override { return TransportNames::get().FRAMED; } - bool decodeFrameStart(Buffer::Instance& buffer) override; - bool decodeFrameEnd(Buffer::Instance& buffer) override; - - static const int32_t MaxFrameSize = 0xFA0000; -}; + /** + * Create a particular Thrift transport. + * @return TransportPtr the transport + */ + virtual TransportPtr createTransport() PURE; -/** - * UnframedTransportImpl implements the Thrift Unframed transport. - * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md - */ -class UnframedTransportImpl : public TransportImplBase { -public: - UnframedTransportImpl(TransportCallbacks& callbacks) : TransportImplBase(callbacks) {} + /** + * @return std::string the identifying name for a particular implementation of thrift transport + * produced by the factory. + */ + virtual std::string name() PURE; - // Transport - const std::string& name() const override { return TransportNames::get().UNFRAMED; } - bool decodeFrameStart(Buffer::Instance&) override { - onFrameStart(absl::optional()); - return true; - } - bool decodeFrameEnd(Buffer::Instance&) override { - onFrameComplete(); - return true; + /** + * Convenience method to lookup a factory by type. + * @param TransportType the transport type + * @return NamedTransportConfigFactory& for the TransportType + */ + static NamedTransportConfigFactory& getFactory(TransportType type) { + const std::string& name = TransportNames::get().fromType(type); + return Envoy::Config::Utility::getAndCheckFactory(name); } }; /** - * AutoTransportImpl implements Transport and attempts to distinguish between the Thrift framed and - * unframed transports. Once the transport is detected, subsequent operations are delegated to the - * appropriate implementation. + * TransportFactoryBase provides a template for a trivial NamedTransportConfigFactory. */ -class AutoTransportImpl : public TransportImplBase { -public: - AutoTransportImpl(TransportCallbacks& callbacks) - : TransportImplBase(callbacks), name_(TransportNames::get().AUTO){}; +template class TransportFactoryBase : public NamedTransportConfigFactory { + TransportPtr createTransport() override { return std::move(std::make_unique()); } - // Transport - const std::string& name() const override { return name_; } - bool decodeFrameStart(Buffer::Instance& buffer) override; - bool decodeFrameEnd(Buffer::Instance& buffer) override; + std::string name() override { return name_; } -private: - void setTransport(TransportPtr&& transport) { - transport_ = std::move(transport); - name_ = fmt::format("{}({})", transport_->name(), TransportNames::get().AUTO); - } +protected: + TransportFactoryBase(const std::string& name) : name_(name) {} - TransportPtr transport_{}; - std::string name_; +private: + const std::string name_; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/transport.cc b/source/extensions/filters/network/thrift_proxy/transport_impl.cc similarity index 65% rename from source/extensions/filters/network/thrift_proxy/transport.cc rename to source/extensions/filters/network/thrift_proxy/transport_impl.cc index c2531379bb2a5..ff177884d94c5 100644 --- a/source/extensions/filters/network/thrift_proxy/transport.cc +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.cc @@ -1,48 +1,21 @@ -#include "extensions/filters/network/thrift_proxy/transport.h" - -#include - -#include -#include +#include "extensions/filters/network/thrift_proxy/transport_impl.h" #include "envoy/common/exception.h" #include "common/common/assert.h" -#include "common/common/byte_order.h" -#include "common/common/utility.h" -#include "extensions/filters/network/thrift_proxy/binary_protocol.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/compact_protocol.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { - if (buffer.length() < 4) { - return false; - } - - int32_t size = BufferHelper::peekI32(buffer); - - if (size <= 0 || size > MaxFrameSize) { - throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", size)); - } - - onFrameStart(absl::optional(static_cast(size))); - - buffer.drain(4); - return true; -} - -bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { - onFrameComplete(); - return true; -} - -bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { +bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) { if (transport_ == nullptr) { // Not enough data to select a transport. if (buffer.length() < 8) { @@ -57,13 +30,13 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { // is configurable, but defaults to 256 MB (0x1000000). THeaderTransport will take up to ~1GB // (0x3FFFFFFF) when it falls back to framed mode. if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { - setTransport(std::make_unique(callbacks_)); + setTransport(std::make_unique()); } } else { // Check for sane unframed protocol. proto_start = static_cast((size >> 16) & 0xFFFF); if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { - setTransport(std::make_unique(callbacks_)); + setTransport(std::make_unique()); } } @@ -78,14 +51,29 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { } } - return transport_->decodeFrameStart(buffer); + return transport_->decodeFrameStart(buffer, size); } bool AutoTransportImpl::decodeFrameEnd(Buffer::Instance& buffer) { - RELEASE_ASSERT(transport_ != nullptr); + RELEASE_ASSERT(transport_ != nullptr, ""); return transport_->decodeFrameEnd(buffer); } +void AutoTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) { + RELEASE_ASSERT(transport_ != nullptr, ""); + transport_->encodeFrame(buffer, message); +} + +class AutoTransportConfigFactory : public TransportFactoryBase { +public: + AutoTransportConfigFactory() : TransportFactoryBase(TransportNames::get().AUTO) {} +}; + +/** + * Static registration for the auto transport. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.h b/source/extensions/filters/network/thrift_proxy/transport_impl.h new file mode 100644 index 0000000000000..08281c53c6d16 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.h @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include "envoy/buffer/buffer.h" + +#include "common/common/fmt.h" + +#include "extensions/filters/network/thrift_proxy/transport.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * AutoTransportImpl implements Transport and attempts to distinguish between the Thrift framed and + * unframed transports. Once the transport is detected, subsequent operations are delegated to the + * appropriate implementation. + */ +class AutoTransportImpl : public Transport { +public: + AutoTransportImpl() : name_(TransportNames::get().AUTO){}; + + // Transport + const std::string& name() const override { return name_; } + TransportType type() const override { + if (transport_ != nullptr) { + return transport_->type(); + } + + return TransportType::Auto; + } + bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; + bool decodeFrameEnd(Buffer::Instance& buffer) override; + void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; + + /* + * Explicitly set the transport. Public to simplify testing. + */ + void setTransport(TransportPtr&& transport) { + transport_ = std::move(transport); + name_ = fmt::format("{}({})", transport_->name(), TransportNames::get().AUTO); + } + +private: + TransportPtr transport_{}; + std::string name_; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc new file mode 100644 index 0000000000000..d3a2744540c96 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc @@ -0,0 +1,22 @@ +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class UnframedTransportConfigFactory : public TransportFactoryBase { +public: + UnframedTransportConfigFactory() : TransportFactoryBase(TransportNames::get().UNFRAMED) {} +}; + +/** + * Static registration for the unframed transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h new file mode 100644 index 0000000000000..29992dfc0812f --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include "envoy/buffer/buffer.h" + +#include "extensions/filters/network/thrift_proxy/transport_impl.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * UnframedTransportImpl implements the Thrift Unframed transport. + * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + */ +class UnframedTransportImpl : public Transport { +public: + UnframedTransportImpl() {} + + // Transport + const std::string& name() const override { return TransportNames::get().UNFRAMED; } + TransportType type() const override { return TransportType::Unframed; } + bool decodeFrameStart(Buffer::Instance&, absl::optional& size) override { + size.reset(); + return true; + } + bool decodeFrameEnd(Buffer::Instance&) override { return true; } + void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override { + buffer.move(message); + } +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/well_known_names.h b/source/extensions/filters/network/well_known_names.h index deffa01fb759c..f9014eb0de7c7 100644 --- a/source/extensions/filters/network/well_known_names.h +++ b/source/extensions/filters/network/well_known_names.h @@ -13,31 +13,31 @@ namespace NetworkFilters { class NetworkFilterNameValues { public: // Client ssl auth filter - const std::string CLIENT_SSL_AUTH = "envoy.client_ssl_auth"; + const std::string ClientSslAuth = "envoy.client_ssl_auth"; // Echo filter - const std::string ECHO = "envoy.echo"; + const std::string Echo = "envoy.echo"; // HTTP connection manager filter - const std::string HTTP_CONNECTION_MANAGER = "envoy.http_connection_manager"; + const std::string HttpConnectionManager = "envoy.http_connection_manager"; // Mongo proxy filter - const std::string MONGO_PROXY = "envoy.mongo_proxy"; + const std::string MongoProxy = "envoy.mongo_proxy"; // Rate limit filter - const std::string RATE_LIMIT = "envoy.ratelimit"; + const std::string RateLimit = "envoy.ratelimit"; // Redis proxy filter - const std::string REDIS_PROXY = "envoy.redis_proxy"; + const std::string RedisProxy = "envoy.redis_proxy"; // IP tagging filter - const std::string TCP_PROXY = "envoy.tcp_proxy"; + const std::string TcpProxy = "envoy.tcp_proxy"; // Authorization filter - const std::string EXT_AUTHORIZATION = "envoy.ext_authz"; + const std::string ExtAuthorization = "envoy.ext_authz"; // Thrift proxy filter - const std::string THRIFT_PROXY = "envoy.filters.network.thrift_proxy"; + const std::string ThriftProxy = "envoy.filters.network.thrift_proxy"; // Converts names from v1 to v2 const Config::V1Converter v1_converter_; // NOTE: Do not add any new filters to this list. All future filters are v2 only. NetworkFilterNameValues() - : v1_converter_({CLIENT_SSL_AUTH, ECHO, HTTP_CONNECTION_MANAGER, MONGO_PROXY, RATE_LIMIT, - REDIS_PROXY, TCP_PROXY, EXT_AUTHORIZATION}) {} + : v1_converter_({ClientSslAuth, Echo, HttpConnectionManager, MongoProxy, RateLimit, + RedisProxy, TcpProxy, ExtAuthorization}) {} }; typedef ConstSingleton NetworkFilterNames; diff --git a/source/extensions/grpc_credentials/example/BUILD b/source/extensions/grpc_credentials/example/BUILD index 00b4477d770a2..a1bdd47df7b82 100644 --- a/source/extensions/grpc_credentials/example/BUILD +++ b/source/extensions/grpc_credentials/example/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Example gRPC Credentials load( diff --git a/source/extensions/grpc_credentials/example/config.h b/source/extensions/grpc_credentials/example/config.h index 1313dbd9ac027..053b79335dc38 100644 --- a/source/extensions/grpc_credentials/example/config.h +++ b/source/extensions/grpc_credentials/example/config.h @@ -30,7 +30,7 @@ class AccessTokenExampleGrpcCredentialsFactory : public Grpc::GoogleGrpcCredenti virtual std::shared_ptr getChannelCredentials(const envoy::api::v2::core::GrpcService& grpc_service_config) override; - std::string name() const override { return GrpcCredentialsNames::get().ACCESS_TOKEN_EXAMPLE; } + std::string name() const override { return GrpcCredentialsNames::get().AccessTokenExample; } }; /* diff --git a/source/extensions/grpc_credentials/file_based_metadata/BUILD b/source/extensions/grpc_credentials/file_based_metadata/BUILD index ae5050dc0cfab..30e63efddd0bd 100644 --- a/source/extensions/grpc_credentials/file_based_metadata/BUILD +++ b/source/extensions/grpc_credentials/file_based_metadata/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # File Based Metadata gRPC Credentials load( diff --git a/source/extensions/grpc_credentials/file_based_metadata/config.cc b/source/extensions/grpc_credentials/file_based_metadata/config.cc index 987827a637de7..951f11c6c6ee9 100644 --- a/source/extensions/grpc_credentials/file_based_metadata/config.cc +++ b/source/extensions/grpc_credentials/file_based_metadata/config.cc @@ -25,7 +25,7 @@ FileBasedMetadataGrpcCredentialsFactory::getChannelCredentials( for (const auto& credential : google_grpc.call_credentials()) { switch (credential.credential_specifier_case()) { case envoy::api::v2::core::GrpcService::GoogleGrpc::CallCredentials::kFromPlugin: { - if (credential.from_plugin().name() == GrpcCredentialsNames::get().FILE_BASED_METADATA) { + if (credential.from_plugin().name() == GrpcCredentialsNames::get().FileBasedMetadata) { FileBasedMetadataGrpcCredentialsFactory file_based_metadata_credentials_factory; const Envoy::ProtobufTypes::MessagePtr file_based_metadata_config_message = Envoy::Config::Utility::translateToFactoryConfig( diff --git a/source/extensions/grpc_credentials/file_based_metadata/config.h b/source/extensions/grpc_credentials/file_based_metadata/config.h index 1880d62c9ffe8..9325e0b7d3d0b 100644 --- a/source/extensions/grpc_credentials/file_based_metadata/config.h +++ b/source/extensions/grpc_credentials/file_based_metadata/config.h @@ -30,7 +30,7 @@ class FileBasedMetadataGrpcCredentialsFactory : public Grpc::GoogleGrpcCredentia return std::make_unique(); } - std::string name() const override { return GrpcCredentialsNames::get().FILE_BASED_METADATA; } + std::string name() const override { return GrpcCredentialsNames::get().FileBasedMetadata; } }; class FileBasedMetadataAuthenticator : public grpc::MetadataCredentialsPlugin { diff --git a/source/extensions/grpc_credentials/well_known_names.h b/source/extensions/grpc_credentials/well_known_names.h index 95678f1edbdc7..81ee6d22ef0b1 100644 --- a/source/extensions/grpc_credentials/well_known_names.h +++ b/source/extensions/grpc_credentials/well_known_names.h @@ -13,9 +13,9 @@ namespace GrpcCredentials { class GrpcCredentialsNameValues { public: // Access Token Example. - const std::string ACCESS_TOKEN_EXAMPLE = "envoy.grpc_credentials.access_token_example"; + const std::string AccessTokenExample = "envoy.grpc_credentials.access_token_example"; // File Based Metadata credentials - const std::string FILE_BASED_METADATA = "envoy.grpc_credentials.file_based_metadata"; + const std::string FileBasedMetadata = "envoy.grpc_credentials.file_based_metadata"; }; typedef ConstSingleton GrpcCredentialsNames; diff --git a/source/extensions/health_checkers/redis/BUILD b/source/extensions/health_checkers/redis/BUILD index a23cfe122a09f..ba83c6899e962 100644 --- a/source/extensions/health_checkers/redis/BUILD +++ b/source/extensions/health_checkers/redis/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Redis custom health checker. load( diff --git a/source/extensions/health_checkers/redis/config.cc b/source/extensions/health_checkers/redis/config.cc index 2b4ebd66f08b4..922e91940b5aa 100644 --- a/source/extensions/health_checkers/redis/config.cc +++ b/source/extensions/health_checkers/redis/config.cc @@ -14,10 +14,9 @@ namespace RedisHealthChecker { Upstream::HealthCheckerSharedPtr RedisHealthCheckerFactory::createCustomHealthChecker( const envoy::api::v2::core::HealthCheck& config, Server::Configuration::HealthCheckerFactoryContext& context) { - return std::make_shared( context.cluster(), config, getRedisHealthCheckConfig(config), context.dispatcher(), - context.runtime(), context.random(), + context.runtime(), context.random(), context.eventLogger(), NetworkFilters::RedisProxy::ConnPool::ClientFactoryImpl::instance_); }; @@ -31,4 +30,4 @@ static Registry::RegisterFactorymakeRequest(pingHealthCheckRequest(), *this); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -73,7 +74,7 @@ void RedisHealthChecker::RedisActiveHealthCheckSession::onResponse( value->asInteger() == 0) { handleSuccess(); } else { - handleFailure(FailureType::Active); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::ACTIVE); } break; case Type::Ping: @@ -81,11 +82,11 @@ void RedisHealthChecker::RedisActiveHealthCheckSession::onResponse( value->asString() == "PONG") { handleSuccess(); } else { - handleFailure(FailureType::Active); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::ACTIVE); } break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } if (!parent_.reuse_connection_) { @@ -95,7 +96,7 @@ void RedisHealthChecker::RedisActiveHealthCheckSession::onResponse( void RedisHealthChecker::RedisActiveHealthCheckSession::onFailure() { current_request_ = nullptr; - handleFailure(FailureType::Network); + handleFailure(envoy::data::core::v2alpha::HealthCheckFailureType::NETWORK); } void RedisHealthChecker::RedisActiveHealthCheckSession::onTimeout() { @@ -125,4 +126,4 @@ RedisHealthChecker::HealthCheckRequest::HealthCheckRequest() { } // namespace RedisHealthChecker } // namespace HealthCheckers } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/source/extensions/health_checkers/redis/redis.h b/source/extensions/health_checkers/redis/redis.h index 07e5567fa935d..5f3c42770d3dc 100644 --- a/source/extensions/health_checkers/redis/redis.h +++ b/source/extensions/health_checkers/redis/redis.h @@ -20,6 +20,7 @@ class RedisHealthChecker : public Upstream::HealthCheckerImplBase { const Upstream::Cluster& cluster, const envoy::api::v2::core::HealthCheck& config, const envoy::config::health_checker::redis::v2::Redis& redis_config, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, Runtime::RandomGenerator& random, + Upstream::HealthCheckEventLoggerPtr&& event_logger, Extensions::NetworkFilters::RedisProxy::ConnPool::ClientFactory& client_factory); static const Extensions::NetworkFilters::RedisProxy::RespValue& pingHealthCheckRequest() { @@ -33,6 +34,11 @@ class RedisHealthChecker : public Upstream::HealthCheckerImplBase { return request->request_; } +protected: + envoy::data::core::v2alpha::HealthCheckerType healthCheckerType() const override { + return envoy::data::core::v2alpha::HealthCheckerType::REDIS; + } + private: struct RedisActiveHealthCheckSession : public ActiveHealthCheckSession, @@ -90,4 +96,4 @@ class RedisHealthChecker : public Upstream::HealthCheckerImplBase { } // namespace RedisHealthChecker } // namespace HealthCheckers } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/source/extensions/health_checkers/redis/utility.h b/source/extensions/health_checkers/redis/utility.h index 52b95ca4899a2..f41b95de8f2e9 100644 --- a/source/extensions/health_checkers/redis/utility.h +++ b/source/extensions/health_checkers/redis/utility.h @@ -10,20 +10,8 @@ namespace RedisHealthChecker { namespace { -static const envoy::config::health_checker::redis::v2::Redis translateFromRedisHealthCheck( - const envoy::api::v2::core::HealthCheck::RedisHealthCheck& deprecated_redis_config) { - envoy::config::health_checker::redis::v2::Redis config; - config.set_key(deprecated_redis_config.key()); - return config; -} - static const envoy::config::health_checker::redis::v2::Redis getRedisHealthCheckConfig(const envoy::api::v2::core::HealthCheck& hc_config) { - // TODO(dio): redis_health_check is deprecated. - if (hc_config.has_redis_health_check()) { - return translateFromRedisHealthCheck(hc_config.redis_health_check()); - } - ProtobufTypes::MessagePtr config = ProtobufTypes::MessagePtr{new envoy::config::health_checker::redis::v2::Redis()}; MessageUtil::jsonConvert(hc_config.custom_health_check().config(), *config); diff --git a/source/extensions/health_checkers/well_known_names.h b/source/extensions/health_checkers/well_known_names.h index 9120d9cfd6b8b..26271e859b6a2 100644 --- a/source/extensions/health_checkers/well_known_names.h +++ b/source/extensions/health_checkers/well_known_names.h @@ -13,7 +13,7 @@ namespace HealthCheckers { class HealthCheckerNameValues { public: // Redis health checker. - const std::string REDIS_HEALTH_CHECKER = "envoy.health_checkers.redis"; + const std::string RedisHealthChecker = "envoy.health_checkers.redis"; }; typedef ConstSingleton HealthCheckerNames; diff --git a/source/extensions/resource_monitors/BUILD b/source/extensions/resource_monitors/BUILD new file mode 100644 index 0000000000000..6156949edef64 --- /dev/null +++ b/source/extensions/resource_monitors/BUILD @@ -0,0 +1,17 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "well_known_names", + hdrs = ["well_known_names.h"], + deps = [ + "//source/common/singleton:const_singleton", + ], +) diff --git a/source/extensions/resource_monitors/common/BUILD b/source/extensions/resource_monitors/common/BUILD new file mode 100644 index 0000000000000..ff6773aaa8d13 --- /dev/null +++ b/source/extensions/resource_monitors/common/BUILD @@ -0,0 +1,18 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "factory_base_lib", + hdrs = ["factory_base.h"], + deps = [ + "//include/envoy/server:resource_monitor_config_interface", + "//source/common/protobuf:utility_lib", + ], +) diff --git a/source/extensions/resource_monitors/common/factory_base.h b/source/extensions/resource_monitors/common/factory_base.h new file mode 100644 index 0000000000000..776f4c486bd7e --- /dev/null +++ b/source/extensions/resource_monitors/common/factory_base.h @@ -0,0 +1,42 @@ +#pragma once + +#include "envoy/server/resource_monitor_config.h" + +#include "common/protobuf/utility.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace Common { + +template +class FactoryBase : public Server::Configuration::ResourceMonitorFactory { +public: + Server::ResourceMonitorPtr + createResourceMonitor(const Protobuf::Message& config, + Server::Configuration::ResourceMonitorFactoryContext& context) override { + return createResourceMonitorFromProtoTyped( + MessageUtil::downcastAndValidate(config), context); + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return name_; } + +protected: + FactoryBase(const std::string& name) : name_(name) {} + +private: + virtual Server::ResourceMonitorPtr createResourceMonitorFromProtoTyped( + const ConfigProto& config, + Server::Configuration::ResourceMonitorFactoryContext& context) PURE; + + const std::string name_; +}; + +} // namespace Common +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/resource_monitors/fixed_heap/BUILD b/source/extensions/resource_monitors/fixed_heap/BUILD new file mode 100644 index 0000000000000..f9042c54305ee --- /dev/null +++ b/source/extensions/resource_monitors/fixed_heap/BUILD @@ -0,0 +1,34 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "fixed_heap_monitor", + srcs = ["fixed_heap_monitor.cc"], + hdrs = ["fixed_heap_monitor.h"], + deps = [ + "//include/envoy/server:resource_monitor_config_interface", + "//source/common/common:assert_lib", + "//source/common/memory:stats_lib", + "@envoy_api//envoy/config/resource_monitor/fixed_heap/v2alpha:fixed_heap_cc", + ], +) + +envoy_cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + ":fixed_heap_monitor", + "//include/envoy/registry", + "//source/common/common:assert_lib", + "//source/extensions/resource_monitors:well_known_names", + "//source/extensions/resource_monitors/common:factory_base_lib", + ], +) diff --git a/source/extensions/resource_monitors/fixed_heap/config.cc b/source/extensions/resource_monitors/fixed_heap/config.cc new file mode 100644 index 0000000000000..d0313789ff014 --- /dev/null +++ b/source/extensions/resource_monitors/fixed_heap/config.cc @@ -0,0 +1,30 @@ +#include "extensions/resource_monitors/fixed_heap/config.h" + +#include "envoy/registry/registry.h" + +#include "common/protobuf/utility.h" + +#include "extensions/resource_monitors/fixed_heap/fixed_heap_monitor.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace FixedHeapMonitor { + +Server::ResourceMonitorPtr FixedHeapMonitorFactory::createResourceMonitorFromProtoTyped( + const envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig& config, + Server::Configuration::ResourceMonitorFactoryContext& /*unused_context*/) { + return std::make_unique(config); +} + +/** + * Static registration for the fixed heap resource monitor factory. @see RegistryFactory. + */ +static Registry::RegisterFactory + registered_; + +} // namespace FixedHeapMonitor +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/resource_monitors/fixed_heap/config.h b/source/extensions/resource_monitors/fixed_heap/config.h new file mode 100644 index 0000000000000..b429812e6bcfb --- /dev/null +++ b/source/extensions/resource_monitors/fixed_heap/config.h @@ -0,0 +1,29 @@ +#pragma once + +#include "envoy/config/resource_monitor/fixed_heap/v2alpha/fixed_heap.pb.validate.h" +#include "envoy/server/resource_monitor_config.h" + +#include "extensions/resource_monitors/common/factory_base.h" +#include "extensions/resource_monitors/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace FixedHeapMonitor { + +class FixedHeapMonitorFactory + : public Common::FactoryBase< + envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig> { +public: + FixedHeapMonitorFactory() : FactoryBase(ResourceMonitorNames::get().FixedHeap) {} + +private: + Server::ResourceMonitorPtr createResourceMonitorFromProtoTyped( + const envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig& config, + Server::Configuration::ResourceMonitorFactoryContext& context) override; +}; + +} // namespace FixedHeapMonitor +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/resource_monitors/fixed_heap/fixed_heap_monitor.cc b/source/extensions/resource_monitors/fixed_heap/fixed_heap_monitor.cc new file mode 100644 index 0000000000000..a968856aa04db --- /dev/null +++ b/source/extensions/resource_monitors/fixed_heap/fixed_heap_monitor.cc @@ -0,0 +1,37 @@ +#include "extensions/resource_monitors/fixed_heap/fixed_heap_monitor.h" + +#include "common/common/assert.h" +#include "common/memory/stats.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace FixedHeapMonitor { + +uint64_t MemoryStatsReader::reservedHeapBytes() { return Memory::Stats::totalCurrentlyReserved(); } + +uint64_t MemoryStatsReader::unmappedHeapBytes() { return Memory::Stats::totalPageHeapUnmapped(); } + +FixedHeapMonitor::FixedHeapMonitor( + const envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig& config, + std::unique_ptr stats) + : max_heap_(config.max_heap_size_bytes()), stats_(std::move(stats)) { + ASSERT(max_heap_ > 0); +} + +void FixedHeapMonitor::updateResourceUsage(Server::ResourceMonitor::Callbacks& callbacks) { + const size_t physical = stats_->reservedHeapBytes(); + const size_t unmapped = stats_->unmappedHeapBytes(); + ASSERT(physical >= unmapped); + const size_t used = physical - unmapped; + + Server::ResourceUsage usage; + usage.resource_pressure_ = used / static_cast(max_heap_); + + callbacks.onSuccess(usage); +} + +} // namespace FixedHeapMonitor +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/resource_monitors/fixed_heap/fixed_heap_monitor.h b/source/extensions/resource_monitors/fixed_heap/fixed_heap_monitor.h new file mode 100644 index 0000000000000..9bc4bb2697c0a --- /dev/null +++ b/source/extensions/resource_monitors/fixed_heap/fixed_heap_monitor.h @@ -0,0 +1,44 @@ +#pragma once + +#include "envoy/config/resource_monitor/fixed_heap/v2alpha/fixed_heap.pb.validate.h" +#include "envoy/server/resource_monitor.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace FixedHeapMonitor { + +/** + * Helper class for getting memory heap stats. + */ +class MemoryStatsReader { +public: + MemoryStatsReader() {} + virtual ~MemoryStatsReader() {} + + // Memory reserved for the process by the heap. + virtual uint64_t reservedHeapBytes(); + // Memory in free, unmapped pages in the page heap. + virtual uint64_t unmappedHeapBytes(); +}; + +/** + * Heap memory monitor with a statically configured maximum. + */ +class FixedHeapMonitor : public Server::ResourceMonitor { +public: + FixedHeapMonitor( + const envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig& config, + std::unique_ptr stats = std::make_unique()); + + void updateResourceUsage(Server::ResourceMonitor::Callbacks& callbacks) override; + +private: + const uint64_t max_heap_; + std::unique_ptr stats_; +}; + +} // namespace FixedHeapMonitor +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/resource_monitors/well_known_names.h b/source/extensions/resource_monitors/well_known_names.h new file mode 100644 index 0000000000000..4bbd575768f18 --- /dev/null +++ b/source/extensions/resource_monitors/well_known_names.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common/singleton/const_singleton.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { + +/** + * Well-known resource monior names. + * NOTE: New resource monitors should use the well known name: envoy.resource_monitors.name. + */ +class ResourceMonitorNameValues { +public: + // Heap monitor with statically configured max. + const std::string FixedHeap = "envoy.resource_monitors.fixed_heap"; +}; + +typedef ConstSingleton ResourceMonitorNames; + +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/stat_sinks/common/statsd/statsd.cc b/source/extensions/stat_sinks/common/statsd/statsd.cc index c802a67da673c..a026f8ab90c06 100644 --- a/source/extensions/stat_sinks/common/statsd/statsd.cc +++ b/source/extensions/stat_sinks/common/statsd/statsd.cc @@ -23,13 +23,13 @@ Writer::Writer(Network::Address::InstanceConstSharedPtr address) { fd_ = address->socket(Network::Address::SocketType::Datagram); ASSERT(fd_ != -1); - int rc = address->connect(fd_); - ASSERT(rc != -1); + const Api::SysCallResult result = address->connect(fd_); + ASSERT(result.rc_ != -1); } Writer::~Writer() { if (fd_ != -1) { - RELEASE_ASSERT(close(fd_) == 0); + RELEASE_ASSERT(close(fd_) == 0, ""); } } diff --git a/source/extensions/stat_sinks/dog_statsd/BUILD b/source/extensions/stat_sinks/dog_statsd/BUILD index 6a018243b183a..83273b1afe426 100644 --- a/source/extensions/stat_sinks/dog_statsd/BUILD +++ b/source/extensions/stat_sinks/dog_statsd/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Stats sink for the DataDog (https://www.datadoghq.com/) variant of the statsd protocol # (https://docs.datadoghq.com/developers/dogstatsd/). diff --git a/source/extensions/stat_sinks/dog_statsd/config.cc b/source/extensions/stat_sinks/dog_statsd/config.cc index 4e8a4b7d32bbe..9ae4436ba5eae 100644 --- a/source/extensions/stat_sinks/dog_statsd/config.cc +++ b/source/extensions/stat_sinks/dog_statsd/config.cc @@ -30,7 +30,7 @@ ProtobufTypes::MessagePtr DogStatsdSinkFactory::createEmptyConfigProto() { new envoy::config::metrics::v2::DogStatsdSink()); } -std::string DogStatsdSinkFactory::name() { return StatsSinkNames::get().DOG_STATSD; } +std::string DogStatsdSinkFactory::name() { return StatsSinkNames::get().DogStatsd; } /** * Static registration for the this sink factory. @see RegisterFactory. diff --git a/source/extensions/stat_sinks/hystrix/BUILD b/source/extensions/stat_sinks/hystrix/BUILD new file mode 100644 index 0000000000000..334121f54e079 --- /dev/null +++ b/source/extensions/stat_sinks/hystrix/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2 + +# Stats sink for the basic version of the hystrix protocol (https://github.com/b/hystrix_spec). + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + "//include/envoy/registry", + "//source/common/network:address_lib", + "//source/common/network:resolver_lib", + "//source/extensions/stat_sinks:well_known_names", + "//source/extensions/stat_sinks/hystrix:hystrix_lib", + "//source/server:configuration_lib", + "@envoy_api//envoy/config/metrics/v2:stats_cc", + ], +) + +envoy_cc_library( + name = "hystrix_lib", + srcs = ["hystrix.cc"], + hdrs = ["hystrix.h"], + deps = [ + "//include/envoy/server:admin_interface", + "//include/envoy/server:instance_interface", + "//include/envoy/stats:stats_interface", + "//source/common/buffer:buffer_lib", + "//source/common/common:logger_lib", + "//source/common/http:headers_lib", + ], +) diff --git a/source/extensions/stat_sinks/hystrix/config.cc b/source/extensions/stat_sinks/hystrix/config.cc new file mode 100644 index 0000000000000..59abdb0d29c36 --- /dev/null +++ b/source/extensions/stat_sinks/hystrix/config.cc @@ -0,0 +1,40 @@ +#include "extensions/stat_sinks/hystrix/config.h" + +#include "envoy/config/metrics/v2/stats.pb.h" +#include "envoy/config/metrics/v2/stats.pb.validate.h" +#include "envoy/registry/registry.h" + +#include "common/network/resolver_impl.h" + +#include "extensions/stat_sinks/hystrix/hystrix.h" +#include "extensions/stat_sinks/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace StatSinks { +namespace Hystrix { + +Stats::SinkPtr HystrixSinkFactory::createStatsSink(const Protobuf::Message& config, + Server::Instance& server) { + const auto& hystrix_sink = + MessageUtil::downcastAndValidate(config); + return std::make_unique(server, hystrix_sink.num_buckets()); +} + +ProtobufTypes::MessagePtr HystrixSinkFactory::createEmptyConfigProto() { + return std::unique_ptr( + new envoy::config::metrics::v2::HystrixSink()); +} + +std::string HystrixSinkFactory::name() { return StatsSinkNames::get().Hystrix; } + +/** + * Static registration for the statsd sink factory. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace Hystrix +} // namespace StatSinks +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/stat_sinks/hystrix/config.h b/source/extensions/stat_sinks/hystrix/config.h new file mode 100644 index 0000000000000..2f3f7c37f8783 --- /dev/null +++ b/source/extensions/stat_sinks/hystrix/config.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "envoy/server/instance.h" + +#include "server/configuration_impl.h" + +namespace Envoy { +namespace Extensions { +namespace StatSinks { +namespace Hystrix { + +class HystrixSinkFactory : Logger::Loggable, + public Server::Configuration::StatsSinkFactory { +public: + // StatsSinkFactory + Stats::SinkPtr createStatsSink(const Protobuf::Message& config, + Server::Instance& server) override; + + ProtobufTypes::MessagePtr createEmptyConfigProto() override; + + std::string name() override; +}; + +} // namespace Hystrix +} // namespace StatSinks +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/stat_sinks/hystrix/hystrix.cc b/source/extensions/stat_sinks/hystrix/hystrix.cc new file mode 100644 index 0000000000000..2618133e0f50b --- /dev/null +++ b/source/extensions/stat_sinks/hystrix/hystrix.cc @@ -0,0 +1,375 @@ +#include "extensions/stat_sinks/hystrix/hystrix.h" + +#include +#include +#include +#include + +#include "common/buffer/buffer_impl.h" +#include "common/common/logger.h" +#include "common/http/headers.h" + +#include "absl/strings/str_cat.h" + +namespace Envoy { +namespace Extensions { +namespace StatSinks { +namespace Hystrix { + +const uint64_t HystrixSink::DEFAULT_NUM_BUCKETS; + +ClusterStatsCache::ClusterStatsCache(const std::string& cluster_name) + : cluster_name_(cluster_name) {} + +void ClusterStatsCache::printToStream(std::stringstream& out_str) { + const std::string cluster_name_prefix = absl::StrCat(cluster_name_, "."); + + printRollingWindow(absl::StrCat(cluster_name_prefix, "success"), success_, out_str); + printRollingWindow(absl::StrCat(cluster_name_prefix, "errors"), errors_, out_str); + printRollingWindow(absl::StrCat(cluster_name_prefix, "timeouts"), timeouts_, out_str); + printRollingWindow(absl::StrCat(cluster_name_prefix, "rejected"), rejected_, out_str); + printRollingWindow(absl::StrCat(cluster_name_prefix, "total"), total_, out_str); +} + +void ClusterStatsCache::printRollingWindow(absl::string_view name, RollingWindow rolling_window, + std::stringstream& out_str) { + out_str << name << " | "; + for (auto specific_stat_vec_itr = rolling_window.begin(); + specific_stat_vec_itr != rolling_window.end(); ++specific_stat_vec_itr) { + out_str << *specific_stat_vec_itr << " | "; + } + out_str << std::endl; +} + +// Add new value to rolling window, in place of oldest one. +void HystrixSink::pushNewValue(RollingWindow& rolling_window, uint64_t value) { + if (rolling_window.empty()) { + rolling_window.resize(window_size_, value); + } else { + rolling_window[current_index_] = value; + } +} + +uint64_t HystrixSink::getRollingValue(RollingWindow rolling_window) { + + if (rolling_window.empty()) { + return 0; + } + // If the counter was reset, the result is negative + // better return 0, will be back to normal once one rolling window passes. + if (rolling_window[current_index_] < rolling_window[(current_index_ + 1) % window_size_]) { + return 0; + } else { + return rolling_window[current_index_] - rolling_window[(current_index_ + 1) % window_size_]; + } +} + +void HystrixSink::updateRollingWindowMap(const Upstream::ClusterInfo& cluster_info, + ClusterStatsCache& cluster_stats_cache) { + const std::string cluster_name = cluster_info.name(); + Upstream::ClusterStats& cluster_stats = cluster_info.stats(); + Stats::Scope& cluster_stats_scope = cluster_info.statsScope(); + + // Combining timeouts+retries - retries are counted as separate requests + // (alternative: each request including the retries counted as 1). + uint64_t timeouts = cluster_stats.upstream_rq_timeout_.value() + + cluster_stats.upstream_rq_per_try_timeout_.value(); + + pushNewValue(cluster_stats_cache.timeouts_, timeouts); + + // Combining errors+retry errors - retries are counted as separate requests + // (alternative: each request including the retries counted as 1) + // since timeouts are 504 (or 408), deduce them from here ("-" sign). + // Timeout retries were not counted here anyway. + uint64_t errors = cluster_stats_scope.counter("upstream_rq_5xx").value() + + cluster_stats_scope.counter("retry.upstream_rq_5xx").value() + + cluster_stats_scope.counter("upstream_rq_4xx").value() + + cluster_stats_scope.counter("retry.upstream_rq_4xx").value() - + cluster_stats.upstream_rq_timeout_.value(); + + pushNewValue(cluster_stats_cache.errors_, errors); + + uint64_t success = cluster_stats_scope.counter("upstream_rq_2xx").value(); + pushNewValue(cluster_stats_cache.success_, success); + + uint64_t rejected = cluster_stats.upstream_rq_pending_overflow_.value(); + pushNewValue(cluster_stats_cache.rejected_, rejected); + + // should not take from upstream_rq_total since it is updated before its components, + // leading to wrong results such as error percentage higher than 100% + uint64_t total = errors + timeouts + success + rejected; + pushNewValue(cluster_stats_cache.total_, total); + + ENVOY_LOG(trace, "{}", printRollingWindows()); +} + +void HystrixSink::resetRollingWindow() { cluster_stats_cache_map_.clear(); } + +void HystrixSink::addStringToStream(absl::string_view key, absl::string_view value, + std::stringstream& info, bool is_first) { + std::string quoted_value = absl::StrCat("\"", value, "\""); + addInfoToStream(key, quoted_value, info, is_first); +} + +void HystrixSink::addIntToStream(absl::string_view key, uint64_t value, std::stringstream& info, + bool is_first) { + addInfoToStream(key, std::to_string(value), info, is_first); +} + +void HystrixSink::addInfoToStream(absl::string_view key, absl::string_view value, + std::stringstream& info, bool is_first) { + if (!is_first) { + info << ", "; + } + std::string added_info = absl::StrCat("\"", key, "\": ", value); + info << added_info; +} + +void HystrixSink::addHystrixCommand(ClusterStatsCache& cluster_stats_cache, + absl::string_view cluster_name, + uint64_t max_concurrent_requests, uint64_t reporting_hosts, + std::chrono::milliseconds rolling_window_ms, + std::stringstream& ss) { + + std::time_t currentTime = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + + ss << "data: {"; + addStringToStream("type", "HystrixCommand", ss, true); + addStringToStream("name", cluster_name, ss); + addStringToStream("group", "NA", ss); + addIntToStream("currentTime", static_cast(currentTime), ss); + addInfoToStream("isCircuitBreakerOpen", "false", ss); + + uint64_t errors = getRollingValue(cluster_stats_cache.errors_); + uint64_t timeouts = getRollingValue(cluster_stats_cache.timeouts_); + uint64_t rejected = getRollingValue(cluster_stats_cache.rejected_); + uint64_t total = getRollingValue(cluster_stats_cache.total_); + + uint64_t error_rate = total == 0 ? 0 : (100 * (errors + timeouts + rejected)) / total; + + addIntToStream("errorPercentage", error_rate, ss); + addIntToStream("errorCount", errors, ss); + addIntToStream("requestCount", total, ss); + addIntToStream("rollingCountCollapsedRequests", 0, ss); + addIntToStream("rollingCountExceptionsThrown", 0, ss); + addIntToStream("rollingCountFailure", errors, ss); + addIntToStream("rollingCountFallbackFailure", 0, ss); + addIntToStream("rollingCountFallbackRejection", 0, ss); + addIntToStream("rollingCountFallbackSuccess", 0, ss); + addIntToStream("rollingCountResponsesFromCache", 0, ss); + + // Envoy's "circuit breaker" has similar meaning to hystrix's isolation + // so we count upstream_rq_pending_overflow and present it as ss + addIntToStream("rollingCountSemaphoreRejected", rejected, ss); + + // Hystrix's short circuit is not similar to Envoy's since it is triggered by 503 responses + // there is no parallel counter in Envoy since as a result of errors (outlier detection) + // requests are not rejected, but rather the node is removed from load balancer healthy pool. + addIntToStream("rollingCountShortCircuited", 0, ss); + addIntToStream("rollingCountSuccess", getRollingValue(cluster_stats_cache.success_), ss); + addIntToStream("rollingCountThreadPoolRejected", 0, ss); + addIntToStream("rollingCountTimeout", timeouts, ss); + addIntToStream("rollingCountBadRequests", 0, ss); + addIntToStream("currentConcurrentExecutionCount", 0, ss); + addIntToStream("latencyExecute_mean", 0, ss); + + // TODO trabetti : add histogram information once available by PR #2932 + addInfoToStream( + "latencyExecute", + "{\"0\":0,\"25\":0,\"50\":0,\"75\":0,\"90\":0,\"95\":0,\"99\":0,\"99.5\":0,\"100\":0}", ss); + addIntToStream("propertyValue_circuitBreakerRequestVolumeThreshold", 0, ss); + addIntToStream("propertyValue_circuitBreakerSleepWindowInMilliseconds", 0, ss); + addIntToStream("propertyValue_circuitBreakerErrorThresholdPercentage", 0, ss); + addInfoToStream("propertyValue_circuitBreakerForceOpen", "false", ss); + addInfoToStream("propertyValue_circuitBreakerForceClosed", "true", ss); + addStringToStream("propertyValue_executionIsolationStrategy", "SEMAPHORE", ss); + addIntToStream("propertyValue_executionIsolationThreadTimeoutInMilliseconds", 0, ss); + addInfoToStream("propertyValue_executionIsolationThreadInterruptOnTimeout", "false", ss); + addIntToStream("propertyValue_executionIsolationSemaphoreMaxConcurrentRequests", + max_concurrent_requests, ss); + addIntToStream("propertyValue_fallbackIsolationSemaphoreMaxConcurrentRequests", 0, ss); + addInfoToStream("propertyValue_requestCacheEnabled", "false", ss); + addInfoToStream("propertyValue_requestLogEnabled", "true", ss); + addIntToStream("reportingHosts", reporting_hosts, ss); + addIntToStream("propertyValue_metricsRollingStatisticalWindowInMilliseconds", + rolling_window_ms.count(), ss); + + ss << "}" << std::endl << std::endl; +} + +void HystrixSink::addHystrixThreadPool(absl::string_view cluster_name, uint64_t queue_size, + uint64_t reporting_hosts, + std::chrono::milliseconds rolling_window_ms, + std::stringstream& ss) { + + ss << "data: {"; + addIntToStream("currentPoolSize", 0, ss, true); + addIntToStream("rollingMaxActiveThreads", 0, ss); + addIntToStream("currentActiveCount", 0, ss); + addIntToStream("currentCompletedTaskCount", 0, ss); + addIntToStream("propertyValue_queueSizeRejectionThreshold", queue_size, ss); + addStringToStream("type", "HystrixThreadPool", ss); + addIntToStream("reportingHosts", reporting_hosts, ss); + addIntToStream("propertyValue_metricsRollingStatisticalWindowInMilliseconds", + rolling_window_ms.count(), ss); + addStringToStream("name", cluster_name, ss); + addIntToStream("currentLargestPoolSize", 0, ss); + addIntToStream("currentCorePoolSize", 0, ss); + addIntToStream("currentQueueSize", 0, ss); + addIntToStream("currentTaskCount", 0, ss); + addIntToStream("rollingCountThreadsExecuted", 0, ss); + addIntToStream("currentMaximumPoolSize", 0, ss); + + ss << "}" << std::endl << std::endl; +} + +void HystrixSink::addClusterStatsToStream(ClusterStatsCache& cluster_stats_cache, + absl::string_view cluster_name, + uint64_t max_concurrent_requests, + uint64_t reporting_hosts, + std::chrono::milliseconds rolling_window_ms, + std::stringstream& ss) { + + addHystrixCommand(cluster_stats_cache, cluster_name, max_concurrent_requests, reporting_hosts, + rolling_window_ms, ss); + addHystrixThreadPool(cluster_name, max_concurrent_requests, reporting_hosts, rolling_window_ms, + ss); +} + +const std::string HystrixSink::printRollingWindows() { + std::stringstream out_str; + + for (auto& itr : cluster_stats_cache_map_) { + ClusterStatsCache& cluster_stats_cache = *(itr.second); + cluster_stats_cache.printToStream(out_str); + } + return out_str.str(); +} + +HystrixSink::HystrixSink(Server::Instance& server, const uint64_t num_buckets) + : server_(server), current_index_(num_buckets > 0 ? num_buckets : DEFAULT_NUM_BUCKETS), + window_size_(current_index_ + 1) { + Server::Admin& admin = server_.admin(); + ENVOY_LOG(debug, + "adding hystrix_event_stream endpoint to enable connection to hystrix dashboard"); + admin.addHandler("/hystrix_event_stream", "send hystrix event stream", + MAKE_ADMIN_HANDLER(handlerHystrixEventStream), false, false); +} + +Http::Code HystrixSink::handlerHystrixEventStream(absl::string_view, + Http::HeaderMap& response_headers, + Buffer::Instance&, + Server::AdminStream& admin_stream) { + + response_headers.insertContentType().value().setReference( + Http::Headers::get().ContentTypeValues.TextEventStream); + response_headers.insertCacheControl().value().setReference( + Http::Headers::get().CacheControlValues.NoCache); + response_headers.insertConnection().value().setReference( + Http::Headers::get().ConnectionValues.Close); + response_headers.insertAccessControlAllowHeaders().value().setReference( + AccessControlAllowHeadersValue.AllowHeadersHystrix); + response_headers.insertAccessControlAllowOrigin().value().setReference( + Http::Headers::get().AccessControlAllowOriginValue.All); + response_headers.insertNoChunks().value().setReference("0"); + + Http::StreamDecoderFilterCallbacks& stream_decoder_filter_callbacks = + admin_stream.getDecoderFilterCallbacks(); + + registerConnection(&stream_decoder_filter_callbacks); + + admin_stream.setEndStreamOnComplete(false); // set streaming + + // Separated out just so it's easier to understand + auto on_destroy_callback = [this, &stream_decoder_filter_callbacks]() { + ENVOY_LOG(debug, "stopped sending data to hystrix dashboard on port {}", + stream_decoder_filter_callbacks.connection()->remoteAddress()->asString()); + + // Unregister the callbacks from the sink so data is no longer encoded through them. + unregisterConnection(&stream_decoder_filter_callbacks); + }; + + // Add the callback to the admin_filter list of callbacks + admin_stream.addOnDestroyCallback(std::move(on_destroy_callback)); + + ENVOY_LOG(debug, "started sending data to hystrix dashboard on port {}", + stream_decoder_filter_callbacks.connection()->remoteAddress()->asString()); + return Http::Code::OK; +} + +void HystrixSink::flush(Stats::Source&) { + if (callbacks_list_.empty()) { + return; + } + incCounter(); + std::stringstream ss; + Upstream::ClusterManager::ClusterInfoMap clusters = server_.clusterManager().clusters(); + for (auto& cluster : clusters) { + Upstream::ClusterInfoConstSharedPtr cluster_info = cluster.second.get().info(); + + std::unique_ptr& cluster_stats_cache_ptr = + cluster_stats_cache_map_[cluster_info->name()]; + if (cluster_stats_cache_ptr == nullptr) { + cluster_stats_cache_ptr = std::make_unique(cluster_info->name()); + } + + // update rolling window with cluster stats + updateRollingWindowMap(*cluster_info, *cluster_stats_cache_ptr); + + // append it to stream to be sent + addClusterStatsToStream( + *cluster_stats_cache_ptr, cluster_info->name(), + cluster_info->resourceManager(Upstream::ResourcePriority::Default).pendingRequests().max(), + cluster_info->statsScope().gauge("membership_total").value(), server_.statsFlushInterval(), + ss); + } + + Buffer::OwnedImpl data; + for (auto callbacks : callbacks_list_) { + data.add(ss.str()); + callbacks->encodeData(data, false); + } + + // send keep alive ping + // TODO (@trabetti) : is it ok to send together with data? + Buffer::OwnedImpl ping_data; + for (auto callbacks : callbacks_list_) { + ping_data.add(":\n\n"); + callbacks->encodeData(ping_data, false); + } + + // check if any clusters were removed, and remove from cache + if (clusters.size() < cluster_stats_cache_map_.size()) { + for (auto it = cluster_stats_cache_map_.begin(); it != cluster_stats_cache_map_.end();) { + if (clusters.find(it->first) == clusters.end()) { + it = cluster_stats_cache_map_.erase(it); + } else { + ++it; + } + } + } +} + +void HystrixSink::registerConnection(Http::StreamDecoderFilterCallbacks* callbacks_to_register) { + callbacks_list_.emplace_back(callbacks_to_register); +} + +void HystrixSink::unregisterConnection(Http::StreamDecoderFilterCallbacks* callbacks_to_remove) { + for (auto it = callbacks_list_.begin(); it != callbacks_list_.end(); ++it) { + if ((*it)->streamId() == callbacks_to_remove->streamId()) { + callbacks_list_.erase(it); + break; + } + } + // If there are no callbacks, clear the map to avoid stale values or having to keep updating the + // map. When a new callback is assigned, the rollingWindow is initialized with current statistics + // and within RollingWindow time, the results showed in the dashboard will be reliable + if (callbacks_list_.empty()) { + resetRollingWindow(); + } +} + +} // namespace Hystrix +} // namespace StatSinks +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/stat_sinks/hystrix/hystrix.h b/source/extensions/stat_sinks/hystrix/hystrix.h new file mode 100644 index 0000000000000..65e4579e3bab1 --- /dev/null +++ b/source/extensions/stat_sinks/hystrix/hystrix.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include +#include + +#include "envoy/server/admin.h" +#include "envoy/server/instance.h" +#include "envoy/stats/stats.h" + +namespace Envoy { +namespace Extensions { +namespace StatSinks { +namespace Hystrix { + +typedef std::vector RollingWindow; +typedef std::map RollingStatsMap; + +struct { + const std::string AllowHeadersHystrix{"Accept, Cache-Control, X-Requested-With, Last-Event-ID"}; +} AccessControlAllowHeadersValue; + +struct ClusterStatsCache { + ClusterStatsCache(const std::string& cluster_name); + + void printToStream(std::stringstream& out_str); + void printRollingWindow(absl::string_view name, RollingWindow rolling_window, + std::stringstream& out_str); + std::string cluster_name_; + + // Rolling windows + RollingWindow errors_; + RollingWindow success_; + RollingWindow total_; + RollingWindow timeouts_; + RollingWindow rejected_; +}; + +typedef std::unique_ptr ClusterStatsCachePtr; + +class HystrixSink : public Stats::Sink, public Logger::Loggable { +public: + HystrixSink(Server::Instance& server, uint64_t num_buckets); + Http::Code handlerHystrixEventStream(absl::string_view, Http::HeaderMap& response_headers, + Buffer::Instance&, Server::AdminStream& admin_stream); + void flush(Stats::Source& source) override; + void onHistogramComplete(const Stats::Histogram&, uint64_t) override{}; + + /** + * Register a new connection. + */ + void registerConnection(Http::StreamDecoderFilterCallbacks* callbacks_to_register); + + /** + * Remove registered connection. + */ + void unregisterConnection(Http::StreamDecoderFilterCallbacks* callbacks_to_remove); + + /** + * Add new value to top of rolling window, pushing out the oldest value. + */ + void pushNewValue(RollingWindow& rolling_window, uint64_t value); + + /** + * Increment pointer of next value to add to rolling window. + */ + void incCounter() { current_index_ = (current_index_ + 1) % window_size_; } + + /** + * Generate the streams to be sent to hystrix dashboard. + */ + void addClusterStatsToStream(ClusterStatsCache& cluster_stats_cache, + absl::string_view cluster_name, uint64_t max_concurrent_requests, + uint64_t reporting_hosts, + std::chrono::milliseconds rolling_window_ms, std::stringstream& ss); + + /** + * Calculate values needed to create the stream and write into the map. + */ + void updateRollingWindowMap(const Upstream::ClusterInfo& cluster_info, + ClusterStatsCache& cluster_stats_cache); + /** + * Clear map. + */ + void resetRollingWindow(); + + /** + * Return string represnting current state of the map. for DEBUG. + */ + const std::string printRollingWindows(); + + /** + * Get the statistic's value change over the rolling window time frame. + */ + uint64_t getRollingValue(RollingWindow rolling_window); + +private: + /** + * Format the given key and absl::string_view value to "key"="value", and adding to the + * stringstream. + */ + void addStringToStream(absl::string_view key, absl::string_view value, std::stringstream& info, + bool is_first = false); + + /** + * Format the given key and uint64_t value to "key"=, and adding to the + * stringstream. + */ + void addIntToStream(absl::string_view key, uint64_t value, std::stringstream& info, + bool is_first = false); + + /** + * Format the given key and value to "key"=value, and adding to the stringstream. + */ + void addInfoToStream(absl::string_view key, absl::string_view value, std::stringstream& info, + bool is_first = false); + + /** + * Generate HystrixCommand event stream. + */ + void addHystrixCommand(ClusterStatsCache& cluster_stats_cache, absl::string_view cluster_name, + uint64_t max_concurrent_requests, uint64_t reporting_hosts, + std::chrono::milliseconds rolling_window_ms, std::stringstream& ss); + + /** + * Generate HystrixThreadPool event stream. + */ + void addHystrixThreadPool(absl::string_view cluster_name, uint64_t queue_size, + uint64_t reporting_hosts, std::chrono::milliseconds rolling_window_ms, + std::stringstream& ss); + + std::vector callbacks_list_; + Server::Instance& server_; + uint64_t current_index_; + const uint64_t window_size_; + static const uint64_t DEFAULT_NUM_BUCKETS = 10; + + // Map from cluster names to a struct of all of that cluster's stat windows. + std::unordered_map cluster_stats_cache_map_; +}; + +typedef std::unique_ptr HystrixSinkPtr; + +} // namespace Hystrix +} // namespace StatSinks +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/stat_sinks/metrics_service/BUILD b/source/extensions/stat_sinks/metrics_service/BUILD index 4f7e60d83b6c5..82a6110311b9e 100644 --- a/source/extensions/stat_sinks/metrics_service/BUILD +++ b/source/extensions/stat_sinks/metrics_service/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Stats sink for the gRPC metrics service: api/envoy/service/metrics/v2/metrics_service.proto load( diff --git a/source/extensions/stat_sinks/metrics_service/config.cc b/source/extensions/stat_sinks/metrics_service/config.cc index 4020c35f9b459..234b46b377406 100644 --- a/source/extensions/stat_sinks/metrics_service/config.cc +++ b/source/extensions/stat_sinks/metrics_service/config.cc @@ -37,7 +37,7 @@ ProtobufTypes::MessagePtr MetricsServiceSinkFactory::createEmptyConfigProto() { std::make_unique()); } -std::string MetricsServiceSinkFactory::name() { return StatsSinkNames::get().METRICS_SERVICE; } +std::string MetricsServiceSinkFactory::name() { return StatsSinkNames::get().MetricsService; } /** * Static registration for the this sink factory. @see RegisterFactory. diff --git a/source/extensions/stat_sinks/statsd/BUILD b/source/extensions/stat_sinks/statsd/BUILD index cad3c5ab08155..23ea5f1470ad0 100644 --- a/source/extensions/stat_sinks/statsd/BUILD +++ b/source/extensions/stat_sinks/statsd/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Stats sink for the basic version of the statsd protocol (https://github.com/b/statsd_spec). load( diff --git a/source/extensions/stat_sinks/statsd/config.cc b/source/extensions/stat_sinks/statsd/config.cc index eb7fe0c3f496c..ddb28d0d0d083 100644 --- a/source/extensions/stat_sinks/statsd/config.cc +++ b/source/extensions/stat_sinks/statsd/config.cc @@ -34,7 +34,7 @@ Stats::SinkPtr StatsdSinkFactory::createStatsSink(const Protobuf::Message& confi server.clusterManager(), server.stats(), statsd_sink.prefix()); default: // Verified by schema. - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } @@ -43,7 +43,7 @@ ProtobufTypes::MessagePtr StatsdSinkFactory::createEmptyConfigProto() { new envoy::config::metrics::v2::StatsdSink()); } -std::string StatsdSinkFactory::name() { return StatsSinkNames::get().STATSD; } +std::string StatsdSinkFactory::name() { return StatsSinkNames::get().Statsd; } /** * Static registration for the statsd sink factory. @see RegisterFactory. diff --git a/source/extensions/stat_sinks/well_known_names.h b/source/extensions/stat_sinks/well_known_names.h index f524ffa4cd181..8b4b6022c8677 100644 --- a/source/extensions/stat_sinks/well_known_names.h +++ b/source/extensions/stat_sinks/well_known_names.h @@ -13,11 +13,13 @@ namespace StatSinks { class StatsSinkNameValues { public: // Statsd sink - const std::string STATSD = "envoy.statsd"; + const std::string Statsd = "envoy.statsd"; // DogStatsD compatible stastsd sink - const std::string DOG_STATSD = "envoy.dog_statsd"; + const std::string DogStatsd = "envoy.dog_statsd"; // MetricsService sink - const std::string METRICS_SERVICE = "envoy.metrics_service"; + const std::string MetricsService = "envoy.metrics_service"; + // Hystrix sink + const std::string Hystrix = "envoy.stat_sinks.hystrix"; }; typedef ConstSingleton StatsSinkNames; diff --git a/source/extensions/tracers/common/ot/opentracing_driver_impl.cc b/source/extensions/tracers/common/ot/opentracing_driver_impl.cc index 6229ea7eb226b..48e5fa9881106 100644 --- a/source/extensions/tracers/common/ot/opentracing_driver_impl.cc +++ b/source/extensions/tracers/common/ot/opentracing_driver_impl.cc @@ -54,7 +54,7 @@ class OpenTracingHTTPHeadersReader : public opentracing::HTTPHeadersReader { case Http::HeaderMap::Lookup::NotSupported: return opentracing::make_unexpected(opentracing::lookup_key_not_supported_error); } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } opentracing::expected ForeachKey(OpenTracingCb f) const override { @@ -128,7 +128,7 @@ Tracing::SpanPtr OpenTracingSpan::spawnChild(const Tracing::Config&, const std:: SystemTime start_time) { std::unique_ptr ot_span = span_->tracer().StartSpan( name, {opentracing::ChildOf(&span_->context()), opentracing::StartTimestamp(start_time)}); - RELEASE_ASSERT(ot_span != nullptr); + RELEASE_ASSERT(ot_span != nullptr, ""); return Tracing::SpanPtr{new OpenTracingSpan{driver_, std::move(ot_span)}}; } @@ -183,7 +183,7 @@ Tracing::SpanPtr OpenTracingDriver::startSpan(const Tracing::Config& config, options.tags.emplace_back(opentracing::ext::sampling_priority, 0); } active_span = tracer.StartSpanWithOptions(operation_name, options); - RELEASE_ASSERT(active_span != nullptr); + RELEASE_ASSERT(active_span != nullptr, ""); active_span->SetTag(opentracing::ext::span_kind, config.operationName() == Tracing::OperationName::Egress ? opentracing::ext::span_kind_rpc_client diff --git a/source/extensions/tracers/dynamic_ot/BUILD b/source/extensions/tracers/dynamic_ot/BUILD index 9c27fe305fb8b..7bc36009a8980 100644 --- a/source/extensions/tracers/dynamic_ot/BUILD +++ b/source/extensions/tracers/dynamic_ot/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Trace driver for dynamically loadable C++ OpenTracing drivers (http://opentracing.io/). load( diff --git a/source/extensions/tracers/dynamic_ot/config.cc b/source/extensions/tracers/dynamic_ot/config.cc index 0c94a4be4f9bc..b82f3f8e8b8c1 100644 --- a/source/extensions/tracers/dynamic_ot/config.cc +++ b/source/extensions/tracers/dynamic_ot/config.cc @@ -23,7 +23,7 @@ DynamicOpenTracingTracerFactory::createHttpTracer(const Json::Object& json_confi return std::make_unique(std::move(dynamic_driver), server.localInfo()); } -std::string DynamicOpenTracingTracerFactory::name() { return TracerNames::get().DYNAMIC_OT; } +std::string DynamicOpenTracingTracerFactory::name() { return TracerNames::get().DynamicOt; } /** * Static registration for the dynamic opentracing tracer. @see RegisterFactory. diff --git a/source/extensions/tracers/dynamic_ot/dynamic_opentracing_driver_impl.cc b/source/extensions/tracers/dynamic_ot/dynamic_opentracing_driver_impl.cc index 3829a6e66ab52..aa3254e884d08 100644 --- a/source/extensions/tracers/dynamic_ot/dynamic_opentracing_driver_impl.cc +++ b/source/extensions/tracers/dynamic_ot/dynamic_opentracing_driver_impl.cc @@ -24,7 +24,7 @@ DynamicOpenTracingDriver::DynamicOpenTracingDriver(Stats::Store& stats, const st throw EnvoyException{formatErrorMessage(tracer_maybe.error(), error_message)}; } tracer_ = std::move(*tracer_maybe); - RELEASE_ASSERT(tracer_ != nullptr); + RELEASE_ASSERT(tracer_ != nullptr, ""); } std::string DynamicOpenTracingDriver::formatErrorMessage(std::error_code error_code, diff --git a/source/extensions/tracers/lightstep/BUILD b/source/extensions/tracers/lightstep/BUILD index f0f136d3d8c18..7a30babd54671 100644 --- a/source/extensions/tracers/lightstep/BUILD +++ b/source/extensions/tracers/lightstep/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Trace driver for LightStep (https://lightstep.com/) load( diff --git a/source/extensions/tracers/lightstep/config.cc b/source/extensions/tracers/lightstep/config.cc index 22a1b83428d4f..e1b1c5a947b35 100644 --- a/source/extensions/tracers/lightstep/config.cc +++ b/source/extensions/tracers/lightstep/config.cc @@ -31,7 +31,7 @@ Tracing::HttpTracerPtr LightstepTracerFactory::createHttpTracer(const Json::Obje return std::make_unique(std::move(lightstep_driver), server.localInfo()); } -std::string LightstepTracerFactory::name() { return TracerNames::get().LIGHTSTEP; } +std::string LightstepTracerFactory::name() { return TracerNames::get().Lightstep; } /** * Static registration for the lightstep tracer. @see RegisterFactory. diff --git a/source/extensions/tracers/well_known_names.h b/source/extensions/tracers/well_known_names.h index 630649a7d0b4c..d545eecedad97 100644 --- a/source/extensions/tracers/well_known_names.h +++ b/source/extensions/tracers/well_known_names.h @@ -13,11 +13,11 @@ namespace Tracers { class TracerNameValues { public: // Lightstep tracer - const std::string LIGHTSTEP = "envoy.lightstep"; + const std::string Lightstep = "envoy.lightstep"; // Zipkin tracer - const std::string ZIPKIN = "envoy.zipkin"; + const std::string Zipkin = "envoy.zipkin"; // Dynamic tracer - const std::string DYNAMIC_OT = "envoy.dynamic.ot"; + const std::string DynamicOt = "envoy.dynamic.ot"; }; typedef ConstSingleton TracerNames; diff --git a/source/extensions/tracers/zipkin/BUILD b/source/extensions/tracers/zipkin/BUILD index 204f1382fd1fe..24f30300c0a99 100644 --- a/source/extensions/tracers/zipkin/BUILD +++ b/source/extensions/tracers/zipkin/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Trace driver for Zipkin (https://zipkin.io/). load( diff --git a/source/extensions/tracers/zipkin/config.cc b/source/extensions/tracers/zipkin/config.cc index da0562478ffe6..278aebadfbc49 100644 --- a/source/extensions/tracers/zipkin/config.cc +++ b/source/extensions/tracers/zipkin/config.cc @@ -26,7 +26,7 @@ Tracing::HttpTracerPtr ZipkinTracerFactory::createHttpTracer(const Json::Object& new Tracing::HttpTracerImpl(std::move(zipkin_driver), server.localInfo())); } -std::string ZipkinTracerFactory::name() { return TracerNames::get().ZIPKIN; } +std::string ZipkinTracerFactory::name() { return TracerNames::get().Zipkin; } /** * Static registration for the lightstep tracer. @see RegisterFactory. diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index 28cc6960e7154..da086ff2d6e2f 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -24,6 +24,21 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "tsi_frame_protector", + srcs = [ + "tsi_frame_protector.cc", + ], + hdrs = [ + "tsi_frame_protector.h", + ], + repository = "@envoy", + deps = [ + ":grpc_tsi_wrapper", + "//source/common/buffer:buffer_lib", + ], +) + envoy_cc_library( name = "tsi_handshaker", srcs = [ diff --git a/source/extensions/transport_sockets/alts/tsi_frame_protector.cc b/source/extensions/transport_sockets/alts/tsi_frame_protector.cc new file mode 100644 index 0000000000000..1cb8cc22494b1 --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_frame_protector.cc @@ -0,0 +1,77 @@ +#include "extensions/transport_sockets/alts/tsi_frame_protector.h" + +#include "common/common/assert.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +// TODO(lizan): tune size later +static constexpr uint32_t BUFFER_SIZE = 16384; + +TsiFrameProtector::TsiFrameProtector(CFrameProtectorPtr&& frame_protector) + : frame_protector_(std::move(frame_protector)) {} + +tsi_result TsiFrameProtector::protect(Buffer::Instance& input, Buffer::Instance& output) { + ASSERT(frame_protector_); + + unsigned char protected_buffer[BUFFER_SIZE]; + while (input.length() > 0) { + auto* message_bytes = reinterpret_cast(input.linearize(input.length())); + size_t protected_buffer_size = BUFFER_SIZE; + size_t processed_message_size = input.length(); + tsi_result result = + tsi_frame_protector_protect(frame_protector_.get(), message_bytes, &processed_message_size, + protected_buffer, &protected_buffer_size); + if (result != TSI_OK) { + ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED); + return result; + } + output.add(protected_buffer, protected_buffer_size); + input.drain(processed_message_size); + } + + // TSI may buffer some of the input internally. Flush its buffer to protected_buffer. + size_t still_pending_size; + do { + size_t protected_buffer_size = BUFFER_SIZE; + tsi_result result = tsi_frame_protector_protect_flush( + frame_protector_.get(), protected_buffer, &protected_buffer_size, &still_pending_size); + if (result != TSI_OK) { + ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED); + return result; + } + output.add(protected_buffer, protected_buffer_size); + } while (still_pending_size > 0); + + return TSI_OK; +} + +tsi_result TsiFrameProtector::unprotect(Buffer::Instance& input, Buffer::Instance& output) { + ASSERT(frame_protector_); + + unsigned char unprotected_buffer[BUFFER_SIZE]; + + while (input.length() > 0) { + auto* message_bytes = reinterpret_cast(input.linearize(input.length())); + size_t unprotected_buffer_size = BUFFER_SIZE; + size_t processed_message_size = input.length(); + tsi_result result = tsi_frame_protector_unprotect(frame_protector_.get(), message_bytes, + &processed_message_size, unprotected_buffer, + &unprotected_buffer_size); + if (result != TSI_OK) { + ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED); + return result; + } + output.add(unprotected_buffer, unprotected_buffer_size); + input.drain(processed_message_size); + } + + return TSI_OK; +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/alts/tsi_frame_protector.h b/source/extensions/transport_sockets/alts/tsi_frame_protector.h new file mode 100644 index 0000000000000..ac2fe1fc8f7f2 --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_frame_protector.h @@ -0,0 +1,49 @@ +#pragma once + +#include "envoy/buffer/buffer.h" + +#include "extensions/transport_sockets/alts/grpc_tsi.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +/** + * A C++ wrapper for tsi_frame_protector interface. + * For detail of tsi_frame_protector, see + * https://github.com/grpc/grpc/blob/v1.10.0/src/core/tsi/transport_security_interface.h#L70 + * + * TODO(lizan): migrate to tsi_zero_copy_grpc_protector for further optimization + */ +class TsiFrameProtector final { +public: + explicit TsiFrameProtector(CFrameProtectorPtr&& frame_protector); + + /** + * Wrapper for tsi_frame_protector_protect + * @param input supplies the input data to protect, the method will drain it when it is processed. + * @param output supplies the buffer where the protected data will be stored. + * @return tsi_result the status. + */ + tsi_result protect(Buffer::Instance& input, Buffer::Instance& output); + + /** + * Wrapper for tsi_frame_protector_unprotect + * @param input supplies the input data to unprotect, the method will drain it when it is + * processed. + * @param output supplies the buffer where the unprotected data will be stored. + * @return tsi_result the status. + */ + tsi_result unprotect(Buffer::Instance& input, Buffer::Instance& output); + +private: + CFrameProtectorPtr frame_protector_; +}; + +typedef std::unique_ptr TsiFrameProtectorPtr; + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/capture/BUILD b/source/extensions/transport_sockets/capture/BUILD index dc78eeb942469..88490ea7077f0 100644 --- a/source/extensions/transport_sockets/capture/BUILD +++ b/source/extensions/transport_sockets/capture/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Capture wrapper around raw_buffer sockets. load( diff --git a/source/extensions/transport_sockets/capture/config.h b/source/extensions/transport_sockets/capture/config.h index c844430a3f359..02b7ae481b740 100644 --- a/source/extensions/transport_sockets/capture/config.h +++ b/source/extensions/transport_sockets/capture/config.h @@ -17,7 +17,7 @@ class CaptureSocketConfigFactory : public virtual Server::Configuration::TransportSocketConfigFactory { public: virtual ~CaptureSocketConfigFactory() {} - std::string name() const override { return TransportSocketNames::get().CAPTURE; } + std::string name() const override { return TransportSocketNames::get().Capture; } ProtobufTypes::MessagePtr createEmptyConfigProto() override; }; diff --git a/source/extensions/transport_sockets/raw_buffer/BUILD b/source/extensions/transport_sockets/raw_buffer/BUILD index 3290a7c35ed92..4af3823704759 100644 --- a/source/extensions/transport_sockets/raw_buffer/BUILD +++ b/source/extensions/transport_sockets/raw_buffer/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Built-in plaintext connection transport socket. load( diff --git a/source/extensions/transport_sockets/raw_buffer/config.h b/source/extensions/transport_sockets/raw_buffer/config.h index 20f5bc06f9f20..e4f909be49a29 100644 --- a/source/extensions/transport_sockets/raw_buffer/config.h +++ b/source/extensions/transport_sockets/raw_buffer/config.h @@ -16,7 +16,7 @@ namespace RawBuffer { class RawBufferSocketFactory : public virtual Server::Configuration::TransportSocketConfigFactory { public: virtual ~RawBufferSocketFactory() {} - std::string name() const override { return TransportSocketNames::get().RAW_BUFFER; } + std::string name() const override { return TransportSocketNames::get().RawBuffer; } ProtobufTypes::MessagePtr createEmptyConfigProto() override; }; diff --git a/source/extensions/transport_sockets/ssl/BUILD b/source/extensions/transport_sockets/ssl/BUILD index b7009f0cfd1a0..5e76f162fa2bd 100644 --- a/source/extensions/transport_sockets/ssl/BUILD +++ b/source/extensions/transport_sockets/ssl/BUILD @@ -1,4 +1,5 @@ licenses(["notice"]) # Apache 2 + # Built-in TLS connection transport socket. load( diff --git a/source/extensions/transport_sockets/ssl/config.h b/source/extensions/transport_sockets/ssl/config.h index 1b5a84dea64e4..5fba4de55f6d3 100644 --- a/source/extensions/transport_sockets/ssl/config.h +++ b/source/extensions/transport_sockets/ssl/config.h @@ -16,7 +16,7 @@ namespace SslTransport { class SslSocketConfigFactory : public virtual Server::Configuration::TransportSocketConfigFactory { public: virtual ~SslSocketConfigFactory() {} - std::string name() const override { return TransportSocketNames::get().TLS; } + std::string name() const override { return TransportSocketNames::get().Tls; } }; class UpstreamSslSocketFactory : public Server::Configuration::UpstreamTransportSocketConfigFactory, diff --git a/source/extensions/transport_sockets/well_known_names.h b/source/extensions/transport_sockets/well_known_names.h index 2fd92d7cea9d0..0cf096fe30e4c 100644 --- a/source/extensions/transport_sockets/well_known_names.h +++ b/source/extensions/transport_sockets/well_known_names.h @@ -12,9 +12,9 @@ namespace TransportSockets { */ class TransportSocketNameValues { public: - const std::string CAPTURE = "envoy.transport_sockets.capture"; - const std::string RAW_BUFFER = "raw_buffer"; - const std::string TLS = "tls"; + const std::string Capture = "envoy.transport_sockets.capture"; + const std::string RawBuffer = "raw_buffer"; + const std::string Tls = "tls"; }; typedef ConstSingleton TransportSocketNames; diff --git a/source/server/BUILD b/source/server/BUILD index 6ff07ca41dd86..dd46a38892f2e 100644 --- a/source/server/BUILD +++ b/source/server/BUILD @@ -152,12 +152,26 @@ envoy_cc_library( deps = [ "//include/envoy/network:address_interface", "//include/envoy/server:options_interface", + "//include/envoy/stats:stats_interface", "//source/common/common:macros", "//source/common/common:version_lib", "//source/common/stats:stats_lib", ], ) +envoy_cc_library( + name = "overload_manager_lib", + srcs = ["overload_manager_impl.cc"], + hdrs = ["overload_manager_impl.h"], + deps = [ + "//include/envoy/server:overload_manager_interface", + "//source/common/common:logger_lib", + "//source/common/config:utility_lib", + "//source/server:resource_monitor_config_lib", + "@envoy_api//envoy/config/overload/v2alpha:overload_cc", + ], +) + envoy_cc_library( name = "lds_api_lib", srcs = ["lds_api.cc"], @@ -208,6 +222,8 @@ envoy_cc_library( "//source/common/api:os_sys_calls_lib", "//source/common/common:empty_string", "//source/common/config:utility_lib", + "//source/common/network:cidr_range_lib", + "//source/common/network:lc_trie_lib", "//source/common/network:listen_socket_lib", "//source/common/network:resolver_lib", "//source/common/network:socket_option_factory_lib", @@ -241,6 +257,14 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "resource_monitor_config_lib", + hdrs = ["resource_monitor_config_impl.h"], + deps = [ + "//include/envoy/server:resource_monitor_config_interface", + ], +) + envoy_cc_library( name = "server_lib", srcs = ["server.cc"], diff --git a/source/server/config_validation/admin.cc b/source/server/config_validation/admin.cc index 1174c1462871e..bb86b52fd093f 100644 --- a/source/server/config_validation/admin.cc +++ b/source/server/config_validation/admin.cc @@ -9,13 +9,13 @@ bool ValidationAdmin::addHandler(const std::string&, const std::string&, Handler bool ValidationAdmin::removeHandler(const std::string&) { return false; }; -const Network::Socket& ValidationAdmin::socket() { NOT_IMPLEMENTED; }; +const Network::Socket& ValidationAdmin::socket() { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }; ConfigTracker& ValidationAdmin::getConfigTracker() { return config_tracker_; }; Http::Code ValidationAdmin::request(absl::string_view, const Http::Utility::QueryParams&, absl::string_view, Http::HeaderMap&, std::string&) { - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } // namespace Server diff --git a/source/server/config_validation/dispatcher.cc b/source/server/config_validation/dispatcher.cc index 02220271ebf31..a31ca56d6dfcf 100644 --- a/source/server/config_validation/dispatcher.cc +++ b/source/server/config_validation/dispatcher.cc @@ -23,7 +23,7 @@ Network::DnsResolverSharedPtr ValidationDispatcher::createDnsResolver( Network::ListenerPtr ValidationDispatcher::createListener(Network::Socket&, Network::ListenerCallbacks&, bool, bool) { - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } } // namespace Event diff --git a/source/server/config_validation/server.h b/source/server/config_validation/server.h index 23f4669872d01..e0e71373b5960 100644 --- a/source/server/config_validation/server.h +++ b/source/server/config_validation/server.h @@ -66,12 +66,12 @@ class ValidationInstance : Logger::Loggable, Network::DnsResolverSharedPtr dnsResolver() override { return dispatcher().createDnsResolver({}); } - void drainListeners() override { NOT_IMPLEMENTED; } - DrainManager& drainManager() override { NOT_IMPLEMENTED; } + void drainListeners() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + DrainManager& drainManager() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } AccessLog::AccessLogManager& accessLogManager() override { return access_log_manager_; } - void failHealthcheck(bool) override { NOT_IMPLEMENTED; } - void getParentStats(HotRestart::GetParentStatsInfo&) override { NOT_IMPLEMENTED; } - HotRestart& hotRestart() override { NOT_IMPLEMENTED; } + void failHealthcheck(bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void getParentStats(HotRestart::GetParentStatsInfo&) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + HotRestart& hotRestart() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Init::Manager& initManager() override { return init_manager_; } ListenerManager& listenerManager() override { return listener_manager_; } Secret::SecretManager& secretManager() override { return *secret_manager_; } @@ -82,12 +82,12 @@ class ValidationInstance : Logger::Loggable, } Runtime::Loader& runtime() override { return *runtime_loader_; } void shutdown() override; - void shutdownAdmin() override { NOT_IMPLEMENTED; } + void shutdownAdmin() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Singleton::Manager& singletonManager() override { return *singleton_manager_; } - bool healthCheckFailed() override { NOT_IMPLEMENTED; } + bool healthCheckFailed() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Options& options() override { return options_; } - time_t startTimeCurrentEpoch() override { NOT_IMPLEMENTED; } - time_t startTimeFirstEpoch() override { NOT_IMPLEMENTED; } + time_t startTimeCurrentEpoch() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + time_t startTimeFirstEpoch() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } Stats::Store& stats() override { return stats_store_; } Tracing::HttpTracer& httpTracer() override { return config_->httpTracer(); } ThreadLocal::Instance& threadLocal() override { return thread_local_; } diff --git a/source/server/connection_handler_impl.cc b/source/server/connection_handler_impl.cc index 854180cd4914d..f37255621d75b 100644 --- a/source/server/connection_handler_impl.cc +++ b/source/server/connection_handler_impl.cc @@ -153,7 +153,7 @@ void ConnectionHandlerImpl::ActiveSocket::continueFilterChain(bool success) { // Set default transport protocol if none of the listener filters did it. if (socket_->detectedTransportProtocol().empty()) { socket_->setDetectedTransportProtocol( - Extensions::TransportSockets::TransportSocketNames::get().RAW_BUFFER); + Extensions::TransportSockets::TransportSocketNames::get().RawBuffer); } // Create a new connection on this listener. listener_.newConnection(std::move(socket_)); diff --git a/source/server/hot_restart_impl.cc b/source/server/hot_restart_impl.cc index e841c91832448..80b4f25313075 100644 --- a/source/server/hot_restart_impl.cc +++ b/source/server/hot_restart_impl.cc @@ -40,7 +40,7 @@ static BlockMemoryHashSetOptions blockMemHashOptions(uint64_t max_stats) { SharedMemory& SharedMemory::initialize(uint64_t stats_set_size, Options& options) { Api::OsSysCalls& os_sys_calls = Api::OsSysCallsSingleton::get(); - const uint64_t entry_size = Stats::RawStatData::size(); + const uint64_t entry_size = Stats::RawStatData::structSizeWithOptions(options.statsOptions()); const uint64_t total_size = sizeof(SharedMemory) + stats_set_size; int flags = O_RDWR; @@ -60,13 +60,13 @@ SharedMemory& SharedMemory::initialize(uint64_t stats_set_size, Options& options if (options.restartEpoch() == 0) { int rc = os_sys_calls.ftruncate(shmem_fd, total_size); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); } SharedMemory* shmem = reinterpret_cast( os_sys_calls.mmap(nullptr, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, shmem_fd, 0)); - RELEASE_ASSERT(shmem != MAP_FAILED); - RELEASE_ASSERT((reinterpret_cast(shmem) % alignof(decltype(shmem))) == 0); + RELEASE_ASSERT(shmem != MAP_FAILED, ""); + RELEASE_ASSERT((reinterpret_cast(shmem) % alignof(decltype(shmem))) == 0, ""); if (options.restartEpoch() == 0) { shmem->size_ = total_size; @@ -78,15 +78,15 @@ SharedMemory& SharedMemory::initialize(uint64_t stats_set_size, Options& options shmem->initializeMutex(shmem->stat_lock_); shmem->initializeMutex(shmem->init_lock_); } else { - RELEASE_ASSERT(shmem->size_ == total_size); - RELEASE_ASSERT(shmem->version_ == VERSION); - RELEASE_ASSERT(shmem->max_stats_ == options.maxStats()); - RELEASE_ASSERT(shmem->entry_size_ == entry_size); + RELEASE_ASSERT(shmem->size_ == total_size, ""); + RELEASE_ASSERT(shmem->version_ == VERSION, ""); + RELEASE_ASSERT(shmem->max_stats_ == options.maxStats(), ""); + RELEASE_ASSERT(shmem->entry_size_ == entry_size, ""); } // Stats::RawStatData must be naturally aligned for atomics to work properly. - RELEASE_ASSERT((reinterpret_cast(shmem->stats_set_data_) % alignof(RawStatDataSet)) == - 0); + RELEASE_ASSERT( + (reinterpret_cast(shmem->stats_set_data_) % alignof(RawStatDataSet)) == 0, ""); // Here we catch the case where a new Envoy starts up when the current Envoy has not yet fully // initialized. The startup logic is quite complicated, and it's not worth trying to handle this @@ -109,14 +109,16 @@ void SharedMemory::initializeMutex(pthread_mutex_t& mutex) { pthread_mutex_init(&mutex, &attribute); } -std::string SharedMemory::version(uint64_t max_num_stats, uint64_t max_stat_name_len) { +std::string SharedMemory::version(uint64_t max_num_stats, + const Stats::StatsOptions& stats_options) { return fmt::format("{}.{}.{}.{}", VERSION, sizeof(SharedMemory), max_num_stats, - max_stat_name_len); + stats_options.maxNameLength()); } HotRestartImpl::HotRestartImpl(Options& options) : options_(options), stats_set_options_(blockMemHashOptions(options.maxStats())), - shmem_(SharedMemory::initialize(RawStatDataSet::numBytes(stats_set_options_), options)), + shmem_(SharedMemory::initialize( + RawStatDataSet::numBytes(stats_set_options_, options_.statsOptions()), options_)), log_lock_(shmem_.log_lock_), access_log_lock_(shmem_.access_log_lock_), stat_lock_(shmem_.stat_lock_), init_lock_(shmem_.init_lock_) { { @@ -124,7 +126,7 @@ HotRestartImpl::HotRestartImpl(Options& options) // because it might be actively written to while we sanityCheck it. Thread::LockGuard lock(stat_lock_); stats_set_.reset(new RawStatDataSet(stats_set_options_, options.restartEpoch() == 0, - shmem_.stats_set_data_)); + shmem_.stats_set_data_, options_.statsOptions())); } my_domain_socket_ = bindDomainSocket(options.restartEpoch()); child_address_ = createDomainSocketAddress((options.restartEpoch() + 1)); @@ -136,17 +138,17 @@ HotRestartImpl::HotRestartImpl(Options& options) // If our parent ever goes away just terminate us so that we don't have to rely on ops/launching // logic killing the entire process tree. We should never exist without our parent. int rc = prctl(PR_SET_PDEATHSIG, SIGTERM); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); } -Stats::RawStatData* HotRestartImpl::alloc(const std::string& name) { +Stats::RawStatData* HotRestartImpl::alloc(absl::string_view name) { // Try to find the existing slot in shared memory, otherwise allocate a new one. Thread::LockGuard lock(stat_lock_); - absl::string_view key = name; - if (key.size() > Stats::RawStatData::maxNameLength()) { - key.remove_suffix(key.size() - Stats::RawStatData::maxNameLength()); - } - auto value_created = stats_set_->insert(key); + // In production, the name is truncated in ThreadLocalStore before this + // is called. This is just a sanity check to make sure that actually happens; + // it is coded as an if/return-null to facilitate testing. + ASSERT(name.length() <= options_.statsOptions().maxNameLength()); + auto value_created = stats_set_->insert(name); Stats::RawStatData* data = value_created.first; if (data == nullptr) { return nullptr; @@ -169,7 +171,8 @@ void HotRestartImpl::free(Stats::RawStatData& data) { } bool key_removed = stats_set_->remove(data.key()); ASSERT(key_removed); - memset(static_cast(&data), 0, Stats::RawStatData::size()); + memset(static_cast(&data), 0, + Stats::RawStatData::structSizeWithOptions(options_.statsOptions())); } int HotRestartImpl::bindDomainSocket(uint64_t id) { @@ -280,7 +283,7 @@ HotRestartImpl::RpcBase* HotRestartImpl::receiveRpc(bool block) { // By default the domain socket is non blocking. If we need to block, make it blocking first. if (block) { int rc = fcntl(my_domain_socket_, F_SETFL, 0); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); } iovec iov[1]; @@ -303,17 +306,17 @@ HotRestartImpl::RpcBase* HotRestartImpl::receiveRpc(bool block) { return nullptr; } - RELEASE_ASSERT(rc != -1); - RELEASE_ASSERT(message.msg_flags == 0); + RELEASE_ASSERT(rc != -1, ""); + RELEASE_ASSERT(message.msg_flags == 0, ""); // Turn non-blocking back on if we made it blocking. if (block) { int rc = fcntl(my_domain_socket_, F_SETFL, O_NONBLOCK); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); } RpcBase* rpc = reinterpret_cast(&rpc_buffer_[0]); - RELEASE_ASSERT(static_cast(rc) == rpc->length_); + RELEASE_ASSERT(static_cast(rc) == rpc->length_, ""); // We should only get control data in a GetListenSocketReply. If that's the case, pull the // cloned fd out of the control data and stick it into the RPC so that higher level code does @@ -327,7 +330,7 @@ HotRestartImpl::RpcBase* HotRestartImpl::receiveRpc(bool block) { reinterpret_cast(rpc)->fd_ = *reinterpret_cast(CMSG_DATA(cmsg)); } else { - RELEASE_ASSERT(false); + RELEASE_ASSERT(false, ""); } } @@ -346,7 +349,7 @@ void HotRestartImpl::sendMessage(sockaddr_un& address, RpcBase& rpc) { message.msg_iov = iov; message.msg_iovlen = 1; int rc = sendmsg(my_domain_socket_, &message, 0); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); } void HotRestartImpl::onGetListenSocket(RpcGetListenSocketRequest& rpc) { @@ -389,7 +392,7 @@ void HotRestartImpl::onGetListenSocket(RpcGetListenSocketRequest& rpc) { *reinterpret_cast(CMSG_DATA(control_message)) = reply.fd_; int rc = sendmsg(my_domain_socket_, &message, 0); - RELEASE_ASSERT(rc != -1); + RELEASE_ASSERT(rc != -1, ""); } } @@ -474,23 +477,30 @@ void HotRestartImpl::shutdown() { socket_event_.reset(); } std::string HotRestartImpl::version() { Thread::LockGuard lock(stat_lock_); - return versionHelper(shmem_.maxStats(), Stats::RawStatData::maxNameLength(), *stats_set_); + return versionHelper(shmem_.maxStats(), options_.statsOptions(), *stats_set_); } // Called from envoy --hot-restart-version -- needs to instantiate a RawStatDataSet so it // can generate the version string. std::string HotRestartImpl::hotRestartVersion(uint64_t max_num_stats, uint64_t max_stat_name_len) { - const BlockMemoryHashSetOptions options = blockMemHashOptions(max_num_stats); - const uint64_t bytes = RawStatDataSet::numBytes(options); + Stats::StatsOptionsImpl stats_options; + stats_options.max_obj_name_length_ = max_stat_name_len - stats_options.maxStatSuffixLength(); + + const BlockMemoryHashSetOptions hash_set_options = blockMemHashOptions(max_num_stats); + const uint64_t bytes = RawStatDataSet::numBytes(hash_set_options, stats_options); std::unique_ptr mem_buffer_for_dry_run_(new uint8_t[bytes]); - RawStatDataSet stats_set(options, true /* init */, mem_buffer_for_dry_run_.get()); - return versionHelper(max_num_stats, max_stat_name_len, stats_set); + RawStatDataSet stats_set(hash_set_options, true /* init */, mem_buffer_for_dry_run_.get(), + stats_options); + + return versionHelper(max_num_stats, stats_options, stats_set); } -std::string HotRestartImpl::versionHelper(uint64_t max_num_stats, uint64_t max_stat_name_len, +std::string HotRestartImpl::versionHelper(uint64_t max_num_stats, + const Stats::StatsOptions& stats_options, RawStatDataSet& stats_set) { - return SharedMemory::version(max_num_stats, max_stat_name_len) + "." + stats_set.version(); + return SharedMemory::version(max_num_stats, stats_options) + "." + + stats_set.version(stats_options); } } // namespace Server diff --git a/source/server/hot_restart_impl.h b/source/server/hot_restart_impl.h index 283840980e7d0..c24c07b109045 100644 --- a/source/server/hot_restart_impl.h +++ b/source/server/hot_restart_impl.h @@ -27,7 +27,7 @@ typedef BlockMemoryHashSet RawStatDataSet; class SharedMemory { public: static void configure(uint64_t max_num_stats, uint64_t max_stat_name_len); - static std::string version(uint64_t max_num_stats, uint64_t max_stat_name_len); + static std::string version(uint64_t max_num_stats, const Stats::StatsOptions& stats_options); // Made public for testing. static const uint64_t VERSION; @@ -139,7 +139,7 @@ class HotRestartImpl : public HotRestart, static std::string hotRestartVersion(uint64_t max_num_stats, uint64_t max_stat_name_len); // RawStatDataAllocator - Stats::RawStatData* alloc(const std::string& name) override; + Stats::RawStatData* alloc(absl::string_view name) override; void free(Stats::RawStatData& data) override; private: @@ -191,8 +191,8 @@ class HotRestartImpl : public HotRestart, template rpc_class* receiveTypedRpc() { RpcBase* base_message = receiveRpc(true); - RELEASE_ASSERT(base_message->length_ == sizeof(rpc_class)); - RELEASE_ASSERT(base_message->type_ == rpc_type); + RELEASE_ASSERT(base_message->length_ == sizeof(rpc_class), ""); + RELEASE_ASSERT(base_message->type_ == rpc_type, ""); return reinterpret_cast(base_message); } @@ -203,7 +203,7 @@ class HotRestartImpl : public HotRestart, void onSocketEvent(); RpcBase* receiveRpc(bool block); void sendMessage(sockaddr_un& address, RpcBase& rpc); - static std::string versionHelper(uint64_t max_num_stats, uint64_t max_stat_name_len, + static std::string versionHelper(uint64_t max_num_stats, const Stats::StatsOptions& stats_options, RawStatDataSet& stats_set); Options& options_; diff --git a/source/server/hot_restart_nop_impl.h b/source/server/hot_restart_nop_impl.h index 73568945fdb93..8444dea459a48 100644 --- a/source/server/hot_restart_nop_impl.h +++ b/source/server/hot_restart_nop_impl.h @@ -27,12 +27,12 @@ class HotRestartNopImpl : public Server::HotRestart { std::string version() override { return "disabled"; } Thread::BasicLockable& logLock() override { return log_lock_; } Thread::BasicLockable& accessLogLock() override { return access_log_lock_; } - Stats::RawStatDataAllocator& statsAllocator() override { return stats_allocator_; } + Stats::StatDataAllocator& statsAllocator() override { return stats_allocator_; } private: Thread::MutexBasicLockable log_lock_; Thread::MutexBasicLockable access_log_lock_; - Stats::HeapRawStatDataAllocator stats_allocator_; + Stats::HeapStatDataAllocator stats_allocator_; }; } // namespace Server diff --git a/source/server/http/BUILD b/source/server/http/BUILD index e6ca62a610d5c..375f9a96c58d4 100644 --- a/source/server/http/BUILD +++ b/source/server/http/BUILD @@ -49,11 +49,13 @@ envoy_cc_library( "//source/common/http/http1:codec_lib", "//source/common/network:listen_socket_lib", "//source/common/network:raw_buffer_socket_lib", + "//source/common/network:utility_lib", "//source/common/profiler:profiler_lib", "//source/common/router:config_lib", "//source/common/stats:stats_lib", "//source/common/upstream:host_utility_lib", "//source/extensions/access_loggers/file:file_access_log_lib", + "@envoy_api//envoy/admin/v2alpha:clusters_cc", "@envoy_api//envoy/admin/v2alpha:config_dump_cc", ], ) diff --git a/source/server/http/admin.cc b/source/server/http/admin.cc index 752ca8e9eac02..267e557ca0639 100644 --- a/source/server/http/admin.cc +++ b/source/server/http/admin.cc @@ -9,6 +9,7 @@ #include #include +#include "envoy/admin/v2alpha/clusters.pb.h" #include "envoy/admin/v2alpha/config_dump.pb.h" #include "envoy/filesystem/filesystem.h" #include "envoy/runtime/runtime.h" @@ -34,6 +35,7 @@ #include "common/http/http1/codec_impl.h" #include "common/json/json_loader.h" #include "common/network/listen_socket_impl.h" +#include "common/network/utility.h" #include "common/profiler/profiler.h" #include "common/router/config_impl.h" #include "common/stats/stats_impl.h" @@ -177,7 +179,7 @@ void AdminFilter::addOnDestroyCallback(std::function cb) { on_destroy_callbacks_.push_back(std::move(cb)); } -const Http::StreamDecoderFilterCallbacks& AdminFilter::getDecoderFilterCallbacks() const { +Http::StreamDecoderFilterCallbacks& AdminFilter::getDecoderFilterCallbacks() const { ASSERT(callbacks_ != nullptr); return *callbacks_; } @@ -258,8 +260,62 @@ void AdminImpl::addCircuitSettings(const std::string& cluster_name, const std::s resource_manager.retries().max())); } -Http::Code AdminImpl::handlerClusters(absl::string_view, Http::HeaderMap&, - Buffer::Instance& response, AdminStream&) { +void AdminImpl::writeClustersAsJson(Buffer::Instance& response) { + envoy::admin::v2alpha::Clusters clusters; + for (auto& cluster_pair : server_.clusterManager().clusters()) { + const Upstream::Cluster& cluster = cluster_pair.second.get(); + Upstream::ClusterInfoConstSharedPtr cluster_info = cluster.info(); + + envoy::admin::v2alpha::ClusterStatus& cluster_status = *clusters.add_cluster_statuses(); + cluster_status.set_name(cluster_info->name()); + + const Upstream::Outlier::Detector* outlier_detector = cluster.outlierDetector(); + if (outlier_detector != nullptr && outlier_detector->successRateEjectionThreshold() > 0.0) { + cluster_status.mutable_success_rate_ejection_threshold()->set_value( + outlier_detector->successRateEjectionThreshold()); + } + + cluster_status.set_added_via_api(cluster_info->addedViaApi()); + + for (auto& host_set : cluster.prioritySet().hostSetsPerPriority()) { + for (auto& host : host_set->hosts()) { + envoy::admin::v2alpha::HostStatus& host_status = *cluster_status.add_host_statuses(); + Network::Utility::addressToProtobufAddress(*host->address(), + *host_status.mutable_address()); + + for (const Stats::CounterSharedPtr& counter : host->counters()) { + auto& metric = (*host_status.mutable_stats())[counter->name()]; + metric.set_type(envoy::admin::v2alpha::SimpleMetric::COUNTER); + metric.set_value(counter->value()); + } + + for (const Stats::GaugeSharedPtr& gauge : host->gauges()) { + auto& metric = (*host_status.mutable_stats())[gauge->name()]; + metric.set_type(envoy::admin::v2alpha::SimpleMetric::GAUGE); + metric.set_value(gauge->value()); + } + + envoy::admin::v2alpha::HostHealthStatus& health_status = + *host_status.mutable_health_status(); + health_status.set_failed_active_health_check( + host->healthFlagGet(Upstream::Host::HealthFlag::FAILED_ACTIVE_HC)); + health_status.set_failed_outlier_check( + host->healthFlagGet(Upstream::Host::HealthFlag::FAILED_OUTLIER_CHECK)); + health_status.set_eds_health_status( + host->healthFlagGet(Upstream::Host::HealthFlag::FAILED_EDS_HEALTH) + ? envoy::api::v2::core::HealthStatus::UNHEALTHY + : envoy::api::v2::core::HealthStatus::HEALTHY); + double success_rate = host->outlierDetector().successRate(); + if (success_rate >= 0.0) { + host_status.mutable_success_rate()->set_value(success_rate); + } + } + } + } + response.add(MessageUtil::getJsonStringFromMessage(clusters, true)); // pretty-print +} + +void AdminImpl::writeClustersAsText(Buffer::Instance& response) { for (auto& cluster : server_.clusterManager().clusters()) { addOutlierInfo(cluster.second.get().info()->name(), cluster.second.get().outlierDetector(), response); @@ -309,6 +365,20 @@ Http::Code AdminImpl::handlerClusters(absl::string_view, Http::HeaderMap&, } } } +} + +Http::Code AdminImpl::handlerClusters(absl::string_view url, Http::HeaderMap& response_headers, + Buffer::Instance& response, AdminStream&) { + Http::Utility::QueryParams query_params = Http::Utility::parseQueryString(url); + auto it = query_params.find("format"); + + if (it != query_params.end() && it->second == "json") { + writeClustersAsJson(response); + response_headers.insertContentType().value().setReference( + Http::Headers::get().ContentTypeValues.Json); + } else { + writeClustersAsText(response); + } return Http::Code::OK; } @@ -320,7 +390,7 @@ Http::Code AdminImpl::handlerConfigDump(absl::string_view, Http::HeaderMap& resp auto& config_dump_map = *(dump.mutable_configs()); for (const auto& key_callback_pair : config_tracker_.getCallbacksMap()) { ProtobufTypes::MessagePtr message = key_callback_pair.second(); - RELEASE_ASSERT(message); + RELEASE_ASSERT(message, ""); ProtobufWkt::Any any_message; any_message.PackFrom(*message); config_dump_map[key_callback_pair.first] = any_message; @@ -494,7 +564,7 @@ std::string PrometheusStatsFormatter::sanitizeName(const std::string& name) { std::string PrometheusStatsFormatter::formattedTags(const std::vector& tags) { std::vector buf; for (const Stats::Tag& tag : tags) { - buf.push_back(fmt::format("{}=\"{}\"", sanitizeName(tag.name_), sanitizeName(tag.value_))); + buf.push_back(fmt::format("{}=\"{}\"", sanitizeName(tag.name_), tag.value_)); } return StringUtil::join(buf, ","); } @@ -772,7 +842,7 @@ void AdminFilter::onComplete() { Buffer::OwnedImpl response; Http::HeaderMapPtr header_map{new Http::HeaderMapImpl}; - RELEASE_ASSERT(request_headers_); + RELEASE_ASSERT(request_headers_, ""); Http::Code code = parent_.runCallback(path, *header_map, response, *this); populateFallbackResponseHeaders(code, *header_map); callbacks_->encodeHeaders(std::move(header_map), @@ -883,16 +953,19 @@ Http::Code AdminImpl::runCallback(absl::string_view path_and_query, for (const UrlHandler& handler : handlers_) { if (path_and_query.compare(0, query_index, handler.prefix_) == 0) { + found_handler = true; if (handler.mutates_server_state_) { const absl::string_view method = admin_stream.getRequestHeaders().Method()->value().getStringView(); if (method != Http::Headers::get().MethodValues.Post) { - ENVOY_LOG(warn, "admin path \"{}\" mutates state, method={} rather than POST", + ENVOY_LOG(error, "admin path \"{}\" mutates state, method={} rather than POST", handler.prefix_, method); + code = Http::Code::BadRequest; + response.add("Invalid request; POST required"); + break; } } code = handler.handler_(path_and_query, response_headers, response, admin_stream); - found_handler = true; break; } } diff --git a/source/server/http/admin.h b/source/server/http/admin.h index 328e1e89ef831..b16a58e2b2858 100644 --- a/source/server/http/admin.h +++ b/source/server/http/admin.h @@ -7,6 +7,7 @@ #include #include +#include "envoy/admin/v2alpha/clusters.pb.h" #include "envoy/http/filter.h" #include "envoy/network/filter.h" #include "envoy/network/listen_socket.h" @@ -88,7 +89,8 @@ class AdminImpl : public Admin, std::chrono::milliseconds drainTimeout() override { return std::chrono::milliseconds(100); } Http::FilterChainFactory& filterFactory() override { return *this; } bool generateRequestId() override { return false; } - const absl::optional& idleTimeout() override { return idle_timeout_; } + absl::optional idleTimeout() const override { return idle_timeout_; } + std::chrono::milliseconds streamIdleTimeout() const override { return {}; } Router::RouteConfigProvider& routeConfigProvider() override { return route_config_provider_; } const std::string& serverName() override { return Http::DefaultServerString::get(); } Http::ConnectionManagerStats& stats() override { return stats_; } @@ -149,11 +151,18 @@ class AdminImpl : public Admin, * @return TRUE if level change succeeded, FALSE otherwise. */ bool changeLogLevel(const Http::Utility::QueryParams& params); + + /** + * Helper methods for the /clusters url handler. + */ void addCircuitSettings(const std::string& cluster_name, const std::string& priority_str, Upstream::ResourceManager& resource_manager, Buffer::Instance& response); void addOutlierInfo(const std::string& cluster_name, const Upstream::Outlier::Detector* outlier_detector, Buffer::Instance& response); + void writeClustersAsJson(Buffer::Instance& response); + void writeClustersAsText(Buffer::Instance& response); + static std::string statsAsJson(const std::map& all_stats, const std::vector& all_histograms, bool show_all, bool pretty_print = false); @@ -297,7 +306,7 @@ class AdminFilter : public Http::StreamDecoderFilter, // AdminStream void setEndStreamOnComplete(bool end_stream) override { end_stream_on_complete_ = end_stream; } void addOnDestroyCallback(std::function cb) override; - const Http::StreamDecoderFilterCallbacks& getDecoderFilterCallbacks() const override; + Http::StreamDecoderFilterCallbacks& getDecoderFilterCallbacks() const override; const Http::HeaderMap& getRequestHeaders() const override; private: diff --git a/source/server/lds_api.cc b/source/server/lds_api.cc index 87c7a5dd321de..43b6f725fd00e 100644 --- a/source/server/lds_api.cc +++ b/source/server/lds_api.cc @@ -25,10 +25,10 @@ LdsApiImpl::LdsApiImpl(const envoy::api::v2::core::ConfigSource& lds_config, subscription_ = Envoy::Config::SubscriptionFactory::subscriptionFromConfigSource( lds_config, local_info.node(), dispatcher, cm, random, *scope_, - [this, &lds_config, &cm, &dispatcher, &random, - &local_info]() -> Config::Subscription* { + [this, &lds_config, &cm, &dispatcher, &random, &local_info, + &scope]() -> Config::Subscription* { return new LdsSubscription(Config::Utility::generateStats(*scope_), lds_config, cm, - dispatcher, random, local_info); + dispatcher, random, local_info, scope.statsOptions()); }, "envoy.api.v2.ListenerDiscoveryService.FetchListeners", "envoy.api.v2.ListenerDiscoveryService.StreamListeners"); diff --git a/source/server/lds_subscription.cc b/source/server/lds_subscription.cc index a825ddbff02b0..757d4c8e43f6c 100644 --- a/source/server/lds_subscription.cc +++ b/source/server/lds_subscription.cc @@ -16,10 +16,11 @@ LdsSubscription::LdsSubscription(Config::SubscriptionStats stats, const envoy::api::v2::core::ConfigSource& lds_config, Upstream::ClusterManager& cm, Event::Dispatcher& dispatcher, Runtime::RandomGenerator& random, - const LocalInfo::LocalInfo& local_info) + const LocalInfo::LocalInfo& local_info, + const Stats::StatsOptions& stats_options) : RestApiFetcher(cm, lds_config.api_config_source().cluster_names()[0], dispatcher, random, Config::Utility::apiConfigSourceRefreshDelay(lds_config.api_config_source())), - local_info_(local_info), stats_(stats) { + local_info_(local_info), stats_(stats), stats_options_(stats_options) { const auto& api_config_source = lds_config.api_config_source(); UNREFERENCED_PARAMETER(lds_config); // If we are building an LdsSubscription, the ConfigSource should be REST_LEGACY. @@ -48,7 +49,7 @@ void LdsSubscription::parseResponse(const Http::Message& response) { Protobuf::RepeatedPtrField resources; for (const Json::ObjectSharedPtr& json_listener : json_listeners) { - Config::LdsJson::translateListener(*json_listener, *resources.Add()); + Config::LdsJson::translateListener(*json_listener, *resources.Add(), stats_options_); } std::pair hash = diff --git a/source/server/lds_subscription.h b/source/server/lds_subscription.h index 2dd39d873c695..c840ed23b972d 100644 --- a/source/server/lds_subscription.h +++ b/source/server/lds_subscription.h @@ -23,7 +23,8 @@ class LdsSubscription : public Http::RestApiFetcher, LdsSubscription(Config::SubscriptionStats stats, const envoy::api::v2::core::ConfigSource& lds_config, Upstream::ClusterManager& cm, Event::Dispatcher& dispatcher, - Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info); + Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info, + const Stats::StatsOptions& stats_options); private: // Config::Subscription @@ -39,7 +40,7 @@ class LdsSubscription : public Http::RestApiFetcher, // We should never hit this at runtime, since this legacy adapter is only used by CdsApiImpl // that doesn't do dynamic modification of resources. UNREFERENCED_PARAMETER(resources); - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } // Http::RestApiFetcher @@ -51,6 +52,7 @@ class LdsSubscription : public Http::RestApiFetcher, const LocalInfo::LocalInfo& local_info_; Config::SubscriptionCallbacks* callbacks_ = nullptr; Config::SubscriptionStats stats_; + const Stats::StatsOptions& stats_options_; }; } // namespace Server diff --git a/source/server/listener_manager_impl.cc b/source/server/listener_manager_impl.cc index f27081dac3743..3e3d25bb4f185 100644 --- a/source/server/listener_manager_impl.cc +++ b/source/server/listener_manager_impl.cc @@ -154,7 +154,7 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(config, use_original_dst, false)) { auto& factory = Config::Utility::getAndCheckFactory( - Extensions::ListenerFilters::ListenerFilterNames::get().ORIGINAL_DST); + Extensions::ListenerFilters::ListenerFilterNames::get().OriginalDst); listener_filter_factories_.push_back( factory.createFilterFactoryFromProto(Envoy::ProtobufWkt::Empty(), *this)); } @@ -165,7 +165,7 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st if (PROTOBUF_GET_WRAPPED_OR_DEFAULT(config.filter_chains()[0], use_proxy_proto, false)) { auto& factory = Config::Utility::getAndCheckFactory( - Extensions::ListenerFilters::ListenerFilterNames::get().PROXY_PROTOCOL); + Extensions::ListenerFilters::ListenerFilterNames::get().ProxyProtocol); listener_filter_factories_.push_back( factory.createFilterFactoryFromProto(Envoy::ProtobufWkt::Empty(), *this)); } @@ -189,11 +189,11 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st auto transport_socket = filter_chain.transport_socket(); if (!filter_chain.has_transport_socket()) { if (filter_chain.has_tls_context()) { - transport_socket.set_name(Extensions::TransportSockets::TransportSocketNames::get().TLS); + transport_socket.set_name(Extensions::TransportSockets::TransportSocketNames::get().Tls); MessageUtil::jsonConvert(filter_chain.tls_context(), *transport_socket.mutable_config()); } else { transport_socket.set_name( - Extensions::TransportSockets::TransportSocketNames::get().RAW_BUFFER); + Extensions::TransportSockets::TransportSocketNames::get().RawBuffer); } } @@ -202,29 +202,22 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st ProtobufTypes::MessagePtr message = Config::Utility::translateToFactoryConfig(transport_socket, config_factory); - std::vector server_names; - if (!filter_chain_match.server_names().empty()) { - if (!filter_chain_match.sni_domains().empty()) { - throw EnvoyException( - fmt::format("error adding listener '{}': both \"server_names\" and the deprecated " - "\"sni_domains\" are used, please merge the list of expected server names " - "into \"server_names\" and remove \"sni_domains\"", - address_->asString())); - } - - server_names.assign(filter_chain_match.server_names().begin(), - filter_chain_match.server_names().end()); - } else if (!filter_chain_match.sni_domains().empty()) { - server_names.assign(filter_chain_match.sni_domains().begin(), - filter_chain_match.sni_domains().end()); + // Validate IP addresses. + std::vector destination_ips; + for (const auto& destination_ip : filter_chain_match.prefix_ranges()) { + const auto& cidr_range = Network::Address::CidrRange::create(destination_ip); + destination_ips.push_back(cidr_range.asString()); } + std::vector server_names(filter_chain_match.server_names().begin(), + filter_chain_match.server_names().end()); + // Reject partial wildcards, we don't match on them. for (const auto& server_name : server_names) { if (server_name.find('*') != std::string::npos && !isWildcardServerName(server_name)) { throw EnvoyException( fmt::format("error adding listener '{}': partial wildcards are not supported in " - "\"server_names\" (or the deprecated \"sni_domains\")", + "\"server_names\"", address_->asString())); } } @@ -233,7 +226,9 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st filter_chain_match.application_protocols().begin(), filter_chain_match.application_protocols().end()); - addFilterChain(server_names, filter_chain_match.transport_protocol(), application_protocols, + addFilterChain(PROTOBUF_GET_WRAPPED_OR_DEFAULT(filter_chain_match, destination_port, 0), + destination_ips, server_names, filter_chain_match.transport_protocol(), + application_protocols, config_factory.createTransportSocketFactory(*message, *this, server_names), parent_.factory_.createNetworkFilterFactoryList(filter_chain.filters(), *this)); @@ -242,10 +237,13 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st (!server_names.empty() || !application_protocols.empty())); } + // Convert DestinationIPsMap to DestinationIPsTrie for faster lookups. + convertDestinationIPsMapToTrie(); + // Automatically inject TLS Inspector if it wasn't configured explicitly and it's needed. if (need_tls_inspector) { for (const auto& filter : config.listener_filters()) { - if (filter.name() == Extensions::ListenerFilters::ListenerFilterNames::get().TLS_INSPECTOR) { + if (filter.name() == Extensions::ListenerFilters::ListenerFilterNames::get().TlsInspector) { need_tls_inspector = false; break; } @@ -260,7 +258,7 @@ ListenerImpl::ListenerImpl(const envoy::api::v2::Listener& config, const std::st auto& factory = Config::Utility::getAndCheckFactory( - Extensions::ListenerFilters::ListenerFilterNames::get().TLS_INSPECTOR); + Extensions::ListenerFilters::ListenerFilterNames::get().TlsInspector); listener_filter_factories_.push_back( factory.createFilterFactoryFromProto(Envoy::ProtobufWkt::Empty(), *this)); } @@ -274,33 +272,73 @@ ListenerImpl::~ListenerImpl() { // active. This is done here explicitly by setting a boolean and then clearing the factory // vector for clarity. initialize_canceled_ = true; - filter_chains_.clear(); + destination_ports_map_.clear(); } bool ListenerImpl::isWildcardServerName(const std::string& name) { return absl::StartsWith(name, "*."); } -void ListenerImpl::addFilterChain(const std::vector& server_names, +void ListenerImpl::addFilterChain(uint16_t destination_port, + const std::vector& destination_ips, + const std::vector& server_names, const std::string& transport_protocol, const std::vector& application_protocols, Network::TransportSocketFactoryPtr&& transport_socket_factory, std::vector filters_factory) { const auto filter_chain = std::make_shared(std::move(transport_socket_factory), std::move(filters_factory)); - // Save mappings. + addFilterChainForDestinationPorts(destination_ports_map_, destination_port, destination_ips, + server_names, transport_protocol, application_protocols, + filter_chain); +} + +void ListenerImpl::addFilterChainForDestinationPorts( + DestinationPortsMap& destination_ports_map, uint16_t destination_port, + const std::vector& destination_ips, const std::vector& server_names, + const std::string& transport_protocol, const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain) { + if (destination_ports_map.find(destination_port) == destination_ports_map.end()) { + destination_ports_map[destination_port] = + std::make_pair(DestinationIPsMap{}, nullptr); + } + addFilterChainForDestinationIPs(destination_ports_map[destination_port].first, destination_ips, + server_names, transport_protocol, application_protocols, + filter_chain); +} + +void ListenerImpl::addFilterChainForDestinationIPs( + DestinationIPsMap& destination_ips_map, const std::vector& destination_ips, + const std::vector& server_names, const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain) { + if (destination_ips.empty()) { + addFilterChainForServerNames(destination_ips_map[EMPTY_STRING], server_names, + transport_protocol, application_protocols, filter_chain); + } else { + for (const auto& destination_ip : destination_ips) { + addFilterChainForServerNames(destination_ips_map[destination_ip], server_names, + transport_protocol, application_protocols, filter_chain); + } + } +} + +void ListenerImpl::addFilterChainForServerNames( + ServerNamesMap& server_names_map, const std::vector& server_names, + const std::string& transport_protocol, const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain) { if (server_names.empty()) { - addFilterChainForApplicationProtocols(filter_chains_[EMPTY_STRING][transport_protocol], + addFilterChainForApplicationProtocols(server_names_map[EMPTY_STRING][transport_protocol], application_protocols, filter_chain); } else { for (const auto& server_name : server_names) { if (isWildcardServerName(server_name)) { // Add mapping for the wildcard domain, i.e. ".example.com" for "*.example.com". addFilterChainForApplicationProtocols( - filter_chains_[server_name.substr(1)][transport_protocol], application_protocols, + server_names_map[server_name.substr(1)][transport_protocol], application_protocols, filter_chain); } else { - addFilterChainForApplicationProtocols(filter_chains_[server_name][transport_protocol], + addFilterChainForApplicationProtocols(server_names_map[server_name][transport_protocol], application_protocols, filter_chain); } } @@ -308,64 +346,131 @@ void ListenerImpl::addFilterChain(const std::vector& server_names, } void ListenerImpl::addFilterChainForApplicationProtocols( - std::unordered_map& transport_protocol_map, + ApplicationProtocolsMap& application_protocols_map, const std::vector& application_protocols, const Network::FilterChainSharedPtr& filter_chain) { if (application_protocols.empty()) { - transport_protocol_map[EMPTY_STRING] = filter_chain; + application_protocols_map[EMPTY_STRING] = filter_chain; } else { for (const auto& application_protocol : application_protocols) { - transport_protocol_map[application_protocol] = filter_chain; + application_protocols_map[application_protocol] = filter_chain; } } } +void ListenerImpl::convertDestinationIPsMapToTrie() { + for (auto& port : destination_ports_map_) { + auto& destination_ips_pair = port.second; + auto& destination_ips_map = destination_ips_pair.first; + std::vector>> list; + for (const auto& entry : destination_ips_map) { + std::vector subnets; + if (entry.first == EMPTY_STRING) { + if (Network::Address::ipFamilySupported(AF_INET)) { + subnets.push_back(Network::Address::CidrRange::create("0.0.0.0/0")); + } + if (Network::Address::ipFamilySupported(AF_INET6)) { + subnets.push_back(Network::Address::CidrRange::create("::/0")); + } + } else { + subnets.push_back(Network::Address::CidrRange::create(entry.first)); + } + list.push_back( + std::make_pair>( + std::make_shared(entry.second), + std::vector(subnets))); + } + destination_ips_pair.second = std::make_unique(list, true); + } +} + const Network::FilterChain* ListenerImpl::findFilterChain(const Network::ConnectionSocket& socket) const { + const auto& address = socket.localAddress(); + + // Match on destination port (only for IP addresses). + if (address->type() == Network::Address::Type::Ip) { + const auto port_match = destination_ports_map_.find(address->ip()->port()); + if (port_match != destination_ports_map_.end()) { + return findFilterChainForDestinationIP(*port_match->second.second, socket); + } + } + + // Match on catch-all port 0. + const auto port_match = destination_ports_map_.find(0); + if (port_match != destination_ports_map_.end()) { + return findFilterChainForDestinationIP(*port_match->second.second, socket); + } + + return nullptr; +} + +const Network::FilterChain* +ListenerImpl::findFilterChainForDestinationIP(const DestinationIPsTrie& destination_ips_trie, + const Network::ConnectionSocket& socket) const { + // Use invalid IP address (matching only filter chains without IP requirements) for UDS. + static const auto& fake_address = Network::Utility::parseInternetAddress("255.255.255.255"); + + auto address = socket.localAddress(); + if (address->type() != Network::Address::Type::Ip) { + address = fake_address; + } + + // Match on both: exact IP and wider CIDR ranges using LcTrie. + const auto& data = destination_ips_trie.getData(address); + if (!data.empty()) { + ASSERT(data.size() == 1); + return findFilterChainForServerName(*data.back(), socket); + } + + return nullptr; +} + +const Network::FilterChain* +ListenerImpl::findFilterChainForServerName(const ServerNamesMap& server_names_map, + const Network::ConnectionSocket& socket) const { const std::string server_name(socket.requestedServerName()); // Match on exact server name, i.e. "www.example.com" for "www.example.com". - const auto server_name_exact_match = filter_chains_.find(server_name); - if (server_name_exact_match != filter_chains_.end()) { - return findFilterChainForServerName(server_name_exact_match->second, socket); + const auto server_name_exact_match = server_names_map.find(server_name); + if (server_name_exact_match != server_names_map.end()) { + return findFilterChainForTransportProtocol(server_name_exact_match->second, socket); } // Match on all wildcard domains, i.e. ".example.com" and ".com" for "www.example.com". size_t pos = server_name.find('.', 1); while (pos < server_name.size() - 1 && pos != std::string::npos) { const std::string wildcard = server_name.substr(pos); - const auto server_name_wildcard_match = filter_chains_.find(wildcard); - if (server_name_wildcard_match != filter_chains_.end()) { - return findFilterChainForServerName(server_name_wildcard_match->second, socket); + const auto server_name_wildcard_match = server_names_map.find(wildcard); + if (server_name_wildcard_match != server_names_map.end()) { + return findFilterChainForTransportProtocol(server_name_wildcard_match->second, socket); } pos = server_name.find('.', pos + 1); } // Match on a filter chain without server name requirements. - const auto server_name_catchall_match = filter_chains_.find(EMPTY_STRING); - if (server_name_catchall_match != filter_chains_.end()) { - return findFilterChainForServerName(server_name_catchall_match->second, socket); + const auto server_name_catchall_match = server_names_map.find(EMPTY_STRING); + if (server_name_catchall_match != server_names_map.end()) { + return findFilterChainForTransportProtocol(server_name_catchall_match->second, socket); } return nullptr; } -const Network::FilterChain* ListenerImpl::findFilterChainForServerName( - const std::unordered_map>& - server_name_match, +const Network::FilterChain* ListenerImpl::findFilterChainForTransportProtocol( + const TransportProtocolsMap& transport_protocols_map, const Network::ConnectionSocket& socket) const { const std::string transport_protocol(socket.detectedTransportProtocol()); // Match on exact transport protocol, e.g. "tls". - const auto transport_protocol_match = server_name_match.find(transport_protocol); - if (transport_protocol_match != server_name_match.end()) { + const auto transport_protocol_match = transport_protocols_map.find(transport_protocol); + if (transport_protocol_match != transport_protocols_map.end()) { return findFilterChainForApplicationProtocols(transport_protocol_match->second, socket); } // Match on a filter chain without transport protocol requirements. - const auto any_protocol_match = server_name_match.find(EMPTY_STRING); - if (any_protocol_match != server_name_match.end()) { + const auto any_protocol_match = transport_protocols_map.find(EMPTY_STRING); + if (any_protocol_match != transport_protocols_map.end()) { return findFilterChainForApplicationProtocols(any_protocol_match->second, socket); } @@ -373,19 +478,19 @@ const Network::FilterChain* ListenerImpl::findFilterChainForServerName( } const Network::FilterChain* ListenerImpl::findFilterChainForApplicationProtocols( - const std::unordered_map& transport_protocol_match, + const ApplicationProtocolsMap& application_protocols_map, const Network::ConnectionSocket& socket) const { // Match on exact application protocol, e.g. "h2" or "http/1.1". for (const auto& application_protocol : socket.requestedApplicationProtocols()) { - const auto application_protocol_match = transport_protocol_match.find(application_protocol); - if (application_protocol_match != transport_protocol_match.end()) { + const auto application_protocol_match = application_protocols_map.find(application_protocol); + if (application_protocol_match != application_protocols_map.end()) { return application_protocol_match->second.get(); } } // Match on a filter chain without application protocol requirements. - const auto any_protocol_match = transport_protocol_match.find(EMPTY_STRING); - if (any_protocol_match != transport_protocol_match.end()) { + const auto any_protocol_match = application_protocols_map.find(EMPTY_STRING); + if (any_protocol_match != application_protocols_map.end()) { return any_protocol_match->second.get(); } diff --git a/source/server/listener_manager_impl.h b/source/server/listener_manager_impl.h index 5349e4c319ce0..14dfadf30ab58 100644 --- a/source/server/listener_manager_impl.h +++ b/source/server/listener_manager_impl.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "envoy/api/v2/listener/listener.pb.h" #include "envoy/network/filter.h" #include "envoy/server/filter_config.h" @@ -9,6 +11,8 @@ #include "envoy/server/worker.h" #include "common/common/logger.h" +#include "common/network/cidr_range.h" +#include "common/network/lc_trie.h" #include "server/init_manager_impl.h" #include "server/lds_api.h" @@ -305,36 +309,67 @@ class ListenerImpl : public Network::ListenerConfig, SystemTime last_updated_; private: - void addFilterChain(const std::vector& server_names, + typedef std::unordered_map ApplicationProtocolsMap; + typedef std::unordered_map TransportProtocolsMap; + // Both exact server names and wildcard domains are part of the same map, in which wildcard + // domains are prefixed with "." (i.e. ".example.com" for "*.example.com") to differentiate + // between exact and wildcard entries. + typedef std::unordered_map ServerNamesMap; + typedef std::unordered_map DestinationIPsMap; + typedef std::shared_ptr ServerNamesMapSharedPtr; + typedef Network::LcTrie::LcTrie DestinationIPsTrie; + typedef std::unique_ptr DestinationIPsTriePtr; + typedef std::unordered_map> + DestinationPortsMap; + + void addFilterChain(uint16_t destination_port, const std::vector& destination_ips, + const std::vector& server_names, const std::string& transport_protocol, const std::vector& application_protocols, Network::TransportSocketFactoryPtr&& transport_socket_factory, std::vector filters_factory); - void addFilterChainForApplicationProtocols( - std::unordered_map& transport_protocol_map, - const std::vector& application_protocols, - const Network::FilterChainSharedPtr& filter_chain); - const Network::FilterChain* findFilterChainForServerName( - const std::unordered_map>& - server_name_match, - const Network::ConnectionSocket& socket) const; - const Network::FilterChain* findFilterChainForApplicationProtocols( - const std::unordered_map& - transport_protocol_match, - const Network::ConnectionSocket& socket) const; + void addFilterChainForDestinationPorts(DestinationPortsMap& destination_ports_map, + uint16_t destination_port, + const std::vector& destination_ips, + const std::vector& server_names, + const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + void addFilterChainForDestinationIPs(DestinationIPsMap& destination_ips_map, + const std::vector& destination_ips, + const std::vector& server_names, + const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + void addFilterChainForServerNames(ServerNamesMap& server_names_map, + const std::vector& server_names, + const std::string& transport_protocol, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + void addFilterChainForApplicationProtocols(ApplicationProtocolsMap& application_protocol_map, + const std::vector& application_protocols, + const Network::FilterChainSharedPtr& filter_chain); + + void convertDestinationIPsMapToTrie(); + + const Network::FilterChain* + findFilterChainForDestinationIP(const DestinationIPsTrie& destination_ips_trie, + const Network::ConnectionSocket& socket) const; + const Network::FilterChain* + findFilterChainForServerName(const ServerNamesMap& server_names_map, + const Network::ConnectionSocket& socket) const; + const Network::FilterChain* + findFilterChainForTransportProtocol(const TransportProtocolsMap& transport_protocols_map, + const Network::ConnectionSocket& socket) const; + const Network::FilterChain* + findFilterChainForApplicationProtocols(const ApplicationProtocolsMap& application_protocols_map, + const Network::ConnectionSocket& socket) const; + static bool isWildcardServerName(const std::string& name); - // Mapping of FilterChain's configured server name and transport protocol, i.e. - // map[server_name][transport_protocol][application_protocol] => FilterChainSharedPtr - // - // For the server_name lookups, both exact server names and wildcard domains are part of the same - // map, in which wildcard domains are prefixed with "." (i.e. ".example.com" for "*.example.com") - // to differentiate between exact and wildcard entries. - std::unordered_map< - std::string, std::unordered_map< - std::string, std::unordered_map>> - filter_chains_; + // Mapping of FilterChain's configured destination ports, IPs, server names, transport protocols + // and application protocols, using structures defined above. + DestinationPortsMap destination_ports_map_; ListenerManagerImpl& parent_; Network::Address::InstanceConstSharedPtr address_; diff --git a/source/server/options_impl.cc b/source/server/options_impl.cc index 497dd31c21217..f052727ee7c74 100644 --- a/source/server/options_impl.cc +++ b/source/server/options_impl.cc @@ -57,7 +57,13 @@ OptionsImpl::OptionsImpl(int argc, const char* const* argv, TCLAP::ValueArg config_yaml( "", "config-yaml", "Inline YAML configuration, merges with the contents of --config-path", false, "", "string", cmd); - TCLAP::SwitchArg v2_config_only("", "v2-config-only", "parse config as v2 only", cmd, false); + + // Deprecated and unused. + TCLAP::SwitchArg v2_config_only("", "v2-config-only", "deprecated", cmd, true); + + TCLAP::SwitchArg allow_v1_config("", "allow-deprecated-v1-api", "allow use of legacy v1 config", + cmd, false); + TCLAP::ValueArg admin_address_path("", "admin-address-path", "Admin address path", false, "", "string", cmd); TCLAP::ValueArg local_address_ip_version("", "local-address-ip-version", @@ -177,7 +183,7 @@ OptionsImpl::OptionsImpl(int argc, const char* const* argv, concurrency_ = concurrency.getValue(); config_path_ = config_path.getValue(); config_yaml_ = config_yaml.getValue(); - v2_config_only_ = v2_config_only.getValue(); + v2_config_only_ = !allow_v1_config.getValue(); admin_address_path_ = admin_address_path.getValue(); log_path_ = log_path.getValue(); restart_epoch_ = restart_epoch.getValue(); @@ -188,13 +194,10 @@ OptionsImpl::OptionsImpl(int argc, const char* const* argv, drain_time_ = std::chrono::seconds(drain_time_s.getValue()); parent_shutdown_time_ = std::chrono::seconds(parent_shutdown_time_s.getValue()); max_stats_ = max_stats.getValue(); - max_obj_name_length_ = max_obj_name_len.getValue(); + stats_options_.max_obj_name_length_ = max_obj_name_len.getValue(); if (hot_restart_version_option.getValue()) { - Stats::RawStatData::configure(*this); - std::cerr << hot_restart_version_cb(max_stats.getValue(), - max_obj_name_len.getValue() + - Stats::RawStatData::maxStatSuffixLength(), + std::cerr << hot_restart_version_cb(max_stats.getValue(), stats_options_.maxNameLength(), !hot_restart_disabled_); throw NoServingException(); } diff --git a/source/server/options_impl.h b/source/server/options_impl.h index 201d3d8ff4c24..bbb97472ec8f0 100644 --- a/source/server/options_impl.h +++ b/source/server/options_impl.h @@ -7,6 +7,8 @@ #include "envoy/common/exception.h" #include "envoy/server/options.h" +#include "common/stats/stats_impl.h" + #include "spdlog/spdlog.h" namespace Envoy { @@ -60,9 +62,7 @@ class OptionsImpl : public Server::Options { void setServiceNodeName(const std::string& service_node) { service_node_ = service_node; } void setServiceZone(const std::string& service_zone) { service_zone_ = service_zone; } void setMaxStats(uint64_t max_stats) { max_stats_ = max_stats; } - void setMaxObjNameLength(uint64_t max_obj_name_length) { - max_obj_name_length_ = max_obj_name_length; - } + void setStatsOptions(Stats::StatsOptionsImpl stats_options) { stats_options_ = stats_options; } void setHotRestartDisabled(bool hot_restart_disabled) { hot_restart_disabled_ = hot_restart_disabled; } @@ -91,7 +91,7 @@ class OptionsImpl : public Server::Options { const std::string& serviceNodeName() const override { return service_node_; } const std::string& serviceZone() const override { return service_zone_; } uint64_t maxStats() const override { return max_stats_; } - uint64_t maxObjNameLength() const override { return max_obj_name_length_; } + const Stats::StatsOptions& statsOptions() const override { return stats_options_; } bool hotRestartDisabled() const override { return hot_restart_disabled_; } private: @@ -114,7 +114,7 @@ class OptionsImpl : public Server::Options { std::chrono::seconds parent_shutdown_time_; Server::Mode mode_; uint64_t max_stats_; - uint64_t max_obj_name_length_; + Stats::StatsOptionsImpl stats_options_; bool hot_restart_disabled_; }; diff --git a/source/server/overload_manager_impl.cc b/source/server/overload_manager_impl.cc new file mode 100644 index 0000000000000..1f9c78ddf7bf7 --- /dev/null +++ b/source/server/overload_manager_impl.cc @@ -0,0 +1,189 @@ +#include "server/overload_manager_impl.h" + +#include "common/common/fmt.h" +#include "common/config/utility.h" +#include "common/protobuf/utility.h" + +#include "server/resource_monitor_config_impl.h" + +namespace Envoy { +namespace Server { + +namespace { + +class ThresholdTriggerImpl : public OverloadAction::Trigger { +public: + ThresholdTriggerImpl(const envoy::config::overload::v2alpha::ThresholdTrigger& config) + : threshold_(config.value()) {} + + bool updateValue(double value) { + const bool fired = isFired(); + value_ = value; + return fired != isFired(); + } + + bool isFired() const { return value_.has_value() && value_ >= threshold_; } + +private: + const double threshold_; + absl::optional value_; +}; + +} // namespace + +OverloadAction::OverloadAction(const envoy::config::overload::v2alpha::OverloadAction& config) { + for (const auto& trigger_config : config.triggers()) { + TriggerPtr trigger; + + switch (trigger_config.trigger_oneof_case()) { + case envoy::config::overload::v2alpha::Trigger::kThreshold: + trigger = std::make_unique(trigger_config.threshold()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + + if (!triggers_.insert(std::make_pair(trigger_config.name(), std::move(trigger))).second) { + throw EnvoyException( + fmt::format("Duplicate trigger resource for overload action {}", config.name())); + } + } +} + +bool OverloadAction::updateResourcePressure(const std::string& name, double pressure) { + const bool active = isActive(); + + auto it = triggers_.find(name); + ASSERT(it != triggers_.end()); + if (it->second->updateValue(pressure)) { + if (it->second->isFired()) { + const auto result = fired_triggers_.insert(name); + ASSERT(result.second); + } else { + const auto result = fired_triggers_.erase(name); + ASSERT(result == 1); + } + } + + return active != isActive(); +} + +bool OverloadAction::isActive() const { return !fired_triggers_.empty(); } + +OverloadManagerImpl::OverloadManagerImpl( + Event::Dispatcher& dispatcher, const envoy::config::overload::v2alpha::OverloadManager& config) + : started_(false), dispatcher_(dispatcher), + refresh_interval_( + std::chrono::milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(config, refresh_interval, 1000))) { + Configuration::ResourceMonitorFactoryContextImpl context(dispatcher); + for (const auto& resource : config.resource_monitors()) { + const auto& name = resource.name(); + ENVOY_LOG(debug, "Adding resource monitor for {}", name); + auto& factory = + Config::Utility::getAndCheckFactory(name); + auto config = Config::Utility::translateToFactoryConfig(resource, factory); + auto monitor = factory.createResourceMonitor(*config, context); + + auto result = resources_.emplace(std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple(name, std::move(monitor), *this)); + if (!result.second) { + throw EnvoyException(fmt::format("Duplicate resource monitor {}", name)); + } + } + + for (const auto& action : config.actions()) { + const auto& name = action.name(); + ENVOY_LOG(debug, "Adding overload action {}", name); + auto result = actions_.emplace(std::piecewise_construct, std::forward_as_tuple(name), + std::forward_as_tuple(action)); + if (!result.second) { + throw EnvoyException(fmt::format("Duplicate overload action {}", name)); + } + + for (const auto& trigger : action.triggers()) { + const std::string resource = trigger.name(); + + if (resources_.find(resource) == resources_.end()) { + throw EnvoyException( + fmt::format("Unknown trigger resource {} for overload action {}", resource, name)); + } + + resource_to_actions_.insert(std::make_pair(resource, name)); + } + } +} + +void OverloadManagerImpl::start() { + ASSERT(!started_); + started_ = true; + timer_ = dispatcher_.createTimer([this]() -> void { + for (auto& resource : resources_) { + resource.second.update(); + } + + timer_->enableTimer(refresh_interval_); + }); + timer_->enableTimer(refresh_interval_); +} + +void OverloadManagerImpl::registerForAction(const std::string& action, + Event::Dispatcher& dispatcher, + OverloadActionCb callback) { + ASSERT(!started_); + + if (actions_.find(action) == actions_.end()) { + ENVOY_LOG(debug, "No overload action configured for {}.", action); + return; + } + + action_to_callbacks_.emplace(std::piecewise_construct, std::forward_as_tuple(action), + std::forward_as_tuple(dispatcher, callback)); +} + +void OverloadManagerImpl::updateResourcePressure(const std::string& resource, double pressure) { + auto action_range = resource_to_actions_.equal_range(resource); + std::for_each(action_range.first, action_range.second, + [&](ResourceToActionMap::value_type& entry) { + const std::string& action = entry.second; + auto action_it = actions_.find(action); + ASSERT(action_it != actions_.end()); + if (action_it->second.updateResourcePressure(resource, pressure)) { + const bool is_active = action_it->second.isActive(); + const auto state = + is_active ? OverloadActionState::Active : OverloadActionState::Inactive; + ENVOY_LOG(info, "Overload action {} has become {}", action, + is_active ? "active" : "inactive"); + auto callback_range = action_to_callbacks_.equal_range(action); + std::for_each(callback_range.first, callback_range.second, + [&](ActionToCallbackMap::value_type& cb_entry) { + auto& cb = cb_entry.second; + cb.dispatcher_.post([&, state]() { cb.callback_(state); }); + }); + } + }); +} + +void OverloadManagerImpl::Resource::update() { + if (!pending_update_) { + pending_update_ = true; + monitor_->updateResourceUsage(*this); + return; + } + ENVOY_LOG(debug, "Skipping update for resource {} which has pending update", name_); + // TODO(eziskind) add stat +} + +void OverloadManagerImpl::Resource::onSuccess(const ResourceUsage& usage) { + pending_update_ = false; + manager_.updateResourcePressure(name_, usage.resource_pressure_); +} + +void OverloadManagerImpl::Resource::onFailure(const EnvoyException& error) { + pending_update_ = false; + ENVOY_LOG(info, "Failed to update resource {}: {}", name_, error.what()); + + // TODO(eziskind): add stat +} + +} // namespace Server +} // namespace Envoy diff --git a/source/server/overload_manager_impl.h b/source/server/overload_manager_impl.h new file mode 100644 index 0000000000000..439933602e6c7 --- /dev/null +++ b/source/server/overload_manager_impl.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +#include "envoy/config/overload/v2alpha/overload.pb.validate.h" +#include "envoy/event/dispatcher.h" +#include "envoy/server/overload_manager.h" +#include "envoy/server/resource_monitor.h" + +#include "common/common/logger.h" + +namespace Envoy { +namespace Server { + +class OverloadAction { +public: + OverloadAction(const envoy::config::overload::v2alpha::OverloadAction& config); + + // Updates the current pressure for the given resource and returns whether the action + // has changed state. + bool updateResourcePressure(const std::string& name, double pressure); + + // Returns whether the action is currently active or not. + bool isActive() const; + + class Trigger { + public: + virtual ~Trigger() {} + + // Updates the current value of the metric and returns whether the trigger has changed state. + virtual bool updateValue(double value) PURE; + + // Returns whether the trigger is currently fired or not. + virtual bool isFired() const PURE; + }; + typedef std::unique_ptr TriggerPtr; + +private: + std::unordered_map triggers_; + std::unordered_set fired_triggers_; +}; + +class OverloadManagerImpl : Logger::Loggable, public OverloadManager { +public: + OverloadManagerImpl(Event::Dispatcher& dispatcher, + const envoy::config::overload::v2alpha::OverloadManager& config); + + void start(); + + // Server::OverloadManager + void registerForAction(const std::string& action, Event::Dispatcher& dispatcher, + OverloadActionCb callback) override; + +private: + class Resource : public ResourceMonitor::Callbacks { + public: + Resource(const std::string& name, ResourceMonitorPtr monitor, OverloadManagerImpl& manager) + : name_(name), monitor_(std::move(monitor)), manager_(manager), pending_update_(false) {} + + // ResourceMonitor::Callbacks + void onSuccess(const ResourceUsage& usage) override; + void onFailure(const EnvoyException& error) override; + + void update(); + + private: + const std::string name_; + ResourceMonitorPtr monitor_; + OverloadManagerImpl& manager_; + bool pending_update_; + }; + + struct ActionCallback { + ActionCallback(Event::Dispatcher& dispatcher, OverloadActionCb callback) + : dispatcher_(dispatcher), callback_(callback) {} + Event::Dispatcher& dispatcher_; + OverloadActionCb callback_; + }; + + void updateResourcePressure(const std::string& resource, double pressure); + + bool started_; + Event::Dispatcher& dispatcher_; + const std::chrono::milliseconds refresh_interval_; + Event::TimerPtr timer_; + std::unordered_map resources_; + std::unordered_map actions_; + + typedef std::unordered_multimap ResourceToActionMap; + ResourceToActionMap resource_to_actions_; + + typedef std::unordered_multimap ActionToCallbackMap; + ActionToCallbackMap action_to_callbacks_; +}; + +} // namespace Server +} // namespace Envoy diff --git a/source/server/resource_monitor_config_impl.h b/source/server/resource_monitor_config_impl.h new file mode 100644 index 0000000000000..2fcfcc443907b --- /dev/null +++ b/source/server/resource_monitor_config_impl.h @@ -0,0 +1,21 @@ +#pragma once + +#include "envoy/server/resource_monitor_config.h" + +namespace Envoy { +namespace Server { +namespace Configuration { + +class ResourceMonitorFactoryContextImpl : public ResourceMonitorFactoryContext { +public: + ResourceMonitorFactoryContextImpl(Event::Dispatcher& dispatcher) : dispatcher_(dispatcher) {} + + Event::Dispatcher& dispatcher() override { return dispatcher_; } + +private: + Event::Dispatcher& dispatcher_; +}; + +} // namespace Configuration +} // namespace Server +} // namespace Envoy diff --git a/source/server/server.cc b/source/server/server.cc index 7c682ceed809f..896359fc7ea66 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -171,7 +171,7 @@ InstanceUtil::loadBootstrapConfig(envoy::config::bootstrap::v2::Bootstrap& boots throw EnvoyException("V1 config (detected) with --config-yaml is not supported"); } Json::ObjectSharedPtr config_json = Json::Factory::loadFromFile(options.configPath()); - Config::BootstrapJson::translateBootstrap(*config_json, bootstrap); + Config::BootstrapJson::translateBootstrap(*config_json, bootstrap, options.statsOptions()); MessageUtil::validate(bootstrap); return BootstrapVersion::V1; } @@ -289,7 +289,8 @@ void InstanceImpl::initialize(Options& options, bootstrap_.node(), stats(), Config::Utility::factoryForGrpcApiConfigSource(*async_client_manager_, hds_config, stats()) ->create(), - dispatcher())); + dispatcher(), runtime(), stats(), sslContextManager(), secretManager(), random(), + info_factory_, access_log_manager_)); } for (Stats::SinkPtr& sink : main_config->statsSinks()) { diff --git a/source/server/server.h b/source/server/server.h index 58ca46beb5d29..e3f05572d1bb0 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -221,6 +221,7 @@ class InstanceImpl : Logger::Loggable, public Instance { ConfigTracker::EntryOwnerPtr config_tracker_entry_; SystemTime bootstrap_config_update_time_; Grpc::AsyncClientManagerPtr async_client_manager_; + Upstream::ProdClusterInfoFactory info_factory_; Upstream::HdsDelegatePtr hds_delegate_; }; diff --git a/test/BUILD b/test/BUILD index b097ed29979af..f7044560db8fa 100644 --- a/test/BUILD +++ b/test/BUILD @@ -20,6 +20,9 @@ envoy_cc_test_library( "main.cc", "test_runner.h", ], + external_deps = [ + "abseil_symbolize", + ], deps = [ "//source/common/common:logger_lib", "//source/common/common:thread_lib", diff --git a/test/common/access_log/access_log_formatter_test.cc b/test/common/access_log/access_log_formatter_test.cc index 6cc11f9234643..44c0ddf9eda84 100644 --- a/test/common/access_log/access_log_formatter_test.cc +++ b/test/common/access_log/access_log_formatter_test.cc @@ -73,6 +73,28 @@ TEST(AccessLogFormatterTest, requestInfoFormatter) { EXPECT_EQ("-", response_duration_format.format(header, header, header, request_info)); } + { + RequestInfoFormatter ttlb_duration_format("RESPONSE_TX_DURATION"); + + absl::optional dur_upstream = std::chrono::nanoseconds(10000000); + EXPECT_CALL(request_info, firstUpstreamRxByteReceived()).WillRepeatedly(Return(dur_upstream)); + absl::optional dur_downstream = std::chrono::nanoseconds(25000000); + EXPECT_CALL(request_info, lastDownstreamTxByteSent()).WillRepeatedly(Return(dur_downstream)); + + EXPECT_EQ("15", ttlb_duration_format.format(header, header, header, request_info)); + } + + { + RequestInfoFormatter ttlb_duration_format("RESPONSE_TX_DURATION"); + + absl::optional dur_upstream; + EXPECT_CALL(request_info, firstUpstreamRxByteReceived()).WillRepeatedly(Return(dur_upstream)); + absl::optional dur_downstream; + EXPECT_CALL(request_info, lastDownstreamTxByteSent()).WillRepeatedly(Return(dur_downstream)); + + EXPECT_EQ("-", ttlb_duration_format.format(header, header, header, request_info)); + } + { RequestInfoFormatter bytes_received_format("BYTES_RECEIVED"); EXPECT_CALL(request_info, bytesReceived()).WillOnce(Return(1)); @@ -115,7 +137,7 @@ TEST(AccessLogFormatterTest, requestInfoFormatter) { { RequestInfoFormatter response_flags_format("RESPONSE_FLAGS"); - ON_CALL(request_info, getResponseFlag(RequestInfo::ResponseFlag::LocalReset)) + ON_CALL(request_info, hasResponseFlag(RequestInfo::ResponseFlag::LocalReset)) .WillByDefault(Return(true)); EXPECT_EQ("LR", response_flags_format.format(header, header, header, request_info)); } diff --git a/test/common/access_log/access_log_impl_test.cc b/test/common/access_log/access_log_impl_test.cc index 8b51a6b3d0a41..cf9b939d825cd 100644 --- a/test/common/access_log/access_log_impl_test.cc +++ b/test/common/access_log/access_log_impl_test.cc @@ -52,8 +52,8 @@ class AccessLogImplTest : public testing::Test { public: AccessLogImplTest() : file_(new Filesystem::MockFile()) { ON_CALL(context_, runtime()).WillByDefault(ReturnRef(runtime_)); - EXPECT_CALL(context_, accessLogManager()).WillOnce(ReturnRef(log_manager_)); - EXPECT_CALL(log_manager_, createAccessLog(_)).WillOnce(Return(file_)); + ON_CALL(context_, accessLogManager()).WillByDefault(ReturnRef(log_manager_)); + ON_CALL(log_manager_, createAccessLog(_)).WillByDefault(Return(file_)); ON_CALL(*file_, write(_)).WillByDefault(SaveArg<0>(&output_)); } @@ -724,6 +724,149 @@ name: envoy.file_access_log log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); } +TEST_F(AccessLogImplTest, ResponseFlagFilterAnyFlag) { + const std::string yaml = R"EOF( +name: envoy.file_access_log +filter: + response_flag_filter: {} +config: + path: /dev/null + )EOF"; + + InstanceSharedPtr log = AccessLogFactory::fromProto(parseAccessLogFromV2Yaml(yaml), context_); + + EXPECT_CALL(*file_, write(_)).Times(0); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); + + request_info_.setResponseFlag(RequestInfo::ResponseFlag::NoRouteFound); + EXPECT_CALL(*file_, write(_)); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); +} + +TEST_F(AccessLogImplTest, ResponseFlagFilterSpecificFlag) { + const std::string yaml = R"EOF( +name: envoy.file_access_log +filter: + response_flag_filter: + flags: + - UO +config: + path: /dev/null + )EOF"; + + InstanceSharedPtr log = AccessLogFactory::fromProto(parseAccessLogFromV2Yaml(yaml), context_); + + EXPECT_CALL(*file_, write(_)).Times(0); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); + + request_info_.setResponseFlag(RequestInfo::ResponseFlag::NoRouteFound); + EXPECT_CALL(*file_, write(_)).Times(0); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); + + request_info_.setResponseFlag(RequestInfo::ResponseFlag::UpstreamOverflow); + EXPECT_CALL(*file_, write(_)); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); +} + +TEST_F(AccessLogImplTest, ResponseFlagFilterSeveralFlags) { + const std::string yaml = R"EOF( +name: envoy.file_access_log +filter: + response_flag_filter: + flags: + - UO + - RL +config: + path: /dev/null + )EOF"; + + InstanceSharedPtr log = AccessLogFactory::fromProto(parseAccessLogFromV2Yaml(yaml), context_); + + EXPECT_CALL(*file_, write(_)).Times(0); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); + + request_info_.setResponseFlag(RequestInfo::ResponseFlag::NoRouteFound); + EXPECT_CALL(*file_, write(_)).Times(0); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); + + request_info_.setResponseFlag(RequestInfo::ResponseFlag::UpstreamOverflow); + EXPECT_CALL(*file_, write(_)); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info_); +} + +TEST_F(AccessLogImplTest, ResponseFlagFilterAllFlagsInPGV) { + const std::string yaml = R"EOF( +name: envoy.file_access_log +filter: + response_flag_filter: + flags: + - LH + - UH + - UT + - LR + - UR + - UF + - UC + - UO + - NR + - DI + - FI + - RL + - UAEX +config: + path: /dev/null + )EOF"; + + static_assert(RequestInfo::ResponseFlag::LastFlag == 0x1000, + "A flag has been added. Fix this code."); + + std::vector all_response_flags = { + RequestInfo::ResponseFlag::FailedLocalHealthCheck, + RequestInfo::ResponseFlag::NoHealthyUpstream, + RequestInfo::ResponseFlag::UpstreamRequestTimeout, + RequestInfo::ResponseFlag::LocalReset, + RequestInfo::ResponseFlag::UpstreamRemoteReset, + RequestInfo::ResponseFlag::UpstreamConnectionFailure, + RequestInfo::ResponseFlag::UpstreamConnectionTermination, + RequestInfo::ResponseFlag::UpstreamOverflow, + RequestInfo::ResponseFlag::NoRouteFound, + RequestInfo::ResponseFlag::DelayInjected, + RequestInfo::ResponseFlag::FaultInjected, + RequestInfo::ResponseFlag::RateLimited, + RequestInfo::ResponseFlag::UnauthorizedExternalService, + }; + + InstanceSharedPtr log = AccessLogFactory::fromProto(parseAccessLogFromV2Yaml(yaml), context_); + + for (const auto response_flag : all_response_flags) { + TestRequestInfo request_info; + request_info.setResponseFlag(response_flag); + EXPECT_CALL(*file_, write(_)); + log->log(&request_headers_, &response_headers_, &response_trailers_, request_info); + } +} + +TEST_F(AccessLogImplTest, ResponseFlagFilterUnsupportedFlag) { + const std::string yaml = R"EOF( +name: envoy.file_access_log +filter: + response_flag_filter: + flags: + - UnsupportedFlag +config: + path: /dev/null + )EOF"; + + EXPECT_THROW_WITH_MESSAGE( + AccessLogFactory::fromProto(parseAccessLogFromV2Yaml(yaml), context_), + ProtoValidationException, + "Proto constraint validation failed (AccessLogFilterValidationError.ResponseFlagFilter: " + "[\"embedded message failed validation\"] | caused by " + "ResponseFlagFilterValidationError.Flags[i]: [\"value must be in list \" [\"LH\" \"UH\" " + "\"UT\" \"LR\" \"UR\" \"UF\" \"UC\" \"UO\" \"NR\" \"DI\" \"FI\" \"RL\" \"UAEX\"]]): " + "response_flag_filter {\n flags: \"UnsupportedFlag\"\n}\n"); +} + } // namespace } // namespace AccessLog } // namespace Envoy diff --git a/test/common/access_log/test_util.h b/test/common/access_log/test_util.h index b6de049ee64ed..dc750cc900d22 100644 --- a/test/common/access_log/test_util.h +++ b/test/common/access_log/test_util.h @@ -21,17 +21,20 @@ class TestRequestInfo : public RequestInfo::RequestInfo { SystemTime startTime() const override { return start_time_; } MonotonicTime startTimeMonotonic() const override { return start_time_monotonic_; } - void addBytesReceived(uint64_t) override { NOT_IMPLEMENTED; } + void addBytesReceived(uint64_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } uint64_t bytesReceived() const override { return 1; } absl::optional protocol() const override { return protocol_; } void protocol(Http::Protocol protocol) override { protocol_ = protocol; } absl::optional responseCode() const override { return response_code_; } - void addBytesSent(uint64_t) override { NOT_IMPLEMENTED; } + void addBytesSent(uint64_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } uint64_t bytesSent() const override { return 2; } - - bool getResponseFlag(Envoy::RequestInfo::ResponseFlag response_flag) const override { + bool intersectResponseFlags(uint64_t response_flags) const override { + return (response_flags_ & response_flags) != 0; + } + bool hasResponseFlag(Envoy::RequestInfo::ResponseFlag response_flag) const override { return response_flags_ & response_flag; } + bool hasAnyResponseFlag() const override { return response_flags_ != 0; } void setResponseFlag(Envoy::RequestInfo::ResponseFlag response_flag) override { response_flags_ |= response_flag; } diff --git a/test/common/buffer/owned_impl_test.cc b/test/common/buffer/owned_impl_test.cc index 9977ea5112c99..9d02fc1c1c099 100644 --- a/test/common/buffer/owned_impl_test.cc +++ b/test/common/buffer/owned_impl_test.cc @@ -83,34 +83,34 @@ TEST_F(OwnedImplTest, Write) { Buffer::OwnedImpl buffer; buffer.add("example"); EXPECT_CALL(os_sys_calls, writev(_, _, _)).WillOnce(Return(7)); - int rc = buffer.write(-1); - EXPECT_EQ(7, rc); + Api::SysCallResult result = buffer.write(-1); + EXPECT_EQ(7, result.rc_); EXPECT_EQ(0, buffer.length()); buffer.add("example"); EXPECT_CALL(os_sys_calls, writev(_, _, _)).WillOnce(Return(6)); - rc = buffer.write(-1); - EXPECT_EQ(6, rc); + result = buffer.write(-1); + EXPECT_EQ(6, result.rc_); EXPECT_EQ(1, buffer.length()); EXPECT_CALL(os_sys_calls, writev(_, _, _)).WillOnce(Return(0)); - rc = buffer.write(-1); - EXPECT_EQ(0, rc); + result = buffer.write(-1); + EXPECT_EQ(0, result.rc_); EXPECT_EQ(1, buffer.length()); EXPECT_CALL(os_sys_calls, writev(_, _, _)).WillOnce(Return(-1)); - rc = buffer.write(-1); - EXPECT_EQ(-1, rc); + result = buffer.write(-1); + EXPECT_EQ(-1, result.rc_); EXPECT_EQ(1, buffer.length()); EXPECT_CALL(os_sys_calls, writev(_, _, _)).WillOnce(Return(1)); - rc = buffer.write(-1); - EXPECT_EQ(1, rc); + result = buffer.write(-1); + EXPECT_EQ(1, result.rc_); EXPECT_EQ(0, buffer.length()); EXPECT_CALL(os_sys_calls, writev(_, _, _)).Times(0); - rc = buffer.write(-1); - EXPECT_EQ(0, rc); + result = buffer.write(-1); + EXPECT_EQ(0, result.rc_); EXPECT_EQ(0, buffer.length()); } @@ -120,18 +120,18 @@ TEST_F(OwnedImplTest, Read) { Buffer::OwnedImpl buffer; EXPECT_CALL(os_sys_calls, readv(_, _, _)).WillOnce(Return(0)); - int rc = buffer.read(-1, 100); - EXPECT_EQ(0, rc); + Api::SysCallResult result = buffer.read(-1, 100); + EXPECT_EQ(0, result.rc_); EXPECT_EQ(0, buffer.length()); EXPECT_CALL(os_sys_calls, readv(_, _, _)).WillOnce(Return(-1)); - rc = buffer.read(-1, 100); - EXPECT_EQ(-1, rc); + result = buffer.read(-1, 100); + EXPECT_EQ(-1, result.rc_); EXPECT_EQ(0, buffer.length()); EXPECT_CALL(os_sys_calls, readv(_, _, _)).Times(0); - rc = buffer.read(-1, 0); - EXPECT_EQ(0, rc); + result = buffer.read(-1, 0); + EXPECT_EQ(0, result.rc_); EXPECT_EQ(0, buffer.length()); } diff --git a/test/common/buffer/watermark_buffer_test.cc b/test/common/buffer/watermark_buffer_test.cc index 1514de59a15a6..b70f423889fae 100644 --- a/test/common/buffer/watermark_buffer_test.cc +++ b/test/common/buffer/watermark_buffer_test.cc @@ -131,11 +131,11 @@ TEST_F(WatermarkBufferTest, WatermarkFdFunctions) { int bytes_written_total = 0; while (bytes_written_total < 20) { - int bytes_written = buffer_.write(pipe_fds[1]); - if (bytes_written < 0) { - ASSERT_EQ(EAGAIN, errno); + Api::SysCallResult result = buffer_.write(pipe_fds[1]); + if (result.rc_ < 0) { + ASSERT_EQ(EAGAIN, result.errno_); } else { - bytes_written_total += bytes_written; + bytes_written_total += result.rc_; } } EXPECT_EQ(1, times_high_watermark_called_); @@ -144,7 +144,8 @@ TEST_F(WatermarkBufferTest, WatermarkFdFunctions) { int bytes_read_total = 0; while (bytes_read_total < 20) { - bytes_read_total += buffer_.read(pipe_fds[0], 20); + Api::SysCallResult result = buffer_.read(pipe_fds[0], 20); + bytes_read_total += result.rc_; } EXPECT_EQ(2, times_high_watermark_called_); EXPECT_EQ(20, buffer_.length()); diff --git a/test/common/common/BUILD b/test/common/common/BUILD index aa7f952e6fcd5..685f9970b5a2b 100644 --- a/test/common/common/BUILD +++ b/test/common/common/BUILD @@ -10,6 +10,23 @@ load( envoy_package() +envoy_cc_test( + name = "backoff_strategy_test", + srcs = ["backoff_strategy_test.cc"], + deps = [ + "//source/common/common:backoff_lib", + "//test/mocks/runtime:runtime_mocks", + ], +) + +envoy_cc_test( + name = "assert_test", + srcs = ["assert_test.cc"], + deps = [ + "//source/common/common:assert_lib", + ], +) + envoy_cc_test( name = "base64_test", srcs = ["base64_test.cc"], @@ -71,6 +88,20 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "matchers_test", + srcs = ["matchers_test.cc"], + deps = [ + "//source/common/common:matchers_lib", + "//source/common/config:metadata_lib", + "//source/common/protobuf:utility_lib", + "@envoy_api//envoy/api/v2/core:base_cc", + "@envoy_api//envoy/type/matcher:metadata_cc", + "@envoy_api//envoy/type/matcher:number_cc", + "@envoy_api//envoy/type/matcher:string_cc", + ], +) + envoy_cc_test( name = "utility_test", srcs = ["utility_test.cc"], @@ -125,8 +156,10 @@ envoy_cc_test( name = "block_memory_hash_set_test", srcs = ["block_memory_hash_set_test.cc"], deps = [ + "//include/envoy/stats:stats_interface", "//source/common/common:block_memory_hash_set_lib", "//source/common/common:hash_lib", + "//source/common/stats:stats_lib", ], ) diff --git a/test/common/common/assert_test.cc b/test/common/common/assert_test.cc new file mode 100644 index 0000000000000..49f3c8d7d748b --- /dev/null +++ b/test/common/common/assert_test.cc @@ -0,0 +1,16 @@ +#include "common/common/assert.h" + +#include "gtest/gtest.h" + +namespace Envoy { + +TEST(Assert, VariousLogs) { + Logger::StderrSinkDelegate stderr_sink(Logger::Registry::getSink()); // For coverage build. + EXPECT_DEATH({ RELEASE_ASSERT(0, ""); }, ".*assert failure: 0.*"); + EXPECT_DEATH({ RELEASE_ASSERT(0, "With some logs"); }, + ".*assert failure: 0. Details: With some logs.*"); + EXPECT_DEATH({ RELEASE_ASSERT(0 == EAGAIN, fmt::format("using {}", "fmt")); }, + ".*assert failure: 0 == EAGAIN. Details: using fmt.*"); +} + +} // namespace Envoy diff --git a/test/common/common/backoff_strategy_test.cc b/test/common/common/backoff_strategy_test.cc new file mode 100644 index 0000000000000..010768349983d --- /dev/null +++ b/test/common/common/backoff_strategy_test.cc @@ -0,0 +1,59 @@ +#include "common/common/backoff_strategy.h" + +#include "test/mocks/runtime/mocks.h" + +#include "gtest/gtest.h" + +using testing::NiceMock; +using testing::Return; + +namespace Envoy { + +TEST(BackOffStrategyTest, JitteredBackOffBasicFlow) { + NiceMock random; + ON_CALL(random, random()).WillByDefault(Return(27)); + + JitteredBackOffStrategy jittered_back_off(25, 30, random); + EXPECT_EQ(2, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(27, jittered_back_off.nextBackOffMs()); +} + +TEST(BackOffStrategyTest, JitteredBackOffBasicReset) { + NiceMock random; + ON_CALL(random, random()).WillByDefault(Return(27)); + + JitteredBackOffStrategy jittered_back_off(25, 30, random); + EXPECT_EQ(2, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(27, jittered_back_off.nextBackOffMs()); + + jittered_back_off.reset(); + EXPECT_EQ(2, jittered_back_off.nextBackOffMs()); // Should start from start +} + +TEST(BackOffStrategyTest, JitteredBackOffWithMaxInterval) { + NiceMock random; + ON_CALL(random, random()).WillByDefault(Return(1024)); + + JitteredBackOffStrategy jittered_back_off(5, 100, random); + EXPECT_EQ(4, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(4, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(9, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(49, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(94, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(94, jittered_back_off.nextBackOffMs()); // Should return Max here +} + +TEST(BackOffStrategyTest, JitteredBackOffWithMaxIntervalReset) { + NiceMock random; + ON_CALL(random, random()).WillByDefault(Return(1024)); + + JitteredBackOffStrategy jittered_back_off(5, 100, random); + EXPECT_EQ(4, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(4, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(9, jittered_back_off.nextBackOffMs()); + EXPECT_EQ(49, jittered_back_off.nextBackOffMs()); + + jittered_back_off.reset(); + EXPECT_EQ(4, jittered_back_off.nextBackOffMs()); // Should start from start +} +} // namespace Envoy diff --git a/test/common/common/block_memory_hash_set_test.cc b/test/common/common/block_memory_hash_set_test.cc index c4e96943902f4..f0082da91ad56 100644 --- a/test/common/common/block_memory_hash_set_test.cc +++ b/test/common/common/block_memory_hash_set_test.cc @@ -4,9 +4,12 @@ #include #include +#include "envoy/stats/stats.h" + #include "common/common/block_memory_hash_set.h" #include "common/common/fmt.h" #include "common/common/hash.h" +#include "common/stats/stats_impl.h" #include "absl/strings/string_view.h" #include "gtest/gtest.h" @@ -19,12 +22,15 @@ class BlockMemoryHashSetTest : public testing::Test { // TestValue that doesn't define a hash. struct TestValueBase { absl::string_view key() const { return name; } - void initialize(absl::string_view key) { - uint64_t xfer = std::min(sizeof(name) - 1, key.size()); - memcpy(name, key.data(), xfer); - name[xfer] = '\0'; + void initialize(absl::string_view key, const Stats::StatsOptions& stats_options) { + ASSERT(key.size() <= stats_options.maxNameLength()); + memcpy(name, key.data(), key.size()); + name[key.size()] = '\0'; + } + static uint64_t structSizeWithOptions(const Stats::StatsOptions& stats_options) { + UNREFERENCED_PARAMETER(stats_options); + return sizeof(TestValue); } - static uint64_t size() { return sizeof(TestValue); } int64_t number; char name[256]; @@ -43,9 +49,10 @@ class BlockMemoryHashSetTest : public testing::Test { typedef BlockMemoryHashSet::ValueCreatedPair ValueCreatedPair; template void setUp() { - options_.capacity = 100; - options_.num_slots = 5; - const uint32_t mem_size = BlockMemoryHashSet::numBytes(options_); + hash_set_options_.capacity = 100; + hash_set_options_.num_slots = 5; + const uint32_t mem_size = + BlockMemoryHashSet::numBytes(hash_set_options_, stats_options_); memory_.reset(new uint8_t[mem_size]); memset(memory_.get(), 0, mem_size); } @@ -59,10 +66,11 @@ class BlockMemoryHashSetTest : public testing::Test { std::string ret; static const uint32_t sentinal = BlockMemoryHashSet::Sentinal; std::string control_string = - fmt::format("{} size={} free_cell_index={}", hs.control_->options.toString(), + fmt::format("{} size={} free_cell_index={}", hs.control_->hash_set_options.toString(), hs.control_->size, hs.control_->free_cell_index); - ret = fmt::format("options={}\ncontrol={}\n", hs.control_->options.toString(), control_string); - for (uint32_t i = 0; i < hs.control_->options.num_slots; ++i) { + ret = fmt::format("options={}\ncontrol={}\n", hs.control_->hash_set_options.toString(), + control_string); + for (uint32_t i = 0; i < hs.control_->hash_set_options.num_slots; ++i) { ret += fmt::format("slot {}:", i); for (uint32_t j = hs.slots_[i]; j != sentinal; j = hs.getCell(j).next_cell_index) { ret += " " + std::string(hs.getCell(j).value.key()); @@ -72,23 +80,27 @@ class BlockMemoryHashSetTest : public testing::Test { return ret; } - BlockMemoryHashSetOptions options_; + BlockMemoryHashSetOptions hash_set_options_; + Stats::StatsOptionsImpl stats_options_; std::unique_ptr memory_; }; TEST_F(BlockMemoryHashSetTest, initAndAttach) { setUp(); { - BlockMemoryHashSet hash_set1(options_, true, memory_.get()); // init - BlockMemoryHashSet hash_set2(options_, false, memory_.get()); // attach + BlockMemoryHashSet hash_set1(hash_set_options_, true, memory_.get(), + stats_options_); // init + BlockMemoryHashSet hash_set2(hash_set_options_, false, memory_.get(), + stats_options_); // attach } // If we tweak an option, we can no longer attach it. bool constructor_completed = false; bool constructor_threw = false; try { - options_.capacity = 99; - BlockMemoryHashSet hash_set3(options_, false, memory_.get()); + hash_set_options_.capacity = 99; + BlockMemoryHashSet hash_set3(hash_set_options_, false, memory_.get(), + stats_options_); constructor_completed = false; } catch (const std::exception& e) { constructor_threw = true; @@ -100,7 +112,7 @@ TEST_F(BlockMemoryHashSetTest, initAndAttach) { TEST_F(BlockMemoryHashSetTest, putRemove) { setUp(); { - BlockMemoryHashSet hash_set1(options_, true, memory_.get()); + BlockMemoryHashSet hash_set1(hash_set_options_, true, memory_.get(), stats_options_); hash_set1.sanityCheck(); EXPECT_EQ(0, hash_set1.size()); EXPECT_EQ(nullptr, hash_set1.get("no such key")); @@ -121,7 +133,8 @@ TEST_F(BlockMemoryHashSetTest, putRemove) { { // Now init a new hash-map with the same memory. - BlockMemoryHashSet hash_set2(options_, false, memory_.get()); + BlockMemoryHashSet hash_set2(hash_set_options_, false, memory_.get(), + stats_options_); EXPECT_EQ(1, hash_set2.size()); EXPECT_EQ(nullptr, hash_set2.get("no such key")); EXPECT_EQ(6789, hash_set2.get("good key")->number) << hashSetToString(hash_set2); @@ -136,21 +149,21 @@ TEST_F(BlockMemoryHashSetTest, putRemove) { TEST_F(BlockMemoryHashSetTest, tooManyValues) { setUp(); - BlockMemoryHashSet hash_set1(options_, true, memory_.get()); + BlockMemoryHashSet hash_set1(hash_set_options_, true, memory_.get(), stats_options_); std::vector keys; - for (uint32_t i = 0; i < options_.capacity + 1; ++i) { + for (uint32_t i = 0; i < hash_set_options_.capacity + 1; ++i) { keys.push_back(fmt::format("key{}", i)); } - for (uint32_t i = 0; i < options_.capacity; ++i) { + for (uint32_t i = 0; i < hash_set_options_.capacity; ++i) { TestValue* value = hash_set1.insert(keys[i]).first; ASSERT_NE(nullptr, value); value->number = i; } hash_set1.sanityCheck(); - EXPECT_EQ(options_.capacity, hash_set1.size()); + EXPECT_EQ(hash_set_options_.capacity, hash_set1.size()); - for (uint32_t i = 0; i < options_.capacity; ++i) { + for (uint32_t i = 0; i < hash_set_options_.capacity; ++i) { const TestValue* value = hash_set1.get(keys[i]); ASSERT_NE(nullptr, value); EXPECT_EQ(i, value->number); @@ -158,29 +171,30 @@ TEST_F(BlockMemoryHashSetTest, tooManyValues) { hash_set1.sanityCheck(); // We can't fit one more value. - EXPECT_EQ(nullptr, hash_set1.insert(keys[options_.capacity]).first); + EXPECT_EQ(nullptr, hash_set1.insert(keys[hash_set_options_.capacity]).first); hash_set1.sanityCheck(); - EXPECT_EQ(options_.capacity, hash_set1.size()); + EXPECT_EQ(hash_set_options_.capacity, hash_set1.size()); // Now remove everything one by one. - for (uint32_t i = 0; i < options_.capacity; ++i) { + for (uint32_t i = 0; i < hash_set_options_.capacity; ++i) { EXPECT_TRUE(hash_set1.remove(keys[i])); } hash_set1.sanityCheck(); EXPECT_EQ(0, hash_set1.size()); // Now we can put in that last key we weren't able to before. - TestValue* value = hash_set1.insert(keys[options_.capacity]).first; + TestValue* value = hash_set1.insert(keys[hash_set_options_.capacity]).first; EXPECT_NE(nullptr, value); value->number = 314519; EXPECT_EQ(1, hash_set1.size()); - EXPECT_EQ(314519, hash_set1.get(keys[options_.capacity])->number); + EXPECT_EQ(314519, hash_set1.get(keys[hash_set_options_.capacity])->number); hash_set1.sanityCheck(); } TEST_F(BlockMemoryHashSetTest, severalKeysZeroHash) { setUp(); - BlockMemoryHashSet hash_set1(options_, true, memory_.get()); + BlockMemoryHashSet hash_set1(hash_set_options_, true, memory_.get(), + stats_options_); hash_set1.insert("one").first->number = 1; hash_set1.insert("two").first->number = 2; hash_set1.insert("three").first->number = 3; @@ -194,8 +208,9 @@ TEST_F(BlockMemoryHashSetTest, severalKeysZeroHash) { TEST_F(BlockMemoryHashSetTest, sanityCheckZeroedMemoryDeathTest) { setUp(); - BlockMemoryHashSet hash_set1(options_, true, memory_.get()); - memset(memory_.get(), 0, hash_set1.numBytes()); + BlockMemoryHashSet hash_set1(hash_set_options_, true, memory_.get(), + stats_options_); + memset(memory_.get(), 0, hash_set1.numBytes(stats_options_)); EXPECT_DEATH(hash_set1.sanityCheck(), ""); } diff --git a/test/common/common/matchers_test.cc b/test/common/common/matchers_test.cc new file mode 100644 index 0000000000000..5e081276dbf7b --- /dev/null +++ b/test/common/common/matchers_test.cc @@ -0,0 +1,167 @@ +#include "envoy/api/v2/core/base.pb.h" +#include "envoy/type/matcher/metadata.pb.h" +#include "envoy/type/matcher/number.pb.h" +#include "envoy/type/matcher/string.pb.h" + +#include "common/common/matchers.h" +#include "common/config/metadata.h" +#include "common/protobuf/protobuf.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Matcher { +namespace { + +TEST(MetadataTest, MatchNullValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_null_value(ProtobufWkt::NullValue::NULL_VALUE); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_null_match(); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); +} + +TEST(MetadataTest, MatchDoubleValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_number_value(9); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_double_match()->set_exact(1); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_double_match()->set_exact(9); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + + auto r = matcher.mutable_value()->mutable_double_match()->mutable_range(); + r->set_start(9.1); + r->set_end(10); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + + r = matcher.mutable_value()->mutable_double_match()->mutable_range(); + r->set_start(8.9); + r->set_end(9); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + + r = matcher.mutable_value()->mutable_double_match()->mutable_range(); + r->set_start(9); + r->set_end(9.1); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); +} + +TEST(MetadataTest, MatchStringExactValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_string_value("prod"); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_string_match()->set_exact("prod"); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); +} + +TEST(MetadataTest, MatchStringPrefixValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_string_value("prodabc"); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_string_match()->set_prefix("prodx"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_string_match()->set_prefix("prod"); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); +} + +TEST(MetadataTest, MatchStringSuffixValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_string_value("abcprod"); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_string_match()->set_suffix("prodx"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->mutable_string_match()->set_suffix("prod"); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + ; +} + +TEST(MetadataTest, MatchBoolValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_bool_value(true); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->set_bool_match(false); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->set_bool_match(true); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); +} + +TEST(MetadataTest, MatchPresentValue) { + envoy::api::v2::core::Metadata metadata; + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.a", "label") + .set_string_value("test"); + Envoy::Config::Metadata::mutableMetadataValue(metadata, "envoy.filter.b", "label") + .set_number_value(1); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("envoy.filter.b"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->set_present_match(false); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + matcher.mutable_value()->set_present_match(true); + EXPECT_TRUE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); + + matcher.clear_path(); + matcher.add_path()->set_key("unknown"); + EXPECT_FALSE(Envoy::Matchers::MetadataMatcher(matcher).match(metadata)); +} + +} // namespace +} // namespace Matcher +} // namespace Envoy diff --git a/test/common/common/utility_speed_test.cc b/test/common/common/utility_speed_test.cc index 1e8b31d1dde88..54a2bc5354b9b 100644 --- a/test/common/common/utility_speed_test.cc +++ b/test/common/common/utility_speed_test.cc @@ -111,7 +111,7 @@ BENCHMARK(BM_RTrimStringViewAlreadyTrimmedAndMakeString); static void BM_FindToken(benchmark::State& state) { const absl::string_view cache_control(CacheControl, CacheControlLength); for (auto _ : state) { - RELEASE_ASSERT(Envoy::StringUtil::findToken(cache_control, ",", "no-transform")); + RELEASE_ASSERT(Envoy::StringUtil::findToken(cache_control, ",", "no-transform"), ""); } } BENCHMARK(BM_FindToken); @@ -153,7 +153,7 @@ static bool findTokenWithoutSplitting(absl::string_view str, char delim, absl::s static void BM_FindTokenWithoutSplitting(benchmark::State& state) { const absl::string_view cache_control(CacheControl, CacheControlLength); for (auto _ : state) { - RELEASE_ASSERT(findTokenWithoutSplitting(cache_control, ',', "no-transform", true)); + RELEASE_ASSERT(findTokenWithoutSplitting(cache_control, ',', "no-transform", true), ""); } } BENCHMARK(BM_FindTokenWithoutSplitting); @@ -168,7 +168,7 @@ static void BM_FindTokenValueNestedSplit(benchmark::State& state) { max_age = Envoy::StringUtil::trim(name_value[1]); } } - RELEASE_ASSERT(max_age == "300"); + RELEASE_ASSERT(max_age == "300", ""); } } BENCHMARK(BM_FindTokenValueNestedSplit); @@ -184,7 +184,7 @@ static void BM_FindTokenValueSearchForEqual(benchmark::State& state) { max_age = Envoy::StringUtil::trim(token.substr(equals + 1)); } } - RELEASE_ASSERT(max_age == "300"); + RELEASE_ASSERT(max_age == "300", ""); } } BENCHMARK(BM_FindTokenValueSearchForEqual); @@ -199,7 +199,7 @@ static void BM_FindTokenValueNoSplit(benchmark::State& state) { max_age = Envoy::StringUtil::trim(token); } } - RELEASE_ASSERT(max_age == "300"); + RELEASE_ASSERT(max_age == "300", ""); } } BENCHMARK(BM_FindTokenValueNoSplit); diff --git a/test/common/config/BUILD b/test/common/config/BUILD index eff9ccd98328e..555bc63f7ec76 100644 --- a/test/common/config/BUILD +++ b/test/common/config/BUILD @@ -47,6 +47,7 @@ envoy_cc_test( "//test/mocks/config:config_mocks", "//test/mocks/event:event_mocks", "//test/mocks/grpc:grpc_mocks", + "//test/mocks/runtime:runtime_mocks", "//test/test_common:logging_lib", "//test/test_common:utility_lib", "@envoy_api//envoy/api/v2:discovery_cc", diff --git a/test/common/config/grpc_mux_impl_test.cc b/test/common/config/grpc_mux_impl_test.cc index ed1606109df6d..d0a1be52d76e3 100644 --- a/test/common/config/grpc_mux_impl_test.cc +++ b/test/common/config/grpc_mux_impl_test.cc @@ -10,6 +10,7 @@ #include "test/mocks/config/mocks.h" #include "test/mocks/event/mocks.h" #include "test/mocks/grpc/mocks.h" +#include "test/mocks/runtime/mocks.h" #include "test/test_common/logging.h" #include "test/test_common/utility.h" @@ -44,7 +45,7 @@ class GrpcMuxImplTest : public testing::Test { dispatcher_, *Protobuf::DescriptorPool::generated_pool()->FindMethodByName( "envoy.service.discovery.v2.AggregatedDiscoveryService.StreamAggregatedResources"), - time_source_)); + random_, time_source_)); } void expectSendMessage(const std::string& type_url, @@ -72,6 +73,7 @@ class GrpcMuxImplTest : public testing::Test { envoy::api::v2::core::Node node_; NiceMock dispatcher_; + Runtime::MockRandomGenerator random_; Grpc::MockAsyncClient* async_client_; Event::MockTimer* timer_; Event::TimerCb timer_cb_; @@ -112,6 +114,7 @@ TEST_F(GrpcMuxImplTest, ResetStream) { expectSendMessage("baz", {"z"}, ""); grpc_mux_->start(); + EXPECT_CALL(random_, random()); EXPECT_CALL(*timer_, enableTimer(_)); grpc_mux_->onRemoteClose(Grpc::Status::GrpcStatus::Canceled, ""); EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); diff --git a/test/common/config/grpc_subscription_impl_test.cc b/test/common/config/grpc_subscription_impl_test.cc index 546dd5e67d209..fd0974a961f7a 100644 --- a/test/common/config/grpc_subscription_impl_test.cc +++ b/test/common/config/grpc_subscription_impl_test.cc @@ -14,15 +14,20 @@ class GrpcSubscriptionImplTest : public GrpcSubscriptionTestHarness, public test TEST_F(GrpcSubscriptionImplTest, StreamCreationFailure) { InSequence s; EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, onConfigUpdateFailed(_)); + EXPECT_CALL(random_, random()); EXPECT_CALL(*timer_, enableTimer(_)); subscription_->start({"cluster0", "cluster1"}, callbacks_); + verifyStats(2, 0, 0, 1, 0); // Ensure this doesn't cause an issue by sending a request, since we don't // have a gRPC stream. subscription_->updateResources({"cluster2"}); + // Retry and succeed. EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + expectSendMessage({"cluster2"}, ""); timer_cb_(); verifyStats(3, 0, 0, 1, 0); @@ -35,6 +40,7 @@ TEST_F(GrpcSubscriptionImplTest, RemoteStreamClose) { Http::HeaderMapPtr trailers{new Http::TestHeaderMapImpl{}}; subscription_->grpcMux().onReceiveTrailingMetadata(std::move(trailers)); EXPECT_CALL(*timer_, enableTimer(_)); + EXPECT_CALL(random_, random()); subscription_->grpcMux().onRemoteClose(Grpc::Status::GrpcStatus::Canceled, ""); verifyStats(1, 0, 0, 0, 0); // Retry and succeed. diff --git a/test/common/config/grpc_subscription_test_harness.h b/test/common/config/grpc_subscription_test_harness.h index c1fa1694fe5e3..06e22e0ef28a1 100644 --- a/test/common/config/grpc_subscription_test_harness.h +++ b/test/common/config/grpc_subscription_test_harness.h @@ -40,7 +40,7 @@ class GrpcSubscriptionTestHarness : public SubscriptionTestHarness { })); subscription_.reset( new GrpcEdsSubscriptionImpl(node_, std::unique_ptr(async_client_), - dispatcher_, *method_descriptor_, stats_)); + dispatcher_, random_, *method_descriptor_, stats_)); } ~GrpcSubscriptionTestHarness() { EXPECT_CALL(async_stream_, sendMessage(_, false)); } @@ -129,6 +129,7 @@ class GrpcSubscriptionTestHarness : public SubscriptionTestHarness { Grpc::MockAsyncClient* async_client_; NiceMock cm_; Event::MockDispatcher dispatcher_; + Runtime::MockRandomGenerator random_; Event::MockTimer* timer_; Event::TimerCb timer_cb_; envoy::api::v2::core::Node node_; diff --git a/test/common/config/subscription_factory_test.cc b/test/common/config/subscription_factory_test.cc index 40139a53d3e7f..0dc79fc2b32e7 100644 --- a/test/common/config/subscription_factory_test.cc +++ b/test/common/config/subscription_factory_test.cc @@ -280,6 +280,7 @@ TEST_F(SubscriptionFactoryTest, GrpcSubscription) { })); return async_client_factory; })); + EXPECT_CALL(random_, random()); EXPECT_CALL(dispatcher_, createTimer_(_)); EXPECT_CALL(callbacks_, onConfigUpdateFailed(_)); subscriptionFromConfigSource(config)->start({"static_cluster"}, callbacks_); diff --git a/test/common/config/utility_test.cc b/test/common/config/utility_test.cc index d25fc023fba19..26cef9151bc8e 100644 --- a/test/common/config/utility_test.cc +++ b/test/common/config/utility_test.cc @@ -81,7 +81,8 @@ TEST(UtilityTest, TranslateApiConfigSource) { api_config_source_grpc); EXPECT_EQ(envoy::api::v2::core::ApiConfigSource::GRPC, api_config_source_grpc.api_type()); EXPECT_EQ(30000, DurationUtil::durationToMilliseconds(api_config_source_grpc.refresh_delay())); - EXPECT_EQ("test_grpc_cluster", api_config_source_grpc.cluster_names(0)); + EXPECT_EQ("test_grpc_cluster", + api_config_source_grpc.grpc_services(0).envoy_grpc().cluster_name()); } TEST(UtilityTest, createTagProducer) { @@ -95,15 +96,16 @@ TEST(UtilityTest, createTagProducer) { } TEST(UtilityTest, ObjNameLength) { - - std::string name = "listenerwithareallyreallylongnamemorethanmaxcharsallowedbyschema"; + Stats::StatsOptionsImpl stats_options; + std::string name = "listenerwithareallyreallyreallyreallyreallyreallyreallyreallyreallyreallyreal" + "lyreallyreallyreallyreallyreallylongnamemorethanmaxcharsallowedbyschema"; std::string err_prefix; std::string err_suffix = fmt::format(": Length of {} ({}) exceeds allowed maximum length ({})", - name, name.length(), Stats::RawStatData::maxObjNameLength()); + name, name.length(), stats_options.maxNameLength()); { err_prefix = "test"; - EXPECT_THROW_WITH_MESSAGE(Utility::checkObjNameLength(err_prefix, name), EnvoyException, - err_prefix + err_suffix); + EXPECT_THROW_WITH_MESSAGE(Utility::checkObjNameLength(err_prefix, name, stats_options), + EnvoyException, err_prefix + err_suffix); } { @@ -113,8 +115,9 @@ TEST(UtilityTest, ObjNameLength) { auto json_object_ptr = Json::Factory::loadFromString(json); envoy::api::v2::Listener listener; - EXPECT_THROW_WITH_MESSAGE(Config::LdsJson::translateListener(*json_object_ptr, listener), - EnvoyException, err_prefix + err_suffix); + EXPECT_THROW_WITH_MESSAGE( + Config::LdsJson::translateListener(*json_object_ptr, listener, stats_options), + EnvoyException, err_prefix + err_suffix); } { @@ -122,8 +125,9 @@ TEST(UtilityTest, ObjNameLength) { std::string json = R"EOF({ "name": ")EOF" + name + R"EOF(", "domains": [], "routes": []})EOF"; auto json_object_ptr = Json::Factory::loadFromString(json); envoy::api::v2::route::VirtualHost vhost; - EXPECT_THROW_WITH_MESSAGE(Config::RdsJson::translateVirtualHost(*json_object_ptr, vhost), - EnvoyException, err_prefix + err_suffix); + EXPECT_THROW_WITH_MESSAGE( + Config::RdsJson::translateVirtualHost(*json_object_ptr, vhost, stats_options), + EnvoyException, err_prefix + err_suffix); } { @@ -135,8 +139,8 @@ TEST(UtilityTest, ObjNameLength) { envoy::api::v2::Cluster cluster; envoy::api::v2::core::ConfigSource eds_config; EXPECT_THROW_WITH_MESSAGE( - Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster), EnvoyException, - err_prefix + err_suffix); + Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster, stats_options), + EnvoyException, err_prefix + err_suffix); } { @@ -144,8 +148,9 @@ TEST(UtilityTest, ObjNameLength) { std::string json = R"EOF({ "route_config_name": ")EOF" + name + R"EOF(", "cluster": "foo"})EOF"; auto json_object_ptr = Json::Factory::loadFromString(json); envoy::config::filter::network::http_connection_manager::v2::Rds rds; - EXPECT_THROW_WITH_MESSAGE(Config::Utility::translateRdsConfig(*json_object_ptr, rds), - EnvoyException, err_prefix + err_suffix); + EXPECT_THROW_WITH_MESSAGE( + Config::Utility::translateRdsConfig(*json_object_ptr, rds, stats_options), EnvoyException, + err_prefix + err_suffix); } } @@ -159,9 +164,10 @@ TEST(UtilityTest, UnixClusterDns) { auto json_object_ptr = Json::Factory::loadFromString(json); envoy::api::v2::Cluster cluster; envoy::api::v2::core::ConfigSource eds_config; + Stats::StatsOptionsImpl stats_options; EXPECT_THROW_WITH_MESSAGE( - Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster), EnvoyException, - "unresolved URL must be TCP scheme, got: unix:///test.sock"); + Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster, stats_options), + EnvoyException, "unresolved URL must be TCP scheme, got: unix:///test.sock"); } TEST(UtilityTest, UnixClusterStatic) { @@ -174,7 +180,8 @@ TEST(UtilityTest, UnixClusterStatic) { auto json_object_ptr = Json::Factory::loadFromString(json); envoy::api::v2::Cluster cluster; envoy::api::v2::core::ConfigSource eds_config; - Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster); + Stats::StatsOptionsImpl stats_options; + Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster, stats_options); EXPECT_EQ("/test.sock", cluster.hosts(0).pipe().path()); } @@ -216,8 +223,10 @@ TEST(UtilityTest, FactoryForGrpcApiConfigSource) { api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::GRPC); api_config_source.add_cluster_names(); // this also logs a warning for setting REST cluster names for a gRPC API config. - EXPECT_NO_THROW( - Utility::factoryForGrpcApiConfigSource(async_client_manager, api_config_source, scope)); + EXPECT_THROW_WITH_REGEX( + Utility::factoryForGrpcApiConfigSource(async_client_manager, api_config_source, scope), + EnvoyException, + "envoy::api::v2::core::ConfigSource::GRPC must not have a cluster name specified."); } { @@ -225,11 +234,10 @@ TEST(UtilityTest, FactoryForGrpcApiConfigSource) { api_config_source.set_api_type(envoy::api::v2::core::ApiConfigSource::GRPC); api_config_source.add_cluster_names(); api_config_source.add_cluster_names(); - // this also logs a warning for setting REST cluster names for a gRPC API config. EXPECT_THROW_WITH_REGEX( Utility::factoryForGrpcApiConfigSource(async_client_manager, api_config_source, scope), EnvoyException, - "envoy::api::v2::core::ConfigSource must have a singleton cluster name specified"); + "envoy::api::v2::core::ConfigSource::GRPC must not have a cluster name specified."); } { @@ -272,17 +280,14 @@ TEST(CheckApiConfigSourceSubscriptionBackingClusterTest, GrpcClusterTestAcrossTy // API of type GRPC api_config_source->set_api_type(envoy::api::v2::core::ApiConfigSource::GRPC); - api_config_source->add_cluster_names("foo_cluster"); // GRPC cluster without GRPC services. EXPECT_THROW_WITH_MESSAGE( Utility::checkApiConfigSourceSubscriptionBackingCluster(cluster_map, *api_config_source), - EnvoyException, - "envoy::api::v2::core::ConfigSource must have a statically defined non-EDS cluster: " - "'foo_cluster' does not exist, was added via api, or is an EDS cluster"); + EnvoyException, "API configs must have either a gRPC service or a cluster name defined"); // Non-existent cluster. - api_config_source->add_grpc_services(); + api_config_source->add_grpc_services()->mutable_envoy_grpc()->set_cluster_name("foo_cluster"); EXPECT_THROW_WITH_MESSAGE( Utility::checkApiConfigSourceSubscriptionBackingCluster(cluster_map, *api_config_source), EnvoyException, @@ -315,6 +320,13 @@ TEST(CheckApiConfigSourceSubscriptionBackingClusterTest, GrpcClusterTestAcrossTy EXPECT_CALL(*cluster.info_, addedViaApi()); EXPECT_CALL(*cluster.info_, type()); Utility::checkApiConfigSourceSubscriptionBackingCluster(cluster_map, *api_config_source); + + // API with cluster_names set should be rejected. + api_config_source->add_cluster_names("foo_cluster"); + EXPECT_THROW_WITH_MESSAGE( + Utility::checkApiConfigSourceSubscriptionBackingCluster(cluster_map, *api_config_source), + EnvoyException, + "envoy::api::v2::core::ConfigSource::GRPC must not have a cluster name specified."); } TEST(CheckApiConfigSourceSubscriptionBackingClusterTest, RestClusterTestAcrossTypes) { diff --git a/test/common/decompressor/zlib_decompressor_impl_test.cc b/test/common/decompressor/zlib_decompressor_impl_test.cc index e0cb95f5d5de0..a3f608e149f3c 100644 --- a/test/common/decompressor/zlib_decompressor_impl_test.cc +++ b/test/common/decompressor/zlib_decompressor_impl_test.cc @@ -28,7 +28,7 @@ class ZlibDecompressorImplTest : public testing::Test { std::string original_text{}; for (uint64_t i = 0; i < 30; ++i) { TestUtility::feedBufferWithRandomCharacters(buffer, default_input_size * i, i); - original_text.append(TestUtility::bufferToString(buffer)); + original_text.append(buffer.toString()); compressor.compress(buffer, Compressor::State::Flush); accumulation_buffer.add(buffer); drainBuffer(buffer); @@ -45,7 +45,7 @@ class ZlibDecompressorImplTest : public testing::Test { decompressor.init(window_bits); decompressor.decompress(accumulation_buffer, buffer); - std::string decompressed_text{TestUtility::bufferToString(buffer)}; + std::string decompressed_text{buffer.toString()}; ASSERT_EQ(compressor.checksum(), decompressor.checksum()); ASSERT_EQ(original_text.length(), decompressed_text.length()); @@ -122,7 +122,7 @@ TEST_F(ZlibDecompressorImplTest, CompressAndDecompress) { std::string original_text{}; for (uint64_t i = 0; i < 20; ++i) { TestUtility::feedBufferWithRandomCharacters(buffer, default_input_size * i, i); - original_text.append(TestUtility::bufferToString(buffer)); + original_text.append(buffer.toString()); compressor.compress(buffer, Compressor::State::Flush); accumulation_buffer.add(buffer); drainBuffer(buffer); @@ -142,7 +142,7 @@ TEST_F(ZlibDecompressorImplTest, CompressAndDecompress) { decompressor.init(gzip_window_bits); decompressor.decompress(accumulation_buffer, buffer); - std::string decompressed_text{TestUtility::bufferToString(buffer)}; + std::string decompressed_text{buffer.toString()}; ASSERT_EQ(compressor.checksum(), decompressor.checksum()); ASSERT_EQ(original_text.length(), decompressed_text.length()); @@ -162,7 +162,7 @@ TEST_F(ZlibDecompressorImplTest, DecompressWithSmallOutputBuffer) { std::string original_text{}; for (uint64_t i = 0; i < 20; ++i) { TestUtility::feedBufferWithRandomCharacters(buffer, default_input_size * i, i); - original_text.append(TestUtility::bufferToString(buffer)); + original_text.append(buffer.toString()); compressor.compress(buffer, Compressor::State::Flush); accumulation_buffer.add(buffer); drainBuffer(buffer); @@ -182,7 +182,7 @@ TEST_F(ZlibDecompressorImplTest, DecompressWithSmallOutputBuffer) { decompressor.init(gzip_window_bits); decompressor.decompress(accumulation_buffer, buffer); - std::string decompressed_text{TestUtility::bufferToString(buffer)}; + std::string decompressed_text{buffer.toString()}; ASSERT_EQ(compressor.checksum(), decompressor.checksum()); ASSERT_EQ(original_text.length(), decompressed_text.length()); @@ -244,7 +244,7 @@ TEST_F(ZlibDecompressorImplTest, CompressDecompressOfMultipleSlices) { ASSERT_EQ(0, buffer.length()); decompressor.decompress(accumulation_buffer, buffer); - std::string decompressed_text{TestUtility::bufferToString(buffer)}; + std::string decompressed_text{buffer.toString()}; ASSERT_EQ(compressor.checksum(), decompressor.checksum()); ASSERT_EQ(original_text.length(), decompressed_text.length()); diff --git a/test/common/grpc/grpc_client_integration.h b/test/common/grpc/grpc_client_integration.h index cc2f5949cbe87..54fd6936374f2 100644 --- a/test/common/grpc/grpc_client_integration.h +++ b/test/common/grpc/grpc_client_integration.h @@ -30,7 +30,7 @@ class BaseGrpcClientIntegrationParamTest { break; } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } }; diff --git a/test/common/grpc/grpc_client_integration_test.cc b/test/common/grpc/grpc_client_integration_test.cc index 163089293b04f..30d92a2abe12a 100644 --- a/test/common/grpc/grpc_client_integration_test.cc +++ b/test/common/grpc/grpc_client_integration_test.cc @@ -308,10 +308,10 @@ TEST_P(GrpcClientIntegrationTest, ResetAfterCloseLocal) { initialize(); auto stream = createStream(empty_metadata_); stream->grpc_stream_->closeStream(); - stream->fake_stream_->waitForEndStream(dispatcher_helper_.dispatcher_); + ASSERT_TRUE(stream->fake_stream_->waitForEndStream(dispatcher_helper_.dispatcher_)); stream->grpc_stream_->resetStream(); dispatcher_helper_.dispatcher_.run(Event::Dispatcher::RunType::NonBlock); - stream->fake_stream_->waitForReset(); + ASSERT_TRUE(stream->fake_stream_->waitForReset()); } // Validate that request cancel() works. @@ -323,7 +323,7 @@ TEST_P(GrpcClientIntegrationTest, CancelRequest) { EXPECT_CALL(*request->child_span_, finishSpan()); request->grpc_request_->cancel(); dispatcher_helper_.dispatcher_.run(Event::Dispatcher::RunType::NonBlock); - request->fake_stream_->waitForReset(); + ASSERT_TRUE(request->fake_stream_->waitForReset()); } // Parameterize the loopback test server socket address and gRPC client type. @@ -352,7 +352,8 @@ TEST_P(GrpcSslClientIntegrationTest, BasicSslRequestWithClientCert) { class GrpcAccessTokenClientIntegrationTest : public GrpcSslClientIntegrationTest { public: void expectExtraHeaders(FakeStream& fake_stream) override { - fake_stream.waitForHeadersComplete(); + AssertionResult result = fake_stream.waitForHeadersComplete(); + RELEASE_ASSERT(result, result.message()); Http::TestHeaderMapImpl stream_headers(fake_stream.headers()); if (access_token_value_ != "") { if (access_token_value_2_ == "") { @@ -396,7 +397,7 @@ TEST_P(GrpcAccessTokenClientIntegrationTest, AccessTokenAuthRequest) { SKIP_IF_GRPC_CLIENT(ClientType::EnvoyGrpc); access_token_value_ = "accesstokenvalue"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().ACCESS_TOKEN_EXAMPLE; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().AccessTokenExample; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); @@ -408,7 +409,7 @@ TEST_P(GrpcAccessTokenClientIntegrationTest, AccessTokenAuthStream) { SKIP_IF_GRPC_CLIENT(ClientType::EnvoyGrpc); access_token_value_ = "accesstokenvalue"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().ACCESS_TOKEN_EXAMPLE; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().AccessTokenExample; initialize(); auto stream = createStream(empty_metadata_); stream->sendServerInitialMetadata(empty_metadata_); @@ -424,7 +425,7 @@ TEST_P(GrpcAccessTokenClientIntegrationTest, MultipleAccessTokens) { access_token_value_ = "accesstokenvalue"; access_token_value_2_ = "accesstokenvalue2"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().ACCESS_TOKEN_EXAMPLE; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().AccessTokenExample; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); @@ -437,7 +438,7 @@ TEST_P(GrpcAccessTokenClientIntegrationTest, ExtraCredentialParams) { access_token_value_ = "accesstokenvalue"; refresh_token_value_ = "refreshtokenvalue"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().ACCESS_TOKEN_EXAMPLE; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().AccessTokenExample; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); @@ -448,7 +449,7 @@ TEST_P(GrpcAccessTokenClientIntegrationTest, ExtraCredentialParams) { TEST_P(GrpcAccessTokenClientIntegrationTest, NoAccessTokens) { SKIP_IF_GRPC_CLIENT(ClientType::EnvoyGrpc); credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().ACCESS_TOKEN_EXAMPLE; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().AccessTokenExample; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); diff --git a/test/common/grpc/grpc_client_integration_test_harness.h b/test/common/grpc/grpc_client_integration_test_harness.h index f6c9c3fa152c6..b13b802b3fc4c 100644 --- a/test/common/grpc/grpc_client_integration_test_harness.h +++ b/test/common/grpc/grpc_client_integration_test_harness.h @@ -82,7 +82,9 @@ class HelloworldStream : public MockAsyncStreamCallbacks grpc_stream_->sendMessage(request_msg, end_stream); helloworld::HelloRequest received_msg; - fake_stream_->waitForGrpcMessage(dispatcher_helper_.dispatcher_, received_msg); + AssertionResult result = + fake_stream_->waitForGrpcMessage(dispatcher_helper_.dispatcher_, received_msg); + RELEASE_ASSERT(result, result.message()); EXPECT_THAT(request_msg, ProtoEq(received_msg)); } @@ -162,7 +164,8 @@ class HelloworldStream : public MockAsyncStreamCallbacks void closeStream() { grpc_stream_->closeStream(); - fake_stream_->waitForEndStream(dispatcher_helper_.dispatcher_); + AssertionResult result = fake_stream_->waitForEndStream(dispatcher_helper_.dispatcher_); + RELEASE_ASSERT(result, result.message()); } DispatcherHelper& dispatcher_helper_; @@ -224,8 +227,10 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { void TearDown() override { if (fake_connection_) { - fake_connection_->close(); - fake_connection_->waitForDisconnect(); + AssertionResult result = fake_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); fake_connection_.reset(); } } @@ -286,12 +291,13 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { return std::make_unique(dispatcher_, *google_tls_, stub_factory, stats_scope_, createGoogleGrpcConfig()); #else - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; #endif } void expectInitialHeaders(FakeStream& fake_stream, const TestMetadata& initial_metadata) { - fake_stream.waitForHeadersComplete(); + AssertionResult result = fake_stream.waitForHeadersComplete(); + RELEASE_ASSERT(result, result.message()); Http::TestHeaderMapImpl stream_headers(fake_stream.headers()); EXPECT_EQ("POST", stream_headers.get_(":method")); EXPECT_EQ("/helloworld.Greeter/SayHello", stream_headers.get_(":path")); @@ -333,9 +339,12 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { EXPECT_NE(request->grpc_request_, nullptr); if (!fake_connection_) { - fake_connection_ = fake_upstream_->waitForHttpConnection(dispatcher_); + AssertionResult result = fake_upstream_->waitForHttpConnection(dispatcher_, fake_connection_); + RELEASE_ASSERT(result, result.message()); } - fake_streams_.push_back(fake_connection_->waitForNewStream(dispatcher_)); + fake_streams_.emplace_back(); + AssertionResult result = fake_connection_->waitForNewStream(dispatcher_, fake_streams_.back()); + RELEASE_ASSERT(result, result.message()); auto& fake_stream = *fake_streams_.back(); request->fake_stream_ = &fake_stream; @@ -343,7 +352,8 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { expectExtraHeaders(fake_stream); helloworld::HelloRequest received_msg; - fake_stream.waitForGrpcMessage(dispatcher_, received_msg); + result = fake_stream.waitForGrpcMessage(dispatcher_, received_msg); + RELEASE_ASSERT(result, result.message()); EXPECT_THAT(request_msg, ProtoEq(received_msg)); return request; @@ -362,9 +372,12 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { EXPECT_NE(stream->grpc_stream_, nullptr); if (!fake_connection_) { - fake_connection_ = fake_upstream_->waitForHttpConnection(dispatcher_); + AssertionResult result = fake_upstream_->waitForHttpConnection(dispatcher_, fake_connection_); + RELEASE_ASSERT(result, result.message()); } - fake_streams_.push_back(fake_connection_->waitForNewStream(dispatcher_)); + fake_streams_.emplace_back(); + AssertionResult result = fake_connection_->waitForNewStream(dispatcher_, fake_streams_.back()); + RELEASE_ASSERT(result, result.message()); auto& fake_stream = *fake_streams_.back(); stream->fake_stream_ = &fake_stream; @@ -399,6 +412,7 @@ class GrpcClientIntegrationTest : public GrpcClientIntegrationParamTest { Upstream::MockThreadLocalCluster thread_local_cluster_; NiceMock local_info_; Runtime::MockLoader runtime_; + Ssl::ContextManagerImpl context_manager_{runtime_}; NiceMock random_; Http::AsyncClientPtr http_async_client_; Http::ConnectionPool::InstancePtr http_conn_pool_; @@ -421,6 +435,7 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { // doesn't like dangling contexts at destruction. GrpcClientIntegrationTest::TearDown(); fake_upstream_.reset(); + async_client_transport_socket_.reset(); client_connection_.reset(); mock_cluster_info_->transport_socket_factory_.reset(); } @@ -483,7 +498,6 @@ class GrpcSslClientIntegrationTest : public GrpcClientIntegrationTest { bool use_client_cert_{}; Secret::MockSecretManager secret_manager_; - Ssl::ContextManagerImpl context_manager_{runtime_}; }; } // namespace diff --git a/test/common/http/BUILD b/test/common/http/BUILD index a48257922c3cd..7ccee9f708ce3 100644 --- a/test/common/http/BUILD +++ b/test/common/http/BUILD @@ -2,9 +2,11 @@ licenses(["notice"]) # Apache 2 load( "//bazel:envoy_build_system.bzl", + "envoy_cc_fuzz_test", "envoy_cc_test", "envoy_cc_test_library", "envoy_package", + "envoy_proto_library", ) envoy_package() @@ -168,6 +170,22 @@ envoy_cc_test( ], ) +envoy_proto_library( + name = "header_map_impl_fuzz_proto", + srcs = ["header_map_impl_fuzz.proto"], + external_deps = ["well_known_protos"], +) + +envoy_cc_fuzz_test( + name = "header_map_impl_fuzz_test", + srcs = ["header_map_impl_fuzz_test.cc"], + corpus = "header_map_impl_corpus", + deps = [ + ":header_map_impl_fuzz_proto", + "//source/common/http:header_map_lib", + ], +) + envoy_cc_test( name = "header_utility_test", srcs = ["header_utility_test.cc"], diff --git a/test/common/http/codec_client_test.cc b/test/common/http/codec_client_test.cc index 9bf053a7bcebf..e902232c7705a 100644 --- a/test/common/http/codec_client_test.cc +++ b/test/common/http/codec_client_test.cc @@ -354,7 +354,7 @@ TEST_P(CodecNetworkTest, SendData) { Buffer::OwnedImpl data(full_data); upstream_connection_->write(data, false); EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { - EXPECT_EQ(full_data, TestUtility::bufferToString(data)); + EXPECT_EQ(full_data, data.toString()); dispatcher_->exit(); })); dispatcher_->run(Event::Dispatcher::RunType::Block); @@ -375,12 +375,9 @@ TEST_P(CodecNetworkTest, SendHeadersAndClose) { upstream_connection_->close(Network::ConnectionCloseType::FlushWrite); EXPECT_CALL(*codec_, dispatch(_)) .Times(2) - .WillOnce(Invoke([&](Buffer::Instance& data) -> void { - EXPECT_EQ(full_data, TestUtility::bufferToString(data)); - })) - .WillOnce(Invoke([&](Buffer::Instance& data) -> void { - EXPECT_EQ("", TestUtility::bufferToString(data)); - })); + .WillOnce( + Invoke([&](Buffer::Instance& data) -> void { EXPECT_EQ(full_data, data.toString()); })) + .WillOnce(Invoke([&](Buffer::Instance& data) -> void { EXPECT_EQ("", data.toString()); })); // Because the headers are not complete, the disconnect will reset the stream. // Note even if the final \r\n were appended to the header data, enough of the // codec state is mocked out that the data would not be framed and the stream @@ -411,12 +408,9 @@ TEST_P(CodecNetworkTest, SendHeadersAndCloseUnderReadDisable) { EXPECT_CALL(*codec_, dispatch(_)) .Times(2) - .WillOnce(Invoke([&](Buffer::Instance& data) -> void { - EXPECT_EQ(full_data, TestUtility::bufferToString(data)); - })) - .WillOnce(Invoke([&](Buffer::Instance& data) -> void { - EXPECT_EQ("", TestUtility::bufferToString(data)); - })); + .WillOnce( + Invoke([&](Buffer::Instance& data) -> void { EXPECT_EQ(full_data, data.toString()); })) + .WillOnce(Invoke([&](Buffer::Instance& data) -> void { EXPECT_EQ("", data.toString()); })); EXPECT_CALL(inner_encoder_.stream_, resetStream(_)).WillOnce(InvokeWithoutArgs([&]() -> void { for (auto callbacks : inner_encoder_.stream_.callbacks_) { callbacks->onResetStream(StreamResetReason::RemoteReset); diff --git a/test/common/http/conn_manager_impl_test.cc b/test/common/http/conn_manager_impl_test.cc index 9fac8816374ab..ee957d1be4f01 100644 --- a/test/common/http/conn_manager_impl_test.cc +++ b/test/common/http/conn_manager_impl_test.cc @@ -254,7 +254,8 @@ class HttpConnectionManagerImplTest : public Test, public ConnectionManagerConfi std::chrono::milliseconds drainTimeout() override { return std::chrono::milliseconds(100); } FilterChainFactory& filterFactory() override { return filter_factory_; } bool generateRequestId() override { return true; } - const absl::optional& idleTimeout() override { return idle_timeout_; } + absl::optional idleTimeout() const override { return idle_timeout_; } + std::chrono::milliseconds streamIdleTimeout() const override { return stream_idle_timeout_; } Router::RouteConfigProvider& routeConfigProvider() override { return route_config_provider_; } const std::string& serverName() override { return server_name_; } ConnectionManagerStats& stats() override { return stats_; } @@ -294,6 +295,7 @@ class HttpConnectionManagerImplTest : public Test, public ConnectionManagerConfi std::vector set_current_client_cert_details_; absl::optional user_agent_; absl::optional idle_timeout_; + std::chrono::milliseconds stream_idle_timeout_{}; NiceMock random_; NiceMock local_info_; NiceMock factory_context_; @@ -310,6 +312,8 @@ class HttpConnectionManagerImplTest : public Test, public ConnectionManagerConfi ConnectionManagerListenerStats listener_stats_; bool proxy_100_continue_ = false; Http::Http1Settings http1_settings_; + NiceMock upstream_conn_; // for websocket tests + NiceMock conn_pool_; // for websocket tests // TODO(mattklein123): Not all tests have been converted over to better setup. Convert the rest. MockStreamEncoder response_encoder_; @@ -1093,6 +1097,345 @@ TEST_F(HttpConnectionManagerImplTest, NoPath) { conn_manager_->onData(fake_input, false); } +// No idle timeout when route idle timeout is implied at both global and +// per-route level. The connection manager config is responsible for managing +// the default configuration aspects. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutNotConfigured) { + setup(false, ""); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, createTimer_(_)).Times(0); + EXPECT_CALL(*codec_, dispatch(_)) + .Times(1) + .WillRepeatedly(Invoke([&](Buffer::Instance& data) -> void { + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + decoder->decodeHeaders(std::move(headers), false); + + data.drain(4); + })); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ(0U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// When the global timeout is configured, the timer is enabled before we receive +// headers, if it fires we don't face plant. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutGlobal) { + stream_idle_timeout_ = std::chrono::milliseconds(10); + setup(false, ""); + + EXPECT_CALL(*codec_, dispatch(_)).Times(1).WillRepeatedly(Invoke([&](Buffer::Instance&) -> void { + Event::MockTimer* idle_timer = new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + EXPECT_CALL(*idle_timer, enableTimer(std::chrono::milliseconds(10))); + conn_manager_->newStream(response_encoder_); + + // Expect resetIdleTimer() to be called for the response + // encodeHeaders()/encodeData(). + EXPECT_CALL(*idle_timer, enableTimer(_)).Times(2); + EXPECT_CALL(*idle_timer, disableTimer()); + idle_timer->callback_(); + })); + + // 408 direct response after timeout. + EXPECT_CALL(response_encoder_, encodeHeaders(_, false)) + .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { + EXPECT_STREQ("408", headers.Status()->value().c_str()); + })); + std::string response_body; + EXPECT_CALL(response_encoder_, encodeData(_, true)).WillOnce(AddBufferToString(&response_body)); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ("stream timeout", response_body); + EXPECT_EQ(1U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Per-route timeouts override the global stream idle timeout. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutRouteOverride) { + stream_idle_timeout_ = std::chrono::milliseconds(10); + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(30))); + + EXPECT_CALL(*codec_, dispatch(_)) + .Times(1) + .WillRepeatedly(Invoke([&](Buffer::Instance& data) -> void { + Event::MockTimer* idle_timer = + new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + EXPECT_CALL(*idle_timer, enableTimer(std::chrono::milliseconds(10))); + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, enableTimer(std::chrono::milliseconds(30))); + decoder->decodeHeaders(std::move(headers), false); + + data.drain(4); + })); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ(0U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Per-route zero timeout overrides the global stream idle timeout. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutRouteZeroOverride) { + stream_idle_timeout_ = std::chrono::milliseconds(10); + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(0))); + + EXPECT_CALL(*codec_, dispatch(_)) + .Times(1) + .WillRepeatedly(Invoke([&](Buffer::Instance& data) -> void { + Event::MockTimer* idle_timer = + new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + EXPECT_CALL(*idle_timer, enableTimer(std::chrono::milliseconds(10))); + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, disableTimer()); + decoder->decodeHeaders(std::move(headers), false); + + data.drain(4); + })); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ(0U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Validate the per-stream idle timeout after having sent downstream headers. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutAfterDownstreamHeaders) { + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(10))); + + // Codec sends downstream request headers. + EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + Event::MockTimer* idle_timer = new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeHeaders(std::move(headers), false); + + // Expect resetIdleTimer() to be called for the response + // encodeHeaders()/encodeData(). + EXPECT_CALL(*idle_timer, enableTimer(_)).Times(2); + EXPECT_CALL(*idle_timer, disableTimer()); + idle_timer->callback_(); + + data.drain(4); + })); + + // 408 direct response after timeout. + EXPECT_CALL(response_encoder_, encodeHeaders(_, false)) + .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { + EXPECT_STREQ("408", headers.Status()->value().c_str()); + })); + std::string response_body; + EXPECT_CALL(response_encoder_, encodeData(_, true)).WillOnce(AddBufferToString(&response_body)); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ("stream timeout", response_body); + EXPECT_EQ(1U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Validate the per-stream idle timer is properly disabled when the stream terminates normally. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutNormalTermination) { + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(10))); + + // Codec sends downstream request headers. + Event::MockTimer* idle_timer = new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeHeaders(std::move(headers), false); + + data.drain(4); + })); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_CALL(*idle_timer, disableTimer()); + conn_manager_->onEvent(Network::ConnectionEvent::RemoteClose); + + EXPECT_EQ(0U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Validate the per-stream idle timeout after having sent downstream +// headers+body. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutAfterDownstreamHeadersAndBody) { + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(10))); + + // Codec sends downstream request headers. + EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + Event::MockTimer* idle_timer = new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeHeaders(std::move(headers), false); + + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeData(data, false); + + // Expect resetIdleTimer() to be called for the response + // encodeHeaders()/encodeData(). + EXPECT_CALL(*idle_timer, enableTimer(_)).Times(2); + EXPECT_CALL(*idle_timer, disableTimer()); + idle_timer->callback_(); + + data.drain(4); + })); + + // 408 direct response after timeout. + EXPECT_CALL(response_encoder_, encodeHeaders(_, false)) + .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { + EXPECT_STREQ("408", headers.Status()->value().c_str()); + })); + std::string response_body; + EXPECT_CALL(response_encoder_, encodeData(_, true)).WillOnce(AddBufferToString(&response_body)); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ("stream timeout", response_body); + EXPECT_EQ(1U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Validate the per-stream idle timeout after upstream headers have been sent. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutAfterUpstreamHeaders) { + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(10))); + + // Store the basic request encoder during filter chain setup. + std::shared_ptr filter(new NiceMock()); + + EXPECT_CALL(filter_factory_, createFilterChain(_)) + .WillRepeatedly(Invoke([&](FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamDecoderFilter(filter); + })); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + + // Codec sends downstream request headers, upstream response headers are + // encoded. + EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { + StreamDecoder* decoder = &conn_manager_->newStream(response_encoder_); + + Event::MockTimer* idle_timer = new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeHeaders(std::move(headers), false); + + HeaderMapPtr response_headers{new TestHeaderMapImpl{{":status", "200"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + filter->callbacks_->encodeHeaders(std::move(response_headers), false); + + EXPECT_CALL(*idle_timer, disableTimer()); + idle_timer->callback_(); + + data.drain(4); + })); + + // 200 upstream response. + EXPECT_CALL(response_encoder_, encodeHeaders(_, false)) + .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { + EXPECT_STREQ("200", headers.Status()->value().c_str()); + })); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ(1U, stats_.named_.downstream_rq_idle_timeout_.value()); +} + +// Validate the per-stream idle timeout after a sequence of header/data events. +TEST_F(HttpConnectionManagerImplTest, PerStreamIdleTimeoutAfterBidiData) { + setup(false, ""); + ON_CALL(route_config_provider_.route_config_->route_->route_entry_, idleTimeout()) + .WillByDefault(Return(std::chrono::milliseconds(10))); + proxy_100_continue_ = true; + + // Store the basic request encoder during filter chain setup. + std::shared_ptr filter(new NiceMock()); + + EXPECT_CALL(filter_factory_, createFilterChain(_)) + .WillRepeatedly(Invoke([&](FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamDecoderFilter(filter); + })); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + + // Codec sends downstream request headers, upstream response headers are + // encoded, data events happen in various directions. + Event::MockTimer* idle_timer = new Event::MockTimer(&filter_callbacks_.connection_.dispatcher_); + StreamDecoder* decoder; + EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { + decoder = &conn_manager_->newStream(response_encoder_); + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":path", "/"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeHeaders(std::move(headers), false); + + HeaderMapPtr response_continue_headers{new TestHeaderMapImpl{{":status", "100"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + filter->callbacks_->encode100ContinueHeaders(std::move(response_continue_headers)); + + HeaderMapPtr response_headers{new TestHeaderMapImpl{{":status", "200"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + filter->callbacks_->encodeHeaders(std::move(response_headers), false); + + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeData(data, false); + + HeaderMapPtr trailers{new TestHeaderMapImpl{{"foo", "bar"}}}; + EXPECT_CALL(*idle_timer, enableTimer(_)); + decoder->decodeTrailers(std::move(trailers)); + + Buffer::OwnedImpl fake_response("world"); + EXPECT_CALL(*idle_timer, enableTimer(_)); + filter->callbacks_->encodeData(fake_response, false); + + EXPECT_CALL(*idle_timer, disableTimer()); + idle_timer->callback_(); + + data.drain(4); + })); + + // 100 continue. + EXPECT_CALL(response_encoder_, encode100ContinueHeaders(_)); + + // 200 upstream response. + EXPECT_CALL(response_encoder_, encodeHeaders(_, false)) + .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { + EXPECT_STREQ("200", headers.Status()->value().c_str()); + })); + + std::string response_body; + EXPECT_CALL(response_encoder_, encodeData(_, false)).WillOnce(AddBufferToString(&response_body)); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + + EXPECT_EQ(1U, stats_.named_.downstream_rq_idle_timeout_.value()); + EXPECT_EQ("world", response_body); +} + TEST_F(HttpConnectionManagerImplTest, RejectWebSocketOnNonWebSocketRoute) { setup(false, ""); @@ -1163,8 +1506,7 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketNoThreadLocalCluster) { TEST_F(HttpConnectionManagerImplTest, WebSocketNoConnInPool) { setup(false, ""); - Upstream::MockHost::MockCreateConnectionData conn_info; - EXPECT_CALL(cluster_manager_, tcpConnForCluster_(_, _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster(_, _, _)).WillOnce(Return(nullptr)); expectOnUpstreamInitFailure(); EXPECT_EQ(1U, stats_.named_.downstream_cx_websocket_active_.value()); @@ -1179,14 +1521,37 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketNoConnInPool) { TEST_F(HttpConnectionManagerImplTest, WebSocketDataAfterConnectFail) { setup(false, ""); - Upstream::MockHost::MockCreateConnectionData conn_info; - EXPECT_CALL(cluster_manager_, tcpConnForCluster_(_, _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster(_, _, _)).WillOnce(Return(&conn_pool_)); + + StreamDecoder* decoder = nullptr; + NiceMock encoder; + + configureRouteForWebsocket(route_config_provider_.route_config_->route_->route_entry_); + + EXPECT_CALL(*codec_, dispatch(_)).WillOnce(Invoke([&](Buffer::Instance& data) -> void { + decoder = &conn_manager_->newStream(encoder); + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, + {":method", "GET"}, + {":path", "/"}, + {"connection", "Upgrade"}, + {"upgrade", "websocket"}}}; + decoder->decodeHeaders(std::move(headers), true); + data.drain(4); + })); + + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); - expectOnUpstreamInitFailure(); EXPECT_EQ(1U, stats_.named_.downstream_cx_websocket_active_.value()); EXPECT_EQ(1U, stats_.named_.downstream_cx_websocket_total_.value()); EXPECT_EQ(0U, stats_.named_.downstream_cx_http1_active_.value()); + EXPECT_CALL(encoder, encodeHeaders(_, true)) + .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { + EXPECT_STREQ("504", headers.Status()->value().c_str()); + })); + + conn_pool_.poolFailure(Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); // This should get dropped, with no ASSERT or crash. @@ -1206,13 +1571,14 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketMetadataMatch) { .WillByDefault(Return( &route_config_provider_.route_config_->route_->route_entry_.metadata_matches_criteria_)); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_(_, _)) - .WillOnce(Invoke([&](const std::string&, Upstream::LoadBalancerContext* context) - -> Upstream::MockHost::MockCreateConnectionData { + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster(_, _, _)) + .WillOnce(Invoke([&](const std::string&, Upstream::ResourcePriority, + Upstream::LoadBalancerContext* context) + -> Tcp::ConnectionPool::MockInstance* { EXPECT_EQ( context->metadataMatchCriteria(), &route_config_provider_.route_config_->route_->route_entry_.metadata_matches_criteria_); - return {}; + return nullptr; })); expectOnUpstreamInitFailure(); @@ -1223,21 +1589,8 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketMetadataMatch) { TEST_F(HttpConnectionManagerImplTest, WebSocketConnectTimeoutError) { setup(false, ""); - Event::MockTimer* connect_timer = - new NiceMock(&filter_callbacks_.connection_.dispatcher_); - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; - - conn_info.connection_ = upstream_connection; - conn_info.host_description_.reset(new Upstream::HostImpl( - cluster_manager_.thread_local_cluster_.cluster_.info_, "newhost", - Network::Utility::resolveUrl("tcp://127.0.0.1:80"), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); - EXPECT_CALL(*connect_timer, enableTimer(_)); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_("fake_cluster", _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)); StreamDecoder* decoder = nullptr; NiceMock encoder; @@ -1255,15 +1608,15 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketConnectTimeoutError) { data.drain(4); })); + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + EXPECT_CALL(encoder, encodeHeaders(_, true)) .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { EXPECT_STREQ("504", headers.Status()->value().c_str()); })); + conn_pool_.poolFailure(Tcp::ConnectionPool::PoolFailureReason::Timeout); - Buffer::OwnedImpl fake_input("1234"); - conn_manager_->onData(fake_input, false); - - connect_timer->callback_(); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } @@ -1271,21 +1624,8 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketConnectTimeoutError) { TEST_F(HttpConnectionManagerImplTest, WebSocketConnectionFailure) { setup(false, ""); - Event::MockTimer* connect_timer = - new NiceMock(&filter_callbacks_.connection_.dispatcher_); - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; - - conn_info.connection_ = upstream_connection; - conn_info.host_description_.reset(new Upstream::HostImpl( - cluster_manager_.thread_local_cluster_.cluster_.info_, "newhost", - Network::Utility::resolveUrl("tcp://127.0.0.1:80"), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); - EXPECT_CALL(*connect_timer, enableTimer(_)); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_("fake_cluster", _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)); StreamDecoder* decoder = nullptr; NiceMock encoder; @@ -1303,16 +1643,16 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketConnectionFailure) { data.drain(4); })); + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); + EXPECT_CALL(encoder, encodeHeaders(_, true)) .WillOnce(Invoke([](const HeaderMap& headers, bool) -> void { EXPECT_STREQ("504", headers.Status()->value().c_str()); })); - Buffer::OwnedImpl fake_input("1234"); - conn_manager_->onData(fake_input, false); + conn_pool_.poolFailure(Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); - // expectOnUpstreamInitFailure("504"); - upstream_connection->raiseEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } @@ -1322,9 +1662,6 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { StreamDecoder* decoder = nullptr; NiceMock encoder; - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, {":method", "GET"}, {":path", "/"}, @@ -1332,14 +1669,8 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { {"upgrade", "websocket"}}}; auto raw_header_ptr = headers.get(); - conn_info.connection_ = upstream_connection; - conn_info.host_description_.reset(new Upstream::HostImpl( - cluster_manager_.thread_local_cluster_.cluster_.info_, "newhost", - Network::Utility::resolveUrl("tcp://127.0.0.1:80"), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_("fake_cluster", _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)); configureRouteForWebsocket(route_config_provider_.route_config_->route_->route_entry_); @@ -1356,7 +1687,9 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { Buffer::OwnedImpl fake_input("1234"); conn_manager_->onData(fake_input, false); - upstream_connection->raiseEvent(Network::ConnectionEvent::Connected); + + conn_pool_.host_->hostname_ = "newhost"; + conn_pool_.poolReady(upstream_conn_); // rewritten authority header when auto_host_rewrite is true EXPECT_STREQ("newhost", raw_header_ptr->Host()->value().c_str()); @@ -1372,21 +1705,8 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketPrefixAndAutoHostRewrite) { TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyData) { setup(false, ""); - Event::MockTimer* connect_timer = - new NiceMock(&filter_callbacks_.connection_.dispatcher_); - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; - - conn_info.connection_ = upstream_connection; - conn_info.host_description_.reset(new Upstream::HostImpl( - cluster_manager_.thread_local_cluster_.cluster_.info_, "newhost", - Network::Utility::resolveUrl("tcp://127.0.0.1:80"), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); - EXPECT_CALL(*connect_timer, enableTimer(_)); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_("fake_cluster", _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)); StreamDecoder* decoder = nullptr; NiceMock encoder; @@ -1413,10 +1733,11 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyData) { conn_manager_->onData(fake_input, false); - EXPECT_CALL(*upstream_connection, write(_, false)); - EXPECT_CALL(*upstream_connection, write(BufferEqual(&early_data), false)); + EXPECT_CALL(upstream_conn_, write(_, false)); + EXPECT_CALL(upstream_conn_, write(BufferEqual(&early_data), false)); EXPECT_CALL(filter_callbacks_.connection_, readDisable(false)); - upstream_connection->raiseEvent(Network::ConnectionEvent::Connected); + conn_pool_.poolReady(upstream_conn_); + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } @@ -1424,21 +1745,8 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyData) { TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyDataConnectionFail) { setup(false, ""); - Event::MockTimer* connect_timer = - new NiceMock(&filter_callbacks_.connection_.dispatcher_); - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; - - conn_info.connection_ = upstream_connection; - conn_info.host_description_.reset(new Upstream::HostImpl( - cluster_manager_.thread_local_cluster_.cluster_.info_, "newhost", - Network::Utility::resolveUrl("tcp://127.0.0.1:80"), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); - EXPECT_CALL(*connect_timer, enableTimer(_)); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_("fake_cluster", _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)); StreamDecoder* decoder = nullptr; NiceMock encoder; @@ -1465,7 +1773,7 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyDataConnectionFail) { conn_manager_->onData(fake_input, false); - upstream_connection->raiseEvent(Network::ConnectionEvent::RemoteClose); + conn_pool_.poolFailure(Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); // This should get dropped, with no crash or ASSERT. @@ -1478,21 +1786,8 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyDataConnectionFail) { TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyEndStream) { setup(false, ""); - Event::MockTimer* connect_timer = - new NiceMock(&filter_callbacks_.connection_.dispatcher_); - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; - - conn_info.connection_ = upstream_connection; - conn_info.host_description_.reset(new Upstream::HostImpl( - cluster_manager_.thread_local_cluster_.cluster_.info_, "newhost", - Network::Utility::resolveUrl("tcp://127.0.0.1:80"), - envoy::api::v2::core::Metadata::default_instance(), 1, - envoy::api::v2::core::Locality().default_instance(), - envoy::api::v2::endpoint::Endpoint::HealthCheckConfig().default_instance())); - EXPECT_CALL(*connect_timer, enableTimer(_)); - EXPECT_CALL(cluster_manager_, tcpConnForCluster_("fake_cluster", _)).WillOnce(Return(conn_info)); + EXPECT_CALL(cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)); StreamDecoder* decoder = nullptr; NiceMock encoder; @@ -1514,13 +1809,72 @@ TEST_F(HttpConnectionManagerImplTest, WebSocketEarlyEndStream) { Buffer::OwnedImpl fake_input("1234"); conn_manager_->onData(fake_input, true); - EXPECT_CALL(*upstream_connection, write(_, false)); - EXPECT_CALL(*upstream_connection, write(_, true)).Times(0); - upstream_connection->raiseEvent(Network::ConnectionEvent::Connected); + EXPECT_CALL(upstream_conn_, write(_, false)); + EXPECT_CALL(upstream_conn_, write(_, true)).Times(0); + conn_pool_.poolReady(upstream_conn_); filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); conn_manager_.reset(); } +// Make sure for upgrades, we do not append Connection: Close when draining. +TEST_F(HttpConnectionManagerImplTest, FooUpgradeDrainClose) { + setup(false, "envoy-custom-server", false); + + // Store the basic request encoder during filter chain setup. + MockStreamFilter* filter = new MockStreamFilter(); + EXPECT_CALL(drain_close_, drainClose()).WillOnce(Return(true)); + + EXPECT_CALL(*filter, decodeHeaders(_, false)) + .WillRepeatedly(Invoke([&](HeaderMap&, bool) -> FilterHeadersStatus { + return FilterHeadersStatus::StopIteration; + })); + + EXPECT_CALL(*filter, encodeHeaders(_, false)) + .WillRepeatedly(Invoke( + [&](HeaderMap&, bool) -> FilterHeadersStatus { return FilterHeadersStatus::Continue; })); + + NiceMock encoder; + EXPECT_CALL(encoder, encodeHeaders(_, false)) + .WillOnce(Invoke([&](const HeaderMap& headers, bool) -> void { + EXPECT_NE(nullptr, headers.Connection()); + EXPECT_STREQ("upgrade", headers.Connection()->value().c_str()); + })); + + EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)).Times(1); + EXPECT_CALL(*filter, setEncoderFilterCallbacks(_)).Times(1); + + EXPECT_CALL(filter_factory_, createUpgradeFilterChain(_, _)) + .WillRepeatedly( + Invoke([&](absl::string_view, FilterChainFactoryCallbacks& callbacks) -> bool { + callbacks.addStreamFilter(StreamFilterSharedPtr{filter}); + return true; + })); + + // When dispatch is called on the codec, we pretend to get a new stream and then fire a headers + // only request into it. Then we respond into the filter. + StreamDecoder* decoder = nullptr; + EXPECT_CALL(*codec_, dispatch(_)).WillRepeatedly(Invoke([&](Buffer::Instance& data) -> void { + decoder = &conn_manager_->newStream(encoder); + + HeaderMapPtr headers{new TestHeaderMapImpl{{":authority", "host"}, + {":method", "GET"}, + {":path", "/"}, + {"connection", "Upgrade"}, + {"upgrade", "foo"}}}; + decoder->decodeHeaders(std::move(headers), false); + + HeaderMapPtr response_headers{ + new TestHeaderMapImpl{{":status", "101"}, {"Connection", "upgrade"}, {"upgrade", "foo"}}}; + filter->decoder_callbacks_->encodeHeaders(std::move(response_headers), false); + + data.drain(4); + })); + + // Kick off the incoming data. Use extra data which should cause a redispatch. + Buffer::OwnedImpl fake_input("1234"); + conn_manager_->onData(fake_input, false); +} + TEST_F(HttpConnectionManagerImplTest, DrainClose) { setup(true, ""); @@ -1972,8 +2326,7 @@ TEST_F(HttpConnectionManagerImplTest, FilterAddBodyDuringDecodeData) { EXPECT_CALL(*decoder_filters_[0], decodeData(_, true)) .WillOnce(Invoke([&](Buffer::Instance& data, bool) -> FilterDataStatus { decoder_filters_[0]->callbacks_->addDecodedData(data, true); - EXPECT_EQ(TestUtility::bufferToString(*decoder_filters_[0]->callbacks_->decodingBuffer()), - "helloworld"); + EXPECT_EQ(decoder_filters_[0]->callbacks_->decodingBuffer()->toString(), "helloworld"); return FilterDataStatus::Continue; })); EXPECT_CALL(*decoder_filters_[1], decodeHeaders(_, false)) @@ -1992,8 +2345,7 @@ TEST_F(HttpConnectionManagerImplTest, FilterAddBodyDuringDecodeData) { EXPECT_CALL(*encoder_filters_[0], encodeData(_, true)) .WillOnce(Invoke([&](Buffer::Instance& data, bool) -> FilterDataStatus { encoder_filters_[0]->callbacks_->addEncodedData(data, true); - EXPECT_EQ(TestUtility::bufferToString(*encoder_filters_[0]->callbacks_->encodingBuffer()), - "goodbye"); + EXPECT_EQ(encoder_filters_[0]->callbacks_->encodingBuffer()->toString(), "goodbye"); return FilterDataStatus::Continue; })); EXPECT_CALL(*encoder_filters_[1], encodeHeaders(_, false)) diff --git a/test/common/http/conn_manager_utility_test.cc b/test/common/http/conn_manager_utility_test.cc index 1d6978c736016..411d0f483a3b0 100644 --- a/test/common/http/conn_manager_utility_test.cc +++ b/test/common/http/conn_manager_utility_test.cc @@ -44,7 +44,8 @@ class MockConnectionManagerConfig : public ConnectionManagerConfig { MOCK_METHOD0(drainTimeout, std::chrono::milliseconds()); MOCK_METHOD0(filterFactory, FilterChainFactory&()); MOCK_METHOD0(generateRequestId, bool()); - MOCK_METHOD0(idleTimeout, const absl::optional&()); + MOCK_CONST_METHOD0(idleTimeout, absl::optional()); + MOCK_CONST_METHOD0(streamIdleTimeout, std::chrono::milliseconds()); MOCK_METHOD0(routeConfigProvider, Router::RouteConfigProvider&()); MOCK_METHOD0(serverName, const std::string&()); MOCK_METHOD0(stats, ConnectionManagerStats&()); @@ -195,7 +196,7 @@ TEST_F(ConnectionManagerUtilityTest, ViaEmpty) { EXPECT_FALSE(request_headers.has(Headers::get().Via)); TestHeaderMapImpl response_headers; - ConnectionManagerUtility::mutateResponseHeaders(response_headers, request_headers, via_); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, via_); EXPECT_FALSE(response_headers.has(Headers::get().Via)); } @@ -212,9 +213,9 @@ TEST_F(ConnectionManagerUtilityTest, ViaAppend) { TestHeaderMapImpl response_headers; // Pretend we're doing a 100-continue transform here. - ConnectionManagerUtility::mutateResponseHeaders(response_headers, request_headers, ""); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); // The actual response header processing. - ConnectionManagerUtility::mutateResponseHeaders(response_headers, request_headers, via_); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, via_); EXPECT_EQ("foo", response_headers.get_(Headers::get().Via)); } @@ -548,24 +549,81 @@ TEST_F(ConnectionManagerUtilityTest, RemoveConnectionUpgradeForHttp2Requests) { // Test cleaning response headers. TEST_F(ConnectionManagerUtilityTest, MutateResponseHeaders) { TestHeaderMapImpl response_headers{ - {"connection", "foo"}, {"transfer-encoding", "foo"}, {"custom_header", "foo"}}; + {"connection", "foo"}, {"transfer-encoding", "foo"}, {"custom_header", "custom_value"}}; TestHeaderMapImpl request_headers{{"x-request-id", "request-id"}}; - ConnectionManagerUtility::mutateResponseHeaders(response_headers, request_headers, ""); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); EXPECT_EQ(1UL, response_headers.size()); - EXPECT_EQ("foo", response_headers.get_("custom_header")); + EXPECT_EQ("custom_value", response_headers.get_("custom_header")); EXPECT_FALSE(response_headers.has("x-request-id")); EXPECT_FALSE(response_headers.has(Headers::get().Via)); } +// Make sure we don't remove connection headers on all Upgrade responses. +TEST_F(ConnectionManagerUtilityTest, DoNotRemoveConnectionUpgradeForWebSocketResponses) { + TestHeaderMapImpl request_headers{{"connection", "UpGrAdE"}, {"upgrade", "foo"}}; + TestHeaderMapImpl response_headers{ + {"connection", "upgrade"}, {"transfer-encoding", "foo"}, {"upgrade", "bar"}}; + EXPECT_TRUE(Utility::isUpgrade(request_headers)); + EXPECT_TRUE(Utility::isUpgrade(response_headers)); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); + + EXPECT_EQ(2UL, response_headers.size()) << response_headers; + EXPECT_EQ("upgrade", response_headers.get_("connection")); + EXPECT_EQ("bar", response_headers.get_("upgrade")); +} + +TEST_F(ConnectionManagerUtilityTest, ClearUpgradeHeadersForNonUpgradeRequests) { + // Test clearing non-upgrade request and response headers + { + TestHeaderMapImpl request_headers{{"x-request-id", "request-id"}}; + TestHeaderMapImpl response_headers{ + {"connection", "foo"}, {"transfer-encoding", "bar"}, {"custom_header", "custom_value"}}; + EXPECT_FALSE(Utility::isUpgrade(request_headers)); + EXPECT_FALSE(Utility::isUpgrade(response_headers)); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); + + EXPECT_EQ(1UL, response_headers.size()) << response_headers; + EXPECT_EQ("custom_value", response_headers.get_("custom_header")); + } + + // Test with the request headers not valid upgrade headers + { + TestHeaderMapImpl request_headers{{"upgrade", "foo"}}; + TestHeaderMapImpl response_headers{{"connection", "upgrade"}, + {"transfer-encoding", "eep"}, + {"upgrade", "foo"}, + {"custom_header", "custom_value"}}; + EXPECT_FALSE(Utility::isUpgrade(request_headers)); + EXPECT_TRUE(Utility::isUpgrade(response_headers)); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); + + EXPECT_EQ(2UL, response_headers.size()) << response_headers; + EXPECT_EQ("custom_value", response_headers.get_("custom_header")); + EXPECT_EQ("foo", response_headers.get_("upgrade")); + } + + // Test with the response headers not valid upgrade headers + { + TestHeaderMapImpl request_headers{{"connection", "UpGrAdE"}, {"upgrade", "foo"}}; + TestHeaderMapImpl response_headers{{"transfer-encoding", "foo"}, {"upgrade", "bar"}}; + EXPECT_TRUE(Utility::isUpgrade(request_headers)); + EXPECT_FALSE(Utility::isUpgrade(response_headers)); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); + + EXPECT_EQ(1UL, response_headers.size()) << response_headers; + EXPECT_EQ("bar", response_headers.get_("upgrade")); + } +} + // Test that we correctly return x-request-id if we were requested to force a trace. TEST_F(ConnectionManagerUtilityTest, MutateResponseHeadersReturnXRequestId) { TestHeaderMapImpl response_headers; TestHeaderMapImpl request_headers{{"x-request-id", "request-id"}, {"x-envoy-force-trace", "true"}}; - ConnectionManagerUtility::mutateResponseHeaders(response_headers, request_headers, ""); + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); EXPECT_EQ("request-id", response_headers.get_("x-request-id")); } @@ -578,7 +636,7 @@ TEST_F(ConnectionManagerUtilityTest, MtlsSanitizeClientCert) { .WillByDefault(Return(Http::ForwardClientCertType::Sanitize)); std::vector details; ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test;SAN=abc;URI=abc;DNS=example.com"}}; + TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test;URI=abc;DNS=example.com"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); @@ -595,13 +653,12 @@ TEST_F(ConnectionManagerUtilityTest, MtlsForwardOnlyClientCert) { std::vector details; ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); TestHeaderMapImpl headers{ - {"x-forwarded-client-cert", - "By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be;DNS=example.com"}}; + {"x-forwarded-client-cert", "By=test://foo.com/fe;URI=test://bar.com/be;DNS=example.com"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); - EXPECT_EQ("By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be;DNS=example.com", + EXPECT_EQ("By=test://foo.com/fe;URI=test://bar.com/be;DNS=example.com", headers.get_("x-forwarded-client-cert")); } @@ -621,7 +678,6 @@ TEST_F(ConnectionManagerUtilityTest, MtlsSetForwardClientCert) { ON_CALL(config_, forwardClientCert()) .WillByDefault(Return(Http::ForwardClientCertType::AppendForward)); std::vector details = std::vector(); - details.push_back(Http::ClientCertDetailsType::SAN); details.push_back(Http::ClientCertDetailsType::URI); details.push_back(Http::ClientCertDetailsType::Cert); details.push_back(Http::ClientCertDetailsType::DNS); @@ -633,7 +689,6 @@ TEST_F(ConnectionManagerUtilityTest, MtlsSetForwardClientCert) { EXPECT_TRUE(headers.has("x-forwarded-client-cert")); EXPECT_EQ("By=test://foo.com/be;" "Hash=abcdefg;" - "SAN=test://foo.com/fe;" "URI=test://foo.com/fe;" "Cert=\"%3D%3Dabc%0Ade%3D\";" "DNS=www.example.com", @@ -659,23 +714,21 @@ TEST_F(ConnectionManagerUtilityTest, MtlsAppendForwardClientCert) { ON_CALL(config_, forwardClientCert()) .WillByDefault(Return(Http::ForwardClientCertType::AppendForward)); std::vector details = std::vector(); - details.push_back(Http::ClientCertDetailsType::SAN); details.push_back(Http::ClientCertDetailsType::URI); details.push_back(Http::ClientCertDetailsType::Cert); details.push_back(Http::ClientCertDetailsType::DNS); ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test://foo.com/fe;SAN=test://bar.com/" - "be;URI=test://bar.com/" - "be;DNS=test.com;DNS=test.com"}}; + TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test://foo.com/fe;" + "URI=test://bar.com/be;" + "DNS=test.com;DNS=test.com"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); - EXPECT_EQ( - "By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be;DNS=test.com;DNS=test.com," - "By=test://foo.com/be;Hash=abcdefg;SAN=test://foo.com/fe;URI=test://foo.com/fe;" - "Cert=\"%3D%3Dabc%0Ade%3D\";DNS=www.example.com", - headers.get_("x-forwarded-client-cert")); + EXPECT_EQ("By=test://foo.com/fe;URI=test://bar.com/be;DNS=test.com;DNS=test.com," + "By=test://foo.com/be;Hash=abcdefg;URI=test://foo.com/fe;" + "Cert=\"%3D%3Dabc%0Ade%3D\";DNS=www.example.com", + headers.get_("x-forwarded-client-cert")); } // This test assumes the following scenario: @@ -693,18 +746,16 @@ TEST_F(ConnectionManagerUtilityTest, MtlsAppendForwardClientCertLocalSanEmpty) { ON_CALL(config_, forwardClientCert()) .WillByDefault(Return(Http::ForwardClientCertType::AppendForward)); std::vector details = std::vector(); - details.push_back(Http::ClientCertDetailsType::SAN); details.push_back(Http::ClientCertDetailsType::URI); ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); TestHeaderMapImpl headers{ - {"x-forwarded-client-cert", - "By=test://foo.com/fe;Hash=xyz;SAN=test://bar.com/be;URI=test://bar.com/be"}}; + {"x-forwarded-client-cert", "By=test://foo.com/fe;Hash=xyz;URI=test://bar.com/be"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); - EXPECT_EQ("By=test://foo.com/fe;Hash=xyz;SAN=test://bar.com/be;URI=test://bar.com/be," - "Hash=abcdefg;SAN=test://foo.com/fe;URI=test://foo.com/fe", + EXPECT_EQ("By=test://foo.com/fe;Hash=xyz;URI=test://bar.com/be," + "Hash=abcdefg;URI=test://foo.com/fe", headers.get_("x-forwarded-client-cert")); } @@ -728,18 +779,17 @@ TEST_F(ConnectionManagerUtilityTest, MtlsSanitizeSetClientCert) { .WillByDefault(Return(Http::ForwardClientCertType::SanitizeSet)); std::vector details = std::vector(); details.push_back(Http::ClientCertDetailsType::Subject); - details.push_back(Http::ClientCertDetailsType::SAN); details.push_back(Http::ClientCertDetailsType::URI); details.push_back(Http::ClientCertDetailsType::Cert); ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", - "By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be"}}; + TestHeaderMapImpl headers{ + {"x-forwarded-client-cert", "By=test://foo.com/fe;URI=test://bar.com/be"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); EXPECT_EQ("By=test://foo.com/be;Hash=abcdefg;Subject=\"/C=US/ST=CA/L=San " - "Francisco/OU=Lyft/CN=test.lyft.com\";SAN=test://foo.com/fe;URI=test://foo.com/" + "Francisco/OU=Lyft/CN=test.lyft.com\";URI=test://foo.com/" "fe;Cert=\"abcde=\"", headers.get_("x-forwarded-client-cert")); } @@ -762,17 +812,16 @@ TEST_F(ConnectionManagerUtilityTest, MtlsSanitizeSetClientCertPeerSanEmpty) { .WillByDefault(Return(Http::ForwardClientCertType::SanitizeSet)); std::vector details = std::vector(); details.push_back(Http::ClientCertDetailsType::Subject); - details.push_back(Http::ClientCertDetailsType::SAN); details.push_back(Http::ClientCertDetailsType::URI); ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", - "By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be"}}; + TestHeaderMapImpl headers{ + {"x-forwarded-client-cert", "By=test://foo.com/fe;URI=test://bar.com/be"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); EXPECT_EQ("By=test://foo.com/be;Hash=abcdefg;Subject=\"/C=US/ST=CA/L=San " - "Francisco/OU=Lyft/CN=test.lyft.com\";SAN=;URI=", + "Francisco/OU=Lyft/CN=test.lyft.com\";URI=", headers.get_("x-forwarded-client-cert")); } @@ -785,7 +834,7 @@ TEST_F(ConnectionManagerUtilityTest, TlsSanitizeClientCertWhenForward) { .WillByDefault(Return(Http::ForwardClientCertType::ForwardOnly)); std::vector details; ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test;SAN=abc;URI=abc"}}; + TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test;URI=abc"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); @@ -801,14 +850,13 @@ TEST_F(ConnectionManagerUtilityTest, TlsAlwaysForwardOnlyClientCert) { .WillByDefault(Return(Http::ForwardClientCertType::AlwaysForwardOnly)); std::vector details; ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", - "By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be"}}; + TestHeaderMapImpl headers{ + {"x-forwarded-client-cert", "By=test://foo.com/fe;URI=test://bar.com/be"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); - EXPECT_EQ("By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be", - headers.get_("x-forwarded-client-cert")); + EXPECT_EQ("By=test://foo.com/fe;URI=test://bar.com/be", headers.get_("x-forwarded-client-cert")); } // forward_only, append_forward and sanitize_set are only effective in mTLS connection. @@ -818,7 +866,7 @@ TEST_F(ConnectionManagerUtilityTest, NonTlsSanitizeClientCertWhenForward) { .WillByDefault(Return(Http::ForwardClientCertType::ForwardOnly)); std::vector details; ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test;SAN=abc;URI=abc"}}; + TestHeaderMapImpl headers{{"x-forwarded-client-cert", "By=test;URI=abc"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); @@ -832,14 +880,13 @@ TEST_F(ConnectionManagerUtilityTest, NonTlsAlwaysForwardClientCert) { .WillByDefault(Return(Http::ForwardClientCertType::AlwaysForwardOnly)); std::vector details; ON_CALL(config_, setCurrentClientCertDetails()).WillByDefault(ReturnRef(details)); - TestHeaderMapImpl headers{{"x-forwarded-client-cert", - "By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be"}}; + TestHeaderMapImpl headers{ + {"x-forwarded-client-cert", "By=test://foo.com/fe;URI=test://bar.com/be"}}; EXPECT_EQ((MutateRequestRet{"10.0.0.3:50000", false}), callMutateRequestHeaders(headers, Protocol::Http2)); EXPECT_TRUE(headers.has("x-forwarded-client-cert")); - EXPECT_EQ("By=test://foo.com/fe;SAN=test://bar.com/be;URI=test://bar.com/be", - headers.get_("x-forwarded-client-cert")); + EXPECT_EQ("By=test://foo.com/fe;URI=test://bar.com/be", headers.get_("x-forwarded-client-cert")); } // Sampling, global on. @@ -958,5 +1005,18 @@ TEST_F(ConnectionManagerUtilityTest, NoTraceOnBrokenUuid) { UuidUtils::isTraceableUuid(request_headers.get_("x-request-id"))); } +TEST_F(ConnectionManagerUtilityTest, RemovesProxyResponseHeaders) { + Http::TestHeaderMapImpl request_headers{{}}; + Http::TestHeaderMapImpl response_headers{{"keep-alive", "timeout=60"}, + {"proxy-connection", "proxy-header"}}; + ConnectionManagerUtility::mutateResponseHeaders(response_headers, &request_headers, ""); + + EXPECT_EQ(UuidTraceStatus::NoTrace, + UuidUtils::isTraceableUuid(request_headers.get_("x-request-id"))); + + EXPECT_FALSE(response_headers.has("keep-alive")); + EXPECT_FALSE(response_headers.has("proxy-connection")); +} + } // namespace Http } // namespace Envoy diff --git a/test/common/http/header_map_impl_corpus/empty b/test/common/http/header_map_impl_corpus/empty new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/common/http/header_map_impl_corpus/example b/test/common/http/header_map_impl_corpus/example new file mode 100644 index 0000000000000..e49ae93468aa0 --- /dev/null +++ b/test/common/http/header_map_impl_corpus/example @@ -0,0 +1,233 @@ +actions { + add_reference { + key: "foo" + value: "bar" + } +} +actions { + add_reference { + key: "foo" + value: "baz" + } +} +actions { + add_reference_key { + key: "foo_string_key" + string_value: "barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr" + } +} +actions { + add_reference_key { + key: "foo_string_key" + string_value: "baz" + } +} +actions { + add_reference_key { + key: "foo_uint64_key" + uint64_value: 42 + } +} +actions { + add_reference_key { + key: "foo_uint64_key" + uint64_value: 37 + } +} +actions { + add_copy { + key: "foo_string_key" + string_value: "barrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr" + } +} +actions { + add_copy { + key: "foo_string_key" + string_value: "baz" + } +} +actions { + add_copy { + key: "foo_uint64_key" + uint64_value: 42 + } +} +actions { + add_copy { + key: "foo_uint64_key" + uint64_value: 37 + } +} +actions { + set_reference { + key: "foo" + value: "bar" + } +} +actions { + set_reference { + key: "foo" + value: "baz" + } +} +actions { + set_reference_key { + key: "foo" + value: "bar" + } +} +actions { + set_reference_key { + key: "foo" + value: "baz" + } +} + +actions { + add_reference { + key: ":method" + value: "bar" + } +} +actions { + add_reference { + key: ":method" + value: "baz" + } +} +actions { + add_reference_key { + key: ":method" + string_value: "bar" + } +} +actions { + add_reference_key { + key: ":method" + string_value: "baz" + } +} +actions { + add_reference_key { + key: ":method" + uint64_value: 42 + } +} +actions { + add_reference_key { + key: ":method" + uint64_value: 37 + } +} +actions { + add_copy { + key: ":method" + string_value: "bar" + } +} +actions { + add_copy { + key: ":method" + string_value: "baz" + } +} +actions { + add_copy { + key: ":method" + uint64_value: 42 + } +} +actions { + add_copy { + key: ":method" + uint64_value: 37 + } +} +actions { + set_reference { + key: ":method" + value: "bar" + } +} +actions { + set_reference { + key: ":method" + value: "baz" + } +} +actions { + set_reference_key { + key: ":method" + value: "bar" + } +} +actions { + set_reference_key { + key: ":method" + value: "baz" + } +} + +actions { + get_and_mutate { + key: ":method" + append: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz" + } +} +actions { + get_and_mutate { + key: ":method" + append: "aa" + } +} +actions { + get_and_mutate { + key: ":method" + clear: {} + } +} +actions { + get_and_mutate { + key: ":method" + find: "a" + } +} +actions { + get_and_mutate { + key: ":method" + set_copy: "a" + } +} +actions { + get_and_mutate { + key: ":method" + set_integer: 0 + } +} +actions { + get_and_mutate { + key: ":method" + set_reference: "a" + } +} +actions { + copy: {} +} +actions { + lookup: ":method" +} +actions { + lookup: "foo" +} +actions { + remove: "f" +} +actions { + remove_prefix: "foo" +} +actions { + remove: ":m" +} +actions { + remove_prefix: ":m" +} diff --git a/test/common/http/header_map_impl_fuzz.proto b/test/common/http/header_map_impl_fuzz.proto new file mode 100644 index 0000000000000..fd1e0155c94e5 --- /dev/null +++ b/test/common/http/header_map_impl_fuzz.proto @@ -0,0 +1,69 @@ +syntax = "proto3"; + +package test.common.http; + +import "google/protobuf/empty.proto"; + +// Structured input for header_map_impl_fuzz_test. + +message AddReference { + string key = 1; + string value = 2; +} + +message AddReferenceKey { + string key = 1; + oneof value_selector { + string string_value = 2; + uint64 uint64_value = 3; + } +} + +message AddCopy { + string key = 1; + oneof value_selector { + string string_value = 2; + uint64 uint64_value = 3; + } +} + +message SetReference { + string key = 1; + string value = 2; +} + +message SetReferenceKey { + string key = 1; + string value = 2; +} + +message GetAndMutate { + string key = 1; + oneof mutate_selector { + string append = 2; + google.protobuf.Empty clear = 3; + string find = 4; + string set_copy = 5; + uint64 set_integer = 6; + string set_reference = 7; + } +} + +message Action { + oneof action_selector { + AddReference add_reference = 1; + AddReferenceKey add_reference_key = 2; + AddCopy add_copy = 3; + SetReference set_reference = 4; + SetReferenceKey set_reference_key = 5; + GetAndMutate get_and_mutate = 6; + google.protobuf.Empty copy = 7; + string lookup = 8; + string remove = 9; + string remove_prefix = 10; + } +} + +message HeaderMapImplFuzzTestCase { + repeated Action actions = 1; +} diff --git a/test/common/http/header_map_impl_fuzz_test.cc b/test/common/http/header_map_impl_fuzz_test.cc new file mode 100644 index 0000000000000..4e8f07658add7 --- /dev/null +++ b/test/common/http/header_map_impl_fuzz_test.cc @@ -0,0 +1,170 @@ +#include + +#include "common/common/assert.h" +#include "common/common/logger.h" +#include "common/http/header_map_impl.h" + +#include "test/common/http/header_map_impl_fuzz.pb.h" +#include "test/fuzz/fuzz_runner.h" + +namespace Envoy { + +// Fuzz the header map implementation. +DEFINE_PROTO_FUZZER(const test::common::http::HeaderMapImplFuzzTestCase& input) { + Http::HeaderMapImplPtr header_map = std::make_unique(); + const auto predefined_exists = [&header_map](const std::string& s) -> bool { + const Http::HeaderEntry* entry; + return header_map->lookup(Http::LowerCaseString(s), &entry) == Http::HeaderMap::Lookup::Found; + }; + std::vector> lower_case_strings; + std::vector> strings; + for (int i = 0; i < input.actions().size(); ++i) { + const auto& action = input.actions(i); + ENVOY_LOG_MISC(debug, "Action {}", action.DebugString()); + switch (action.action_selector_case()) { + case test::common::http::Action::kAddReference: { + const auto& add_reference = action.add_reference(); + // Workaround for https://github.com/envoyproxy/envoy/issues/3919. + if (predefined_exists(add_reference.key())) { + continue; + } + lower_case_strings.emplace_back(std::make_unique(add_reference.key())); + strings.emplace_back(std::make_unique(add_reference.value())); + header_map->addReference(*lower_case_strings.back(), *strings.back()); + break; + } + case test::common::http::Action::kAddReferenceKey: { + const auto& add_reference_key = action.add_reference_key(); + // Workaround for https://github.com/envoyproxy/envoy/issues/3919. + if (predefined_exists(add_reference_key.key())) { + continue; + } + lower_case_strings.emplace_back( + std::make_unique(add_reference_key.key())); + switch (add_reference_key.value_selector_case()) { + case test::common::http::AddReferenceKey::kStringValue: + header_map->addReferenceKey(*lower_case_strings.back(), add_reference_key.string_value()); + break; + case test::common::http::AddReferenceKey::kUint64Value: + header_map->addReferenceKey(*lower_case_strings.back(), add_reference_key.uint64_value()); + break; + default: + break; + } + break; + } + case test::common::http::Action::kAddCopy: { + const auto& add_copy = action.add_copy(); + // Workaround for https://github.com/envoyproxy/envoy/issues/3919. + if (predefined_exists(add_copy.key())) { + continue; + } + const Http::LowerCaseString key{add_copy.key()}; + switch (add_copy.value_selector_case()) { + case test::common::http::AddCopy::kStringValue: + header_map->addCopy(key, add_copy.string_value()); + break; + case test::common::http::AddCopy::kUint64Value: + header_map->addCopy(key, add_copy.uint64_value()); + break; + default: + break; + } + break; + } + case test::common::http::Action::kSetReference: { + const auto& set_reference = action.set_reference(); + lower_case_strings.emplace_back(std::make_unique(set_reference.key())); + strings.emplace_back(std::make_unique(set_reference.value())); + header_map->setReference(*lower_case_strings.back(), *strings.back()); + break; + } + case test::common::http::Action::kSetReferenceKey: { + const auto& set_reference_key = action.set_reference_key(); + lower_case_strings.emplace_back( + std::make_unique(set_reference_key.key())); + header_map->setReferenceKey(*lower_case_strings.back(), set_reference_key.value()); + break; + } + case test::common::http::Action::kGetAndMutate: { + const auto& get_and_mutate = action.get_and_mutate(); + auto* header_entry = header_map->get(Http::LowerCaseString(get_and_mutate.key())); + if (header_entry != nullptr) { + // Do some read-only stuff. + (void)strlen(header_entry->key().c_str()); + (void)strlen(header_entry->value().c_str()); + (void)strlen(header_entry->value().buffer()); + header_entry->key().empty(); + header_entry->value().empty(); + // Do some mutation or parameterized action. + switch (get_and_mutate.mutate_selector_case()) { + case test::common::http::GetAndMutate::kAppend: + header_entry->value().append(get_and_mutate.append().c_str(), + get_and_mutate.append().size()); + break; + case test::common::http::GetAndMutate::kClear: + header_entry->value().clear(); + break; + case test::common::http::GetAndMutate::kFind: + header_entry->value().find(get_and_mutate.find().c_str()); + break; + case test::common::http::GetAndMutate::kSetCopy: + header_entry->value().setCopy(get_and_mutate.set_copy().c_str(), + get_and_mutate.set_copy().size()); + break; + case test::common::http::GetAndMutate::kSetInteger: + header_entry->value().setInteger(get_and_mutate.set_integer()); + break; + case test::common::http::GetAndMutate::kSetReference: + strings.emplace_back(std::make_unique(get_and_mutate.set_reference())); + header_entry->value().setReference(*strings.back()); + break; + default: + break; + } + } + break; + } + case test::common::http::Action::kCopy: { + header_map = std::make_unique( + *reinterpret_cast(header_map.get())); + break; + } + case test::common::http::Action::kLookup: { + const Http::HeaderEntry* header_entry; + header_map->lookup(Http::LowerCaseString(action.lookup()), &header_entry); + break; + } + case test::common::http::Action::kRemove: { + header_map->remove(Http::LowerCaseString(action.remove())); + break; + } + case test::common::http::Action::kRemovePrefix: { + header_map->removePrefix(Http::LowerCaseString(action.remove_prefix())); + break; + } + default: + // Maybe nothing is set? + break; + } + // Exercise some read-only accessors. + header_map->byteSize(); + header_map->size(); + header_map->iterate( + [](const Http::HeaderEntry& header, void * /*context*/) -> Http::HeaderMap::Iterate { + header.key(); + header.value(); + return Http::HeaderMap::Iterate::Continue; + }, + nullptr); + header_map->iterateReverse( + [](const Http::HeaderEntry& header, void * /*context*/) -> Http::HeaderMap::Iterate { + header.key(); + header.value(); + return Http::HeaderMap::Iterate::Continue; + }, + nullptr); + } +} + +} // namespace Envoy diff --git a/test/common/http/header_map_impl_test.cc b/test/common/http/header_map_impl_test.cc index 141e2ff4ffee4..5b0ce1a7092ad 100644 --- a/test/common/http/header_map_impl_test.cc +++ b/test/common/http/header_map_impl_test.cc @@ -261,31 +261,6 @@ TEST(HeaderStringTest, All) { EXPECT_EQ(HeaderString::Type::Reference, string.type()); } - // caseInsensitiveContains - { - const std::string static_string("keep-alive, Upgrade, close"); - HeaderString string(static_string); - EXPECT_TRUE(string.caseInsensitiveContains("keep-alive")); - EXPECT_TRUE(string.caseInsensitiveContains("Keep-alive")); - EXPECT_TRUE(string.caseInsensitiveContains("Upgrade")); - EXPECT_TRUE(string.caseInsensitiveContains("upgrade")); - EXPECT_TRUE(string.caseInsensitiveContains("close")); - EXPECT_TRUE(string.caseInsensitiveContains("Close")); - EXPECT_FALSE(string.caseInsensitiveContains("")); - EXPECT_FALSE(string.caseInsensitiveContains("keep")); - EXPECT_FALSE(string.caseInsensitiveContains("alive")); - EXPECT_FALSE(string.caseInsensitiveContains("grade")); - - const std::string small("close"); - string.setCopy(small.c_str(), small.size()); - EXPECT_FALSE(string.caseInsensitiveContains("keep-alive")); - - const std::string empty(""); - string.setCopy(empty.c_str(), empty.size()); - EXPECT_FALSE(string.caseInsensitiveContains("keep-alive")); - EXPECT_FALSE(string.caseInsensitiveContains("")); - } - // getString { std::string static_string("HELLO"); diff --git a/test/common/http/header_utility_test.cc b/test/common/http/header_utility_test.cc index 6102a1a44903c..b94280d2d41af 100644 --- a/test/common/http/header_utility_test.cc +++ b/test/common/http/header_utility_test.cc @@ -43,35 +43,6 @@ name: test-header EXPECT_EQ(HeaderUtility::HeaderMatchType::Present, header_data.header_match_type_); } -TEST(HeaderDataConstructorTest, ValueSet) { - const std::string yaml = R"EOF( -name: test-header -value: value - )EOF"; - - HeaderUtility::HeaderData header_data = - HeaderUtility::HeaderData(parseHeaderMatcherFromYaml(yaml)); - - EXPECT_EQ("test-header", header_data.name_.get()); - EXPECT_EQ(HeaderUtility::HeaderMatchType::Value, header_data.header_match_type_); - EXPECT_EQ("value", header_data.value_); -} - -TEST(HeaderDataConstructorTest, ValueAndRegexFlagSet) { - const std::string yaml = R"EOF( -name: test-header -value: value -regex: true - )EOF"; - - HeaderUtility::HeaderData header_data = - HeaderUtility::HeaderData(parseHeaderMatcherFromYaml(yaml)); - - EXPECT_EQ("test-header", header_data.name_.get()); - EXPECT_EQ(HeaderUtility::HeaderMatchType::Regex, header_data.header_match_type_); - EXPECT_EQ("", header_data.value_); -} - TEST(HeaderDataConstructorTest, ExactMatchSpecifier) { const std::string yaml = R"EOF( name: test-header diff --git a/test/common/http/http1/codec_impl_test.cc b/test/common/http/http1/codec_impl_test.cc index fe2dfb71d2da2..2427f9c4f4a29 100644 --- a/test/common/http/http1/codec_impl_test.cc +++ b/test/common/http/http1/codec_impl_test.cc @@ -423,6 +423,29 @@ TEST_F(Http1ServerConnectionImplTest, HeadRequestResponse) { EXPECT_EQ("HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\n", output); } +TEST_F(Http1ServerConnectionImplTest, HeadChunkedRequestResponse) { + initialize(); + + NiceMock decoder; + Http::StreamEncoder* response_encoder = nullptr; + EXPECT_CALL(callbacks_, newStream(_)) + .WillOnce(Invoke([&](Http::StreamEncoder& encoder) -> Http::StreamDecoder& { + response_encoder = &encoder; + return decoder; + })); + + Buffer::OwnedImpl buffer("HEAD / HTTP/1.1\r\n\r\n"); + codec_->dispatch(buffer); + EXPECT_EQ(0U, buffer.length()); + + std::string output; + ON_CALL(connection_, write(_, _)).WillByDefault(AddBufferToString(&output)); + + TestHeaderMapImpl headers{{":status", "200"}}; + response_encoder->encodeHeaders(headers, true); + EXPECT_EQ("HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\n\r\n", output); +} + TEST_F(Http1ServerConnectionImplTest, DoubleRequest) { initialize(); @@ -465,6 +488,78 @@ TEST_F(Http1ServerConnectionImplTest, RequestWithTrailers) { EXPECT_EQ(0U, buffer.length()); } +TEST_F(Http1ServerConnectionImplTest, UpgradeRequest) { + initialize(); + + InSequence sequence; + NiceMock decoder; + EXPECT_CALL(callbacks_, newStream(_)).WillOnce(ReturnRef(decoder)); + + EXPECT_CALL(decoder, decodeHeaders_(_, false)).Times(1); + Buffer::OwnedImpl buffer( + "POST / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: foo\r\ncontent-length:5\r\n\r\n"); + codec_->dispatch(buffer); + + Buffer::OwnedImpl expected_data1("12345"); + Buffer::OwnedImpl body("12345"); + EXPECT_CALL(decoder, decodeData(BufferEqual(&expected_data1), false)).Times(1); + codec_->dispatch(body); + + Buffer::OwnedImpl expected_data2("abcd"); + Buffer::OwnedImpl websocket_payload("abcd"); + EXPECT_CALL(decoder, decodeData(BufferEqual(&expected_data2), false)).Times(1); + codec_->dispatch(websocket_payload); +} + +TEST_F(Http1ServerConnectionImplTest, UpgradeRequestWithEarlyData) { + initialize(); + + InSequence sequence; + NiceMock decoder; + EXPECT_CALL(callbacks_, newStream(_)).WillOnce(ReturnRef(decoder)); + + Buffer::OwnedImpl expected_data("12345abcd"); + EXPECT_CALL(decoder, decodeHeaders_(_, false)).Times(1); + EXPECT_CALL(decoder, decodeData(BufferEqual(&expected_data), false)).Times(1); + Buffer::OwnedImpl buffer("POST / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: " + "foo\r\ncontent-length:5\r\n\r\n12345abcd"); + codec_->dispatch(buffer); +} + +TEST_F(Http1ServerConnectionImplTest, UpgradeRequestWithTEChunked) { + initialize(); + + InSequence sequence; + NiceMock decoder; + EXPECT_CALL(callbacks_, newStream(_)).WillOnce(ReturnRef(decoder)); + + // Even with T-E chunked, the data should neither be inspected for (the not + // present in this unit test) chunks, but simply passed through. + Buffer::OwnedImpl expected_data("12345abcd"); + EXPECT_CALL(decoder, decodeHeaders_(_, false)).Times(1); + EXPECT_CALL(decoder, decodeData(BufferEqual(&expected_data), false)).Times(1); + Buffer::OwnedImpl buffer("POST / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: " + "foo\r\ntransfer-encoding: chunked\r\n\r\n12345abcd"); + codec_->dispatch(buffer); +} + +TEST_F(Http1ServerConnectionImplTest, UpgradeRequestWithNoBody) { + initialize(); + + InSequence sequence; + NiceMock decoder; + EXPECT_CALL(callbacks_, newStream(_)).WillOnce(ReturnRef(decoder)); + + // Make sure we avoid the deferred_end_stream_headers_ optimization for + // requests-with-no-body. + Buffer::OwnedImpl expected_data("abcd"); + EXPECT_CALL(decoder, decodeHeaders_(_, false)).Times(1); + EXPECT_CALL(decoder, decodeData(BufferEqual(&expected_data), false)).Times(1); + Buffer::OwnedImpl buffer( + "GET / HTTP/1.1\r\nConnection: upgrade\r\nUpgrade: foo\r\ncontent-length: 0\r\n\r\nabcd"); + codec_->dispatch(buffer); +} + TEST_F(Http1ServerConnectionImplTest, WatermarkTest) { EXPECT_CALL(connection_, bufferLimit()).Times(1).WillOnce(Return(10)); initialize(); @@ -693,6 +788,56 @@ TEST_F(Http1ClientConnectionImplTest, GiantPath) { codec_->dispatch(response); } +TEST_F(Http1ClientConnectionImplTest, UpgradeResponse) { + initialize(); + + InSequence s; + + NiceMock response_decoder; + Http::StreamEncoder& request_encoder = codec_->newStream(response_decoder); + TestHeaderMapImpl headers{{":method", "GET"}, {":path", "/"}, {":authority", "host"}}; + request_encoder.encodeHeaders(headers, true); + + // Send upgrade headers + EXPECT_CALL(response_decoder, decodeHeaders_(_, false)); + Buffer::OwnedImpl response( + "HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: upgrade\r\nUpgrade: websocket\r\n\r\n"); + codec_->dispatch(response); + + // Send body payload + Buffer::OwnedImpl expected_data1("12345"); + Buffer::OwnedImpl body("12345"); + EXPECT_CALL(response_decoder, decodeData(BufferEqual(&expected_data1), false)).Times(1); + codec_->dispatch(body); + + // Send websocket payload + Buffer::OwnedImpl expected_data2("abcd"); + Buffer::OwnedImpl websocket_payload("abcd"); + EXPECT_CALL(response_decoder, decodeData(BufferEqual(&expected_data2), false)).Times(1); + codec_->dispatch(websocket_payload); +} + +// Same data as above, but make sure directDispatch immediately hands off any +// outstanding data. +TEST_F(Http1ClientConnectionImplTest, UpgradeResponseWithEarlyData) { + initialize(); + + InSequence s; + + NiceMock response_decoder; + Http::StreamEncoder& request_encoder = codec_->newStream(response_decoder); + TestHeaderMapImpl headers{{":method", "GET"}, {":path", "/"}, {":authority", "host"}}; + request_encoder.encodeHeaders(headers, true); + + // Send upgrade headers + EXPECT_CALL(response_decoder, decodeHeaders_(_, false)); + Buffer::OwnedImpl expected_data("12345abcd"); + EXPECT_CALL(response_decoder, decodeData(BufferEqual(&expected_data), false)).Times(1); + Buffer::OwnedImpl response("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nConnection: " + "upgrade\r\nUpgrade: websocket\r\n\r\n12345abcd"); + codec_->dispatch(response); +} + TEST_F(Http1ClientConnectionImplTest, WatermarkTest) { EXPECT_CALL(connection_, bufferLimit()).Times(1).WillOnce(Return(10)); initialize(); diff --git a/test/common/http/http2/codec_impl_test.cc b/test/common/http/http2/codec_impl_test.cc index 47d6993129e93..2894363830efb 100644 --- a/test/common/http/http2/codec_impl_test.cc +++ b/test/common/http/http2/codec_impl_test.cc @@ -665,6 +665,30 @@ TEST_P(Http2CodecImplTest, WatermarkUnderEndStream) { response_encoder_->encodeHeaders(response_headers, true); } +class Http2CodecImplStreamLimitTest : public Http2CodecImplTest {}; + +// Regression test for issue #3076. +// +// TODO(PiotrSikora): add tests that exercise both scenarios: before and after receiving +// the HTTP/2 SETTINGS frame. +TEST_P(Http2CodecImplStreamLimitTest, MaxClientStreams) { + for (int i = 0; i < 101; ++i) { + request_encoder_ = &client_.newStream(response_decoder_); + setupDefaultConnectionMocks(); + EXPECT_CALL(server_callbacks_, newStream(_)) + .WillOnce(Invoke([&](StreamEncoder& encoder) -> StreamDecoder& { + response_encoder_ = &encoder; + encoder.getStream().addCallbacks(server_stream_callbacks_); + return request_decoder_; + })); + + TestHeaderMapImpl request_headers; + HttpTestUtility::addDefaultHeaders(request_headers); + EXPECT_CALL(request_decoder_, decodeHeaders_(_, true)); + request_encoder_->encodeHeaders(request_headers, true); + } +} + #define HTTP2SETTINGS_SMALL_WINDOW_COMBINE \ ::testing::Combine(::testing::Values(Http2Settings::DEFAULT_HPACK_TABLE_SIZE), \ ::testing::Values(Http2Settings::DEFAULT_MAX_CONCURRENT_STREAMS), \ @@ -688,6 +712,12 @@ INSTANTIATE_TEST_CASE_P(Http2CodecImplFlowControlTest, Http2CodecImplFlowControl ::testing::Values(Http2Settings::DEFAULT_INITIAL_STREAM_WINDOW_SIZE), \ ::testing::Values(Http2Settings::DEFAULT_INITIAL_CONNECTION_WINDOW_SIZE)) +// Stream limit test only uses the default values because not all combinations of +// edge settings allow for the number of streams needed by the test. +INSTANTIATE_TEST_CASE_P(Http2CodecImplStreamLimitTest, Http2CodecImplStreamLimitTest, + ::testing::Combine(HTTP2SETTINGS_DEFAULT_COMBINE, + HTTP2SETTINGS_DEFAULT_COMBINE)); + INSTANTIATE_TEST_CASE_P(Http2CodecImplTestDefaultSettings, Http2CodecImplTest, ::testing::Combine(HTTP2SETTINGS_DEFAULT_COMBINE, HTTP2SETTINGS_DEFAULT_COMBINE)); @@ -753,6 +783,34 @@ TEST_P(Http2CodecImplTest, TestCodecHeaderLimits) { request_encoder_->encodeHeaders(request_headers, false); } +TEST_P(Http2CodecImplTest, TestCodecHeaderCompression) { + initialize(); + + TestHeaderMapImpl request_headers; + HttpTestUtility::addDefaultHeaders(request_headers); + EXPECT_CALL(request_decoder_, decodeHeaders_(_, true)); + request_encoder_->encodeHeaders(request_headers, true); + + TestHeaderMapImpl response_headers{{":status", "200"}, {"compression", "test"}}; + EXPECT_CALL(response_decoder_, decodeHeaders_(_, true)); + response_encoder_->encodeHeaders(response_headers, true); + + // Sanity check to verify that state of encoders and decoders matches. + EXPECT_EQ(nghttp2_session_get_hd_deflate_dynamic_table_size(server_.session()), + nghttp2_session_get_hd_inflate_dynamic_table_size(client_.session())); + EXPECT_EQ(nghttp2_session_get_hd_deflate_dynamic_table_size(client_.session()), + nghttp2_session_get_hd_inflate_dynamic_table_size(server_.session())); + + // Verify that headers are compressed only when both client and server advertise table size > 0: + if (client_http2settings_.hpack_table_size_ && server_http2settings_.hpack_table_size_) { + EXPECT_NE(0, nghttp2_session_get_hd_deflate_dynamic_table_size(client_.session())); + EXPECT_NE(0, nghttp2_session_get_hd_deflate_dynamic_table_size(server_.session())); + } else { + EXPECT_EQ(0, nghttp2_session_get_hd_deflate_dynamic_table_size(client_.session())); + EXPECT_EQ(0, nghttp2_session_get_hd_deflate_dynamic_table_size(server_.session())); + } +} + } // namespace Http2 } // namespace Http } // namespace Envoy diff --git a/test/common/http/utility_test.cc b/test/common/http/utility_test.cc index cd04127fbd27e..c50be78887abe 100644 --- a/test/common/http/utility_test.cc +++ b/test/common/http/utility_test.cc @@ -46,6 +46,8 @@ TEST(HttpUtility, isWebSocketUpgradeRequest) { EXPECT_FALSE(Utility::isWebSocketUpgradeRequest(TestHeaderMapImpl{{"upgrade", "websocket"}})); EXPECT_FALSE(Utility::isWebSocketUpgradeRequest( TestHeaderMapImpl{{"Connection", "close"}, {"Upgrade", "websocket"}})); + EXPECT_FALSE(Utility::isUpgrade( + TestHeaderMapImpl{{"Connection", "IsNotAnUpgrade"}, {"Upgrade", "websocket"}})); EXPECT_TRUE(Utility::isWebSocketUpgradeRequest( TestHeaderMapImpl{{"Connection", "upgrade"}, {"Upgrade", "websocket"}})); @@ -55,6 +57,23 @@ TEST(HttpUtility, isWebSocketUpgradeRequest) { TestHeaderMapImpl{{"connection", "Upgrade"}, {"upgrade", "WebSocket"}})); } +TEST(HttpUtility, isUpgrade) { + EXPECT_FALSE(Utility::isUpgrade(TestHeaderMapImpl{})); + EXPECT_FALSE(Utility::isUpgrade(TestHeaderMapImpl{{"connection", "upgrade"}})); + EXPECT_FALSE(Utility::isUpgrade(TestHeaderMapImpl{{"upgrade", "foo"}})); + EXPECT_FALSE(Utility::isUpgrade(TestHeaderMapImpl{{"Connection", "close"}, {"Upgrade", "foo"}})); + EXPECT_FALSE( + Utility::isUpgrade(TestHeaderMapImpl{{"Connection", "IsNotAnUpgrade"}, {"Upgrade", "foo"}})); + EXPECT_FALSE(Utility::isUpgrade( + TestHeaderMapImpl{{"Connection", "Is Not An Upgrade"}, {"Upgrade", "foo"}})); + + EXPECT_TRUE(Utility::isUpgrade(TestHeaderMapImpl{{"Connection", "upgrade"}, {"Upgrade", "foo"}})); + EXPECT_TRUE(Utility::isUpgrade(TestHeaderMapImpl{{"connection", "upgrade"}, {"upgrade", "foo"}})); + EXPECT_TRUE(Utility::isUpgrade(TestHeaderMapImpl{{"connection", "Upgrade"}, {"upgrade", "FoO"}})); + EXPECT_TRUE(Utility::isUpgrade( + TestHeaderMapImpl{{"connection", "keep-alive, Upgrade"}, {"upgrade", "FOO"}})); +} + TEST(HttpUtility, appendXff) { { TestHeaderMapImpl headers; diff --git a/test/common/network/addr_family_aware_socket_option_impl_test.cc b/test/common/network/addr_family_aware_socket_option_impl_test.cc index 3c6c16a0d9c9f..90f2334b7944e 100644 --- a/test/common/network/addr_family_aware_socket_option_impl_test.cc +++ b/test/common/network/addr_family_aware_socket_option_impl_test.cc @@ -49,6 +49,8 @@ TEST_F(AddrFamilyAwareSocketOptionImplTest, V4EmptyOptionNames) { // If a platform doesn't support IPv4 and IPv6 socket option variants for an IPv4 address, we fail TEST_F(AddrFamilyAwareSocketOptionImplTest, V6EmptyOptionNames) { + EXPECT_CALL(os_sys_calls_, socket(_, _, _)); + EXPECT_CALL(os_sys_calls_, close(_)); Address::Ipv6Instance address("::1:2:3:4", 5678); const int fd = address.socket(Address::SocketType::Stream); EXPECT_CALL(socket_, fd()).WillRepeatedly(Return(fd)); diff --git a/test/common/network/address_impl_test.cc b/test/common/network/address_impl_test.cc index 338ad1fe3e185..839f4437699cd 100644 --- a/test/common/network/address_impl_test.cc +++ b/test/common/network/address_impl_test.cc @@ -64,9 +64,9 @@ void testSocketBindAndConnect(Network::Address::IpVersion ip_version, bool v6onl } // Bind the socket to the desired address and port. - const int rc = addr_port->bind(listen_fd); - const int err = errno; - ASSERT_EQ(rc, 0) << addr_port->asString() << "\nerror: " << strerror(err) << "\nerrno: " << err; + const Api::SysCallResult result = addr_port->bind(listen_fd); + ASSERT_EQ(result.rc_, 0) << addr_port->asString() << "\nerror: " << strerror(result.errno_) + << "\nerrno: " << result.errno_; // Do a bare listen syscall. Not bothering to accept connections as that would // require another thread. @@ -85,9 +85,9 @@ void testSocketBindAndConnect(Network::Address::IpVersion ip_version, bool v6onl makeFdBlocking(client_fd); // Connect to the server. - const int rc = addr_port->connect(client_fd); - const int err = errno; - ASSERT_EQ(rc, 0) << addr_port->asString() << "\nerror: " << strerror(err) << "\nerrno: " << err; + const Api::SysCallResult result = addr_port->connect(client_fd); + ASSERT_EQ(result.rc_, 0) << addr_port->asString() << "\nerror: " << strerror(result.errno_) + << "\nerrno: " << result.errno_; }; client_connect(addr_port); @@ -314,9 +314,9 @@ TEST(PipeInstanceTest, UnlinksExistingFile) { ASSERT_GE(listen_fd, 0) << address.asString(); ScopedFdCloser closer(listen_fd); - const int rc = address.bind(listen_fd); - const int err = errno; - ASSERT_EQ(rc, 0) << address.asString() << "\nerror: " << strerror(err) << "\nerrno: " << err; + const Api::SysCallResult result = address.bind(listen_fd); + ASSERT_EQ(result.rc_, 0) << address.asString() << "\nerror: " << strerror(result.errno_) + << "\nerrno: " << result.errno_; }; const std::string path = TestEnvironment::unixDomainSocketPath("UnlinksExistingFile.sock"); diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index d19c552b91e61..a7c4f022a3ea2 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -678,7 +678,7 @@ TEST_P(ConnectionImplTest, WriteWithWatermarks) { EXPECT_CALL(*client_write_buffer_, move(_)) .WillRepeatedly(DoAll(AddBufferToStringWithoutDraining(&data_written), Invoke(client_write_buffer_, &MockWatermarkBuffer::baseMove))); - EXPECT_CALL(*client_write_buffer_, write(_)).WillOnce(Invoke([&](int fd) -> int { + EXPECT_CALL(*client_write_buffer_, write(_)).WillOnce(Invoke([&](int fd) -> Api::SysCallResult { dispatcher_->exit(); return client_write_buffer_->failWrite(fd); })); @@ -764,7 +764,7 @@ TEST_P(ConnectionImplTest, WatermarkFuzzing) { .WillOnce(Invoke(client_write_buffer_, &MockWatermarkBuffer::baseMove)); EXPECT_CALL(*client_write_buffer_, write(_)) .WillOnce(DoAll(Invoke([&](int) -> void { client_write_buffer_->drain(bytes_to_flush); }), - Return(bytes_to_flush))) + Return(Api::SysCallResult{bytes_to_flush, 0}))) .WillRepeatedly(testing::Invoke(client_write_buffer_, &MockWatermarkBuffer::failWrite)); client_connection_->write(buffer_to_write, false); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); diff --git a/test/common/network/dns_impl_test.cc b/test/common/network/dns_impl_test.cc index f275a86fa733c..d8de2de5b253d 100644 --- a/test/common/network/dns_impl_test.cc +++ b/test/common/network/dns_impl_test.cc @@ -359,8 +359,8 @@ class CustomInstance : public Address::Instance { } const std::string& asString() const override { return antagonistic_name_; } const std::string& logicalName() const override { return antagonistic_name_; } - int bind(int fd) const override { return instance_.bind(fd); } - int connect(int fd) const override { return instance_.connect(fd); } + Api::SysCallResult bind(int fd) const override { return instance_.bind(fd); } + Api::SysCallResult connect(int fd) const override { return instance_.connect(fd); } const Address::Ip* ip() const override { return instance_.ip(); } int socket(Address::SocketType type) const override { return instance_.socket(type); } Address::Type type() const override { return instance_.type(); } diff --git a/test/common/network/filter_manager_impl_test.cc b/test/common/network/filter_manager_impl_test.cc index 4768f6dd4e176..4bd2dbc5c4acf 100644 --- a/test/common/network/filter_manager_impl_test.cc +++ b/test/common/network/filter_manager_impl_test.cc @@ -154,6 +154,8 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { InSequence s; NiceMock factory_context; NiceMock connection; + NiceMock upstream_connection; + NiceMock conn_pool; FilterManagerImpl manager(connection, *this); std::string rl_json = R"EOF( @@ -202,21 +204,15 @@ TEST_F(NetworkFilterManagerTest, RateLimitAndTcpProxy) { EXPECT_EQ(manager.initializeReadFilters(), true); - NiceMock* upstream_connection = - new NiceMock(); - Upstream::MockHost::MockCreateConnectionData conn_info; - conn_info.connection_ = upstream_connection; - conn_info.host_description_ = Upstream::makeTestHost( - factory_context.cluster_manager_.thread_local_cluster_.cluster_.info_, "tcp://127.0.0.1:80"); - EXPECT_CALL(factory_context.cluster_manager_, tcpConnForCluster_("fake_cluster", _)) - .WillOnce(Return(conn_info)); + EXPECT_CALL(factory_context.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool)); request_callbacks->complete(RateLimit::LimitStatus::OK); - upstream_connection->raiseEvent(Network::ConnectionEvent::Connected); + conn_pool.poolReady(upstream_connection); Buffer::OwnedImpl buffer("hello"); - EXPECT_CALL(*upstream_connection, write(BufferEqual(&buffer), _)); + EXPECT_CALL(upstream_connection, write(BufferEqual(&buffer), _)); read_buffer_.add("hello"); manager.onRead(); } diff --git a/test/common/network/lc_trie_speed_test.cc b/test/common/network/lc_trie_speed_test.cc index 4fa2ac0c13b67..e859a56bc3ac0 100644 --- a/test/common/network/lc_trie_speed_test.cc +++ b/test/common/network/lc_trie_speed_test.cc @@ -7,9 +7,9 @@ namespace { std::vector addresses; -std::unique_ptr lc_trie; +std::unique_ptr> lc_trie; -std::unique_ptr lc_trie_nested_prefixes; +std::unique_ptr> lc_trie_nested_prefixes; } // namespace @@ -21,7 +21,7 @@ static void BM_LcTrieLookup(benchmark::State& state) { for (auto _ : state) { i++; i %= addresses.size(); - output_tags += lc_trie->getTags(addresses[i]).size(); + output_tags += lc_trie->getData(addresses[i]).size(); } benchmark::DoNotOptimize(output_tags); } @@ -34,7 +34,7 @@ static void BM_LcTrieLookupWithNestedPrefixes(benchmark::State& state) { for (auto _ : state) { i++; i %= addresses.size(); - output_tags += lc_trie_nested_prefixes->getTags(addresses[i]).size(); + output_tags += lc_trie_nested_prefixes->getData(addresses[i]).size(); } benchmark::DoNotOptimize(output_tags); } @@ -63,10 +63,10 @@ int main(int argc, char** argv) { } } - lc_trie = std::make_unique(tag_data); + lc_trie = std::make_unique>(tag_data); tag_data.emplace_back(std::pair>( {"tag_0", {Envoy::Network::Address::CidrRange::create("0.0.0.0/0")}})); - lc_trie_nested_prefixes = std::make_unique(tag_data); + lc_trie_nested_prefixes = std::make_unique>(tag_data); benchmark::Initialize(&argc, argv); if (benchmark::ReportUnrecognizedArguments(argc, argv)) { diff --git a/test/common/network/lc_trie_test.cc b/test/common/network/lc_trie_test.cc index a1073ccbc5c9c..a407e7832214f 100644 --- a/test/common/network/lc_trie_test.cc +++ b/test/common/network/lc_trie_test.cc @@ -14,7 +14,7 @@ namespace LcTrie { class LcTrieTest : public testing::Test { public: void setup(const std::vector>& cidr_range_strings, - double fill_factor = 0, uint32_t root_branch_factor = 0) { + bool exclusive = false, double fill_factor = 0, uint32_t root_branch_factor = 0) { std::vector>> output; for (size_t i = 0; i < cidr_range_strings.size(); i++) { std::pair> ip_tags; @@ -26,9 +26,9 @@ class LcTrieTest : public testing::Test { } // Use custom fill factors and root branch factors if they are in the valid range. if ((fill_factor > 0) && (fill_factor <= 1) && (root_branch_factor > 0)) { - trie_.reset(new LcTrie(output, fill_factor, root_branch_factor)); + trie_.reset(new LcTrie(output, exclusive, fill_factor, root_branch_factor)); } else { - trie_.reset(new LcTrie(output)); + trie_.reset(new LcTrie(output, exclusive)); } } @@ -37,13 +37,13 @@ class LcTrieTest : public testing::Test { for (const auto& kv : test_output) { std::vector expected(kv.second); std::sort(expected.begin(), expected.end()); - std::vector actual(trie_->getTags(Utility::parseInternetAddress(kv.first))); + std::vector actual(trie_->getData(Utility::parseInternetAddress(kv.first))); std::sort(actual.begin(), actual.end()); EXPECT_EQ(expected, actual); } } - std::unique_ptr trie_; + std::unique_ptr> trie_; }; // Use the default constructor values. @@ -102,7 +102,7 @@ TEST_F(LcTrieTest, RootBranchingFactor) { {"232.0.0.0/8"}, // tag_13 {"233.0.0.0/8"}, // tag_14 }; - setup(cidr_range_strings, fill_factor, root_branching_factor); + setup(cidr_range_strings, false, fill_factor, root_branching_factor); std::vector>> test_case = { {"0.0.0.0", {"tag_0"}}, {"16.0.0.1", {"tag_1"}}, @@ -284,19 +284,41 @@ TEST_F(LcTrieTest, NestedPrefixesWithCatchAll) { {"::0/0"}, // tag_4 {"2001:db8::/96", "2001:db8::8000/97"}, // tag_5 {"2001:db8::ffff/128"}, // tag_6 - {"2001:db8:1::/48"} // tag_7 + {"2001:db8:1::/48"}, // tag_7 + {"203.0.113.0/24"} // tag_8 (same subnet as tag_1) }; setup(cidr_range_strings); std::vector>> test_case = { - {"203.0.113.0", {"tag_0", "tag_1"}}, - {"203.0.113.192", {"tag_0", "tag_1", "tag_2"}}, - {"203.0.113.255", {"tag_0", "tag_1", "tag_2"}}, + {"203.0.0.0", {"tag_0"}}, + {"203.0.113.0", {"tag_0", "tag_1", "tag_8"}}, + {"203.0.113.192", {"tag_0", "tag_1", "tag_2", "tag_8"}}, + {"203.0.113.255", {"tag_0", "tag_1", "tag_2", "tag_8"}}, {"198.51.100.1", {"tag_0", "tag_3"}}, {"2001:db8::ffff", {"tag_4", "tag_5", "tag_6"}}, - {"2001:db8:1::ffff", {"tag_4", "tag_7"}} + {"2001:db8:1::ffff", {"tag_4", "tag_7"}}}; + expectIPAndTags(test_case); +} +TEST_F(LcTrieTest, ExclusiveNestedPrefixesWithCatchAll) { + std::vector> cidr_range_strings = { + {"0.0.0.0/0"}, // tag_0 + {"203.0.113.0/24"}, // tag_1 + {"203.0.113.128/25"}, // tag_2 + {"198.51.100.0/24"}, // tag_3 + {"::0/0"}, // tag_4 + {"2001:db8::/96", "2001:db8::8000/97"}, // tag_5 + {"2001:db8::ffff/128"}, // tag_6 + {"2001:db8:1::/48"}, // tag_7 + {"203.0.113.0/24"} // tag_8 (same subnet as tag_1) }; + setup(cidr_range_strings, true); + + std::vector>> test_case = { + {"203.0.0.0", {"tag_0"}}, {"203.0.113.0", {"tag_1", "tag_8"}}, + {"203.0.113.192", {"tag_2"}}, {"203.0.113.255", {"tag_2"}}, + {"198.51.100.1", {"tag_3"}}, {"2001:db8::ffff", {"tag_6"}}, + {"2001:db8:1::ffff", {"tag_7"}}}; expectIPAndTags(test_case); } @@ -315,7 +337,7 @@ TEST_F(LcTrieTest, MaximumEntriesExceptionDefault) { std::pair> ip_tag = std::make_pair("bad_tag", prefixes); std::vector>> ip_tags_input{ip_tag}; - EXPECT_THROW_WITH_MESSAGE(new LcTrie(ip_tags_input), EnvoyException, + EXPECT_THROW_WITH_MESSAGE(new LcTrie(ip_tags_input), EnvoyException, "The input vector has '524288' CIDR range entries. " "LC-Trie can only support '262144' CIDR ranges with " "the specified fill factor."); @@ -339,7 +361,7 @@ TEST_F(LcTrieTest, MaximumEntriesExceptionOverride) { std::pair> ip_tag = std::make_pair("bad_tag", prefixes); std::vector>> ip_tags_input{ip_tag}; - EXPECT_THROW_WITH_MESSAGE(new LcTrie(ip_tags_input, 0.01), EnvoyException, + EXPECT_THROW_WITH_MESSAGE(new LcTrie(ip_tags_input, false, 0.01), EnvoyException, "The input vector has '8192' CIDR range entries. " "LC-Trie can only support '5242' CIDR ranges with " "the specified fill factor."); diff --git a/test/common/network/socket_option_impl_test.cc b/test/common/network/socket_option_impl_test.cc index bb29ad08a9ee7..9f804c3b18f0a 100644 --- a/test/common/network/socket_option_impl_test.cc +++ b/test/common/network/socket_option_impl_test.cc @@ -8,8 +8,9 @@ class SocketOptionImplTest : public SocketOptionTest {}; TEST_F(SocketOptionImplTest, BadFd) { absl::string_view zero("\0\0\0\0", 4); - EXPECT_EQ(-1, SocketOptionImpl::setSocketOption(socket_, {}, zero)); - EXPECT_EQ(ENOTSUP, errno); + Api::SysCallResult result = SocketOptionImpl::setSocketOption(socket_, {}, zero); + EXPECT_EQ(-1, result.rc_); + EXPECT_EQ(ENOTSUP, result.errno_); } TEST_F(SocketOptionImplTest, SetOptionSuccessTrue) { diff --git a/test/common/protobuf/utility_test.cc b/test/common/protobuf/utility_test.cc index 92dd351b53b0b..15dbfa1a55f50 100644 --- a/test/common/protobuf/utility_test.cc +++ b/test/common/protobuf/utility_test.cc @@ -13,6 +13,15 @@ namespace Envoy { +TEST(UtilityTest, convertPercentNaN) { + envoy::api::v2::Cluster::CommonLbConfig common_config_; + common_config_.mutable_healthy_panic_threshold()->set_value( + std::numeric_limits::quiet_NaN()); + EXPECT_THROW(PROTOBUF_PERCENT_TO_ROUNDED_INTEGER_OR_DEFAULT(common_config_, + healthy_panic_threshold, 100, 50), + EnvoyException); +} + TEST(UtilityTest, RepeatedPtrUtilDebugString) { Protobuf::RepeatedPtrField repeated; EXPECT_EQ("[]", RepeatedPtrUtil::debugString(repeated)); @@ -211,6 +220,12 @@ TEST(UtilityTest, JsonConvertSuccess) { EXPECT_EQ(42, dest_duration.seconds()); } +TEST(UtilityTest, JsonConvertUnknownFieldSuccess) { + const ProtobufWkt::Struct obj = MessageUtil::keyValueStruct("test_key", "test_value"); + envoy::config::bootstrap::v2::Bootstrap bootstrap; + EXPECT_NO_THROW(MessageUtil::jsonConvert(obj, bootstrap)); +} + TEST(UtilityTest, JsonConvertFail) { ProtobufWkt::Duration source_duration; source_duration.set_seconds(-281474976710656); diff --git a/test/common/request_info/request_info_impl_test.cc b/test/common/request_info/request_info_impl_test.cc index 1dd67ac3250c0..687a6264a3c86 100644 --- a/test/common/request_info/request_info_impl_test.cc +++ b/test/common/request_info/request_info_impl_test.cc @@ -96,14 +96,22 @@ TEST(RequestInfoImplTest, ResponseFlagTest) { RateLimited}; RequestInfoImpl request_info(Http::Protocol::Http2); + EXPECT_FALSE(request_info.hasAnyResponseFlag()); + EXPECT_FALSE(request_info.intersectResponseFlags(0)); for (ResponseFlag flag : responseFlags) { // Test cumulative setting of response flags. - EXPECT_FALSE(request_info.getResponseFlag(flag)) + EXPECT_FALSE(request_info.hasResponseFlag(flag)) << fmt::format("Flag: {} was already set", flag); request_info.setResponseFlag(flag); - EXPECT_TRUE(request_info.getResponseFlag(flag)) + EXPECT_TRUE(request_info.hasResponseFlag(flag)) << fmt::format("Flag: {} was expected to be set", flag); } + EXPECT_TRUE(request_info.hasAnyResponseFlag()); + + RequestInfoImpl request_info2(Http::Protocol::Http2); + request_info2.setResponseFlag(FailedLocalHealthCheck); + + EXPECT_TRUE(request_info2.intersectResponseFlags(FailedLocalHealthCheck)); } TEST(RequestInfoImplTest, MiscSettersAndGetters) { diff --git a/test/common/request_info/utility_test.cc b/test/common/request_info/utility_test.cc index 47dba21103991..e9058a1f1ed5a 100644 --- a/test/common/request_info/utility_test.cc +++ b/test/common/request_info/utility_test.cc @@ -34,14 +34,14 @@ TEST(ResponseFlagUtilsTest, toShortStringConversion) { for (const auto& test_case : expected) { NiceMock request_info; - ON_CALL(request_info, getResponseFlag(test_case.first)).WillByDefault(Return(true)); + ON_CALL(request_info, hasResponseFlag(test_case.first)).WillByDefault(Return(true)); EXPECT_EQ(test_case.second, ResponseFlagUtils::toShortString(request_info)); } // No flag is set. { NiceMock request_info; - ON_CALL(request_info, getResponseFlag(_)).WillByDefault(Return(false)); + ON_CALL(request_info, hasResponseFlag(_)).WillByDefault(Return(false)); EXPECT_EQ("-", ResponseFlagUtils::toShortString(request_info)); } @@ -49,14 +49,42 @@ TEST(ResponseFlagUtilsTest, toShortStringConversion) { // These are not real use cases, but are used to cover multiple response flags case. { NiceMock request_info; - ON_CALL(request_info, getResponseFlag(ResponseFlag::DelayInjected)).WillByDefault(Return(true)); - ON_CALL(request_info, getResponseFlag(ResponseFlag::FaultInjected)).WillByDefault(Return(true)); - ON_CALL(request_info, getResponseFlag(ResponseFlag::UpstreamRequestTimeout)) + ON_CALL(request_info, hasResponseFlag(ResponseFlag::DelayInjected)).WillByDefault(Return(true)); + ON_CALL(request_info, hasResponseFlag(ResponseFlag::FaultInjected)).WillByDefault(Return(true)); + ON_CALL(request_info, hasResponseFlag(ResponseFlag::UpstreamRequestTimeout)) .WillByDefault(Return(true)); EXPECT_EQ("UT,DI,FI", ResponseFlagUtils::toShortString(request_info)); } } +TEST(ResponseFlagsUtilsTest, toResponseFlagConversion) { + static_assert(ResponseFlag::LastFlag == 0x1000, "A flag has been added. Fix this code."); + + std::vector> expected = { + std::make_pair("LH", ResponseFlag::FailedLocalHealthCheck), + std::make_pair("UH", ResponseFlag::NoHealthyUpstream), + std::make_pair("UT", ResponseFlag::UpstreamRequestTimeout), + std::make_pair("LR", ResponseFlag::LocalReset), + std::make_pair("UR", ResponseFlag::UpstreamRemoteReset), + std::make_pair("UF", ResponseFlag::UpstreamConnectionFailure), + std::make_pair("UC", ResponseFlag::UpstreamConnectionTermination), + std::make_pair("UO", ResponseFlag::UpstreamOverflow), + std::make_pair("NR", ResponseFlag::NoRouteFound), + std::make_pair("DI", ResponseFlag::DelayInjected), + std::make_pair("FI", ResponseFlag::FaultInjected), + std::make_pair("RL", ResponseFlag::RateLimited), + std::make_pair("UAEX", ResponseFlag::UnauthorizedExternalService), + }; + + EXPECT_FALSE(ResponseFlagUtils::toResponseFlag("NonExistentFlag").has_value()); + + for (const auto& test_case : expected) { + absl::optional response_flag = ResponseFlagUtils::toResponseFlag(test_case.first); + EXPECT_TRUE(response_flag.has_value()); + EXPECT_EQ(test_case.second, response_flag.value()); + } +} + TEST(UtilityTest, formatDownstreamAddressNoPort) { EXPECT_EQ("1.2.3.4", Utility::formatDownstreamAddressNoPort(Network::Address::Ipv4Instance("1.2.3.4"))); diff --git a/test/common/router/config_impl_test.cc b/test/common/router/config_impl_test.cc index 6160e603efe1a..eed552b9bc55b 100644 --- a/test/common/router/config_impl_test.cc +++ b/test/common/router/config_impl_test.cc @@ -47,7 +47,9 @@ Http::TestHeaderMapImpl genHeaders(const std::string& host, const std::string& p envoy::api::v2::RouteConfiguration parseRouteConfigurationFromJson(const std::string& json_string) { envoy::api::v2::RouteConfiguration route_config; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Envoy::Config::RdsJson::translateRouteConfiguration(*json_object_ptr, route_config); + Stats::StatsOptionsImpl stats_options; + Envoy::Config::RdsJson::translateRouteConfiguration(*json_object_ptr, route_config, + stats_options); return route_config; } @@ -530,7 +532,7 @@ TEST(RouteMatcherTest, TestRoutesWithInvalidRegex) { EnvoyException, "Invalid regex '\\^/\\(\\+invalid\\)':"); } -// Validates behavior of request_headers_to_add at router, vhost, and route levels. +// Validates behavior of request_headers_to_add at router, vhost, and route action levels. TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { std::string json = R"EOF( { @@ -550,14 +552,14 @@ TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { "request_headers_to_add": [ {"key": "x-global-header1", "value": "route-override"}, {"key": "x-vhost-header1", "value": "route-override"}, - {"key": "x-route-header", "value": "route-new_endpoint"} + {"key": "x-route-action-header", "value": "route-new_endpoint"} ] }, { "path": "/", "cluster": "root_www2", "request_headers_to_add": [ - {"key": "x-route-header", "value": "route-allpath"} + {"key": "x-route-action-header", "value": "route-allpath"} ] }, { @@ -577,7 +579,7 @@ TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { "prefix": "/", "cluster": "www2_staging", "request_headers_to_add": [ - {"key": "x-route-header", "value": "route-allprefix"} + {"key": "x-route-action-header", "value": "route-allprefix"} ] } ] @@ -627,7 +629,7 @@ TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { route->finalizeRequestHeaders(headers, request_info, true); EXPECT_EQ("route-override", headers.get_("x-global-header1")); EXPECT_EQ("route-override", headers.get_("x-vhost-header1")); - EXPECT_EQ("route-new_endpoint", headers.get_("x-route-header")); + EXPECT_EQ("route-new_endpoint", headers.get_("x-route-action-header")); } // Multiple routes can have same route-level headers with different values. @@ -637,7 +639,7 @@ TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { route->finalizeRequestHeaders(headers, request_info, true); EXPECT_EQ("vhost-override", headers.get_("x-global-header1")); EXPECT_EQ("vhost1-www2", headers.get_("x-vhost-header1")); - EXPECT_EQ("route-allpath", headers.get_("x-route-header")); + EXPECT_EQ("route-allpath", headers.get_("x-route-action-header")); } // Multiple virtual hosts can have same virtual host level headers with different values. @@ -647,7 +649,7 @@ TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { route->finalizeRequestHeaders(headers, request_info, true); EXPECT_EQ("global1", headers.get_("x-global-header1")); EXPECT_EQ("vhost1-www2_staging", headers.get_("x-vhost-header1")); - EXPECT_EQ("route-allprefix", headers.get_("x-route-header")); + EXPECT_EQ("route-allprefix", headers.get_("x-route-action-header")); } // Global headers. @@ -660,8 +662,8 @@ TEST(RouteMatcherTest, TestAddRemoveRequestHeaders) { } } -// Validates behavior of request_headers_to_add at router, vhost, and route levels when append -// is disabled. +// Validates behavior of request_headers_to_add at router, vhost, route, and route action levels +// when append is disabled. TEST(RouteMatcherTest, TestRequestHeadersToAddWithAppendFalse) { std::string yaml = R"EOF( name: foo @@ -679,20 +681,36 @@ name: foo append: false routes: - match: { prefix: "/endpoint" } + request_headers_to_add: + - header: + key: x-global-header + value: route-endpoint + append: false + - header: + key: x-vhost-header + value: route-endpoint + append: false + - header: + key: x-route-header + value: route-endpoint + append: false route: cluster: www2 request_headers_to_add: - header: key: x-global-header - value: route-endpoint + value: route-action-endpoint append: false - header: key: x-vhost-header - value: route-endpoint + value: route-action-endpoint append: false - header: key: x-route-header - value: route-endpoint + value: route-action-endpoint + - header: + key: x-route-action-header + value: route-action-endpoint append: false - match: { prefix: "/" } route: { cluster: www2 } @@ -712,7 +730,7 @@ name: foo // Request header manipulation testing. { - // Global and virtual host override route. + // Global and virtual host override route, route overrides route action. { Http::TestHeaderMapImpl headers = genHeaders("www.lyft.com", "/endpoint", "GET"); const RouteEntry* route = config.route(headers, 0)->routeEntry(); @@ -720,6 +738,7 @@ name: foo EXPECT_EQ("global", headers.get_("x-global-header")); EXPECT_EQ("vhost-www2", headers.get_("x-vhost-header")); EXPECT_EQ("route-endpoint", headers.get_("x-route-header")); + EXPECT_EQ("route-action-endpoint", headers.get_("x-route-action-header")); } // Global overrides virtual host. @@ -734,7 +753,7 @@ name: foo } // Validates behavior of response_headers_to_add and response_headers_to_remove at router, vhost, -// and route levels. +// route, and route action levels. TEST(RouteMatcherTest, TestAddRemoveResponseHeaders) { std::string yaml = R"EOF( name: foo @@ -751,6 +770,10 @@ name: foo response_headers_to_remove: ["x-vhost-remove"] routes: - match: { prefix: "/new_endpoint" } + response_headers_to_add: + - header: + key: x-route-header + value: route-override route: prefix_rewrite: "/api/new_endpoint" cluster: www2 @@ -762,14 +785,14 @@ name: foo key: x-vhost-header1 value: route-override - header: - key: x-route-header + key: x-route-action-header value: route-new_endpoint - match: { path: "/" } route: cluster: root_www2 response_headers_to_add: - header: - key: x-route-header + key: x-route-action-header value: route-allpath response_headers_to_remove: ["x-route-remove"] - match: { prefix: "/" } @@ -786,7 +809,7 @@ name: foo cluster: www2_staging response_headers_to_add: - header: - key: x-route-header + key: x-route-action-header value: route-allprefix - name: default domains: ["*"] @@ -815,7 +838,8 @@ response_headers_to_remove: ["x-global-remove"] route->finalizeResponseHeaders(headers, request_info); EXPECT_EQ("route-override", headers.get_("x-global-header1")); EXPECT_EQ("route-override", headers.get_("x-vhost-header1")); - EXPECT_EQ("route-new_endpoint", headers.get_("x-route-header")); + EXPECT_EQ("route-new_endpoint", headers.get_("x-route-action-header")); + EXPECT_EQ("route-override", headers.get_("x-route-header")); } // Multiple routes can have same route-level headers with different values. @@ -826,7 +850,7 @@ response_headers_to_remove: ["x-global-remove"] route->finalizeResponseHeaders(headers, request_info); EXPECT_EQ("vhost-override", headers.get_("x-global-header1")); EXPECT_EQ("vhost1-www2", headers.get_("x-vhost-header1")); - EXPECT_EQ("route-allpath", headers.get_("x-route-header")); + EXPECT_EQ("route-allpath", headers.get_("x-route-action-header")); } // Multiple virtual hosts can have same virtual host level headers with different values. @@ -837,7 +861,7 @@ response_headers_to_remove: ["x-global-remove"] route->finalizeResponseHeaders(headers, request_info); EXPECT_EQ("global1", headers.get_("x-global-header1")); EXPECT_EQ("vhost1-www2_staging", headers.get_("x-vhost-header1")); - EXPECT_EQ("route-allprefix", headers.get_("x-route-header")); + EXPECT_EQ("route-allprefix", headers.get_("x-route-action-header")); } // Global headers. @@ -1074,7 +1098,7 @@ TEST(RouteMatcherTest, InvalidHeaderMatchedRoutingConfig) { prefix: "/" headers: - name: test_header - value: "(+not a regex)" + exact_match: "(+not a regex)" route: { cluster: "local_service" } )EOF"; @@ -1087,8 +1111,7 @@ TEST(RouteMatcherTest, InvalidHeaderMatchedRoutingConfig) { prefix: "/" headers: - name: test_header - value: "(+invalid regex)" - regex: true + regex_match: "(+invalid regex)" route: { cluster: "local_service" } )EOF"; @@ -4372,9 +4395,70 @@ name: RegexNoMatch } } +TEST(RouteConfigurationV2, NoIdleTimeout) { + const std::string NoIdleTimeot = R"EOF( +name: NoIdleTimeout +virtual_hosts: + - name: regex + domains: [idle.lyft.com] + routes: + - match: { regex: "/regex"} + route: + cluster: some-cluster + )EOF"; + + NiceMock factory_context; + ConfigImpl config(parseRouteConfigurationFromV2Yaml(NoIdleTimeot), factory_context, true); + Http::TestHeaderMapImpl headers = genRedirectHeaders("idle.lyft.com", "/regex", true, false); + const RouteEntry* route_entry = config.route(headers, 0)->routeEntry(); + EXPECT_EQ(absl::nullopt, route_entry->idleTimeout()); +} + +TEST(RouteConfigurationV2, ZeroIdleTimeout) { + const std::string ZeroIdleTimeot = R"EOF( +name: ZeroIdleTimeout +virtual_hosts: + - name: regex + domains: [idle.lyft.com] + routes: + - match: { regex: "/regex"} + route: + cluster: some-cluster + idle_timeout: 0s + )EOF"; + + NiceMock factory_context; + ConfigImpl config(parseRouteConfigurationFromV2Yaml(ZeroIdleTimeot), factory_context, true); + Http::TestHeaderMapImpl headers = genRedirectHeaders("idle.lyft.com", "/regex", true, false); + const RouteEntry* route_entry = config.route(headers, 0)->routeEntry(); + EXPECT_EQ(0, route_entry->idleTimeout().value().count()); +} + +TEST(RouteConfigurationV2, ExplicitIdleTimeout) { + const std::string ExplicitIdleTimeot = R"EOF( +name: ExplicitIdleTimeout +virtual_hosts: + - name: regex + domains: [idle.lyft.com] + routes: + - match: { regex: "/regex"} + route: + cluster: some-cluster + idle_timeout: 7s + )EOF"; + + NiceMock factory_context; + ConfigImpl config(parseRouteConfigurationFromV2Yaml(ExplicitIdleTimeot), factory_context, true); + Http::TestHeaderMapImpl headers = genRedirectHeaders("idle.lyft.com", "/regex", true, false); + const RouteEntry* route_entry = config.route(headers, 0)->routeEntry(); + EXPECT_EQ(7 * 1000, route_entry->idleTimeout().value().count()); +} + class PerFilterConfigsTest : public testing::Test { public: - PerFilterConfigsTest() : factory_(), registered_factory_(factory_) {} + PerFilterConfigsTest() + : factory_(), registered_factory_(factory_), default_factory_(), + registered_default_factory_(default_factory_) {} struct DerivedFilterConfig : public RouteSpecificFilterConfig { ProtobufWkt::Timestamp config_; @@ -4385,7 +4469,7 @@ class PerFilterConfigsTest : public testing::Test { Http::FilterFactoryCb createFilter(const std::string&, Server::Configuration::FactoryContext&) override { - NOT_IMPLEMENTED; + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } ProtobufTypes::MessagePtr createEmptyRouteConfigProto() override { return ProtobufTypes::MessagePtr{new ProtobufWkt::Timestamp()}; @@ -4398,6 +4482,15 @@ class PerFilterConfigsTest : public testing::Test { return obj; } }; + class DefaultTestFilterConfig : public Extensions::HttpFilters::Common::EmptyHttpFilterConfig { + public: + DefaultTestFilterConfig() : EmptyHttpFilterConfig("test.default.filter") {} + + Http::FilterFactoryCb createFilter(const std::string&, + Server::Configuration::FactoryContext&) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + }; void checkEach(const std::string& yaml, uint32_t expected_entry, uint32_t expected_route, uint32_t expected_vhost) { @@ -4421,8 +4514,24 @@ class PerFilterConfigsTest : public testing::Test { << "config value does not match expected for source: " << source; } + void checkNoPerFilterConfig(const std::string& yaml) { + const ConfigImpl config(parseRouteConfigurationFromV2Yaml(yaml), factory_context_, true); + + const auto route = config.route(genHeaders("www.foo.com", "/", "GET"), 0); + const auto* route_entry = route->routeEntry(); + const auto& vhost = route_entry->virtualHost(); + + EXPECT_EQ(nullptr, + route_entry->perFilterConfigTyped(default_factory_.name())); + EXPECT_EQ(nullptr, route->perFilterConfigTyped(default_factory_.name())); + EXPECT_EQ(nullptr, vhost.perFilterConfigTyped(default_factory_.name())); + } + TestFilterConfig factory_; Registry::InjectFactory registered_factory_; + DefaultTestFilterConfig default_factory_; + Registry::InjectFactory + registered_default_factory_; NiceMock factory_context_; }; @@ -4443,6 +4552,23 @@ name: foo "Didn't find a registered implementation for name: 'unknown.filter'"); } +// Test that a trivially specified NamedHttpFilterConfigFactory ignores per_filter_config without +// error. +TEST_F(PerFilterConfigsTest, DefaultFilterImplementation) { + std::string yaml = R"EOF( +name: foo +virtual_hosts: + - name: bar + domains: ["*"] + routes: + - match: { prefix: "/" } + route: { cluster: baz } + per_filter_config: { test.default.filter: { unknown_key: 123} } +)EOF"; + + checkNoPerFilterConfig(yaml); +} + TEST_F(PerFilterConfigsTest, RouteLocalConfig) { std::string yaml = R"EOF( name: foo @@ -4498,6 +4624,7 @@ name: foo checkEach(yaml, 1213, 1213, 1415); } + } // namespace } // namespace Router } // namespace Envoy diff --git a/test/common/router/header_formatter_test.cc b/test/common/router/header_formatter_test.cc index 1e80c0bb37e0a..000549573560e 100644 --- a/test/common/router/header_formatter_test.cc +++ b/test/common/router/header_formatter_test.cc @@ -83,8 +83,9 @@ TEST_F(RequestInfoHeaderFormatterTest, TestFormatWithUpstreamMetadataVariable) { std::shared_ptr> host( new NiceMock()); - envoy::api::v2::core::Metadata metadata = TestUtility::parseYaml( - R"EOF( + auto metadata = std::make_shared( + TestUtility::parseYaml( + R"EOF( filter_metadata: namespace: key: value @@ -99,11 +100,11 @@ TEST_F(RequestInfoHeaderFormatterTest, TestFormatWithUpstreamMetadataVariable) { list_key: [ list_element ] struct_key: deep_key: deep_value - )EOF"); + )EOF")); // Prove we're testing the expected types. const auto& nested_struct = - Envoy::Config::Metadata::metadataValue(metadata, "namespace", "nested").struct_value(); + Envoy::Config::Metadata::metadataValue(*metadata, "namespace", "nested").struct_value(); EXPECT_EQ(nested_struct.fields().at("str_key").kind_case(), ProtobufWkt::Value::kStringValue); EXPECT_EQ(nested_struct.fields().at("bool_key1").kind_case(), ProtobufWkt::Value::kBoolValue); EXPECT_EQ(nested_struct.fields().at("bool_key2").kind_case(), ProtobufWkt::Value::kBoolValue); @@ -114,7 +115,7 @@ TEST_F(RequestInfoHeaderFormatterTest, TestFormatWithUpstreamMetadataVariable) { EXPECT_EQ(nested_struct.fields().at("struct_key").kind_case(), ProtobufWkt::Value::kStructValue); ON_CALL(request_info, upstreamHost()).WillByDefault(Return(host)); - ON_CALL(*host, metadata()).WillByDefault(ReturnRef(metadata)); + ON_CALL(*host, metadata()).WillByDefault(Return(metadata)); // Top-level value. testFormatting(request_info, "UPSTREAM_METADATA([\"namespace\", \"key\"])", "value"); @@ -331,13 +332,14 @@ TEST(HeaderParserTest, TestParseInternal) { ON_CALL(request_info, upstreamHost()).WillByDefault(Return(host)); // Metadata with percent signs in the key. - envoy::api::v2::core::Metadata metadata = TestUtility::parseYaml( - R"EOF( + auto metadata = std::make_shared( + TestUtility::parseYaml( + R"EOF( filter_metadata: ns: key: value - )EOF"); - ON_CALL(*host, metadata()).WillByDefault(ReturnRef(metadata)); + )EOF")); + ON_CALL(*host, metadata()).WillByDefault(Return(metadata)); // "2018-04-03T23:06:09.123Z". const SystemTime start_time(std::chrono::milliseconds(1522796769123)); @@ -411,9 +413,9 @@ TEST(HeaderParserTest, EvaluateEmptyHeaders) { std::shared_ptr> host( new NiceMock()); NiceMock request_info; - envoy::api::v2::core::Metadata metadata; + auto metadata = std::make_shared(); ON_CALL(request_info, upstreamHost()).WillByDefault(Return(host)); - ON_CALL(*host, metadata()).WillByDefault(ReturnRef(metadata)); + ON_CALL(*host, metadata()).WillByDefault(Return(metadata)); req_header_parser->evaluateHeaders(headerMap, request_info); EXPECT_FALSE(headerMap.has("x-key")); } @@ -485,13 +487,14 @@ match: { prefix: "/new_endpoint" } ON_CALL(request_info, upstreamHost()).WillByDefault(Return(host)); // Metadata with percent signs in the key. - envoy::api::v2::core::Metadata metadata = TestUtility::parseYaml( - R"EOF( + auto metadata = std::make_shared( + TestUtility::parseYaml( + R"EOF( filter_metadata: namespace: "%key%": value - )EOF"); - ON_CALL(*host, metadata()).WillByDefault(ReturnRef(metadata)); + )EOF")); + ON_CALL(*host, metadata()).WillByDefault(Return(metadata)); req_header_parser->evaluateHeaders(headerMap, request_info); diff --git a/test/common/router/rds_impl_test.cc b/test/common/router/rds_impl_test.cc index 9acd91d2960a9..c0ff21dd58b9e 100644 --- a/test/common/router/rds_impl_test.cc +++ b/test/common/router/rds_impl_test.cc @@ -35,12 +35,12 @@ namespace Router { namespace { envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager -parseHttpConnectionManagerFromJson(const std::string& json_string) { +parseHttpConnectionManagerFromJson(const std::string& json_string, const Stats::Scope& scope) { envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager http_connection_manager; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Envoy::Config::FilterJson::translateHttpConnectionManager(*json_object_ptr, - http_connection_manager); + Envoy::Config::FilterJson::translateHttpConnectionManager( + *json_object_ptr, http_connection_manager, scope.statsOptions()); return http_connection_manager; } @@ -107,15 +107,16 @@ class RdsImplTest : public RdsTestBase { interval_timer_ = new Event::MockTimer(&factory_context_.dispatcher_); EXPECT_CALL(factory_context_.init_manager_, registerTarget(_)); rds_ = - RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json), + RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json, scope_), factory_context_, "foo.", *route_config_provider_manager_); expectRequest(); factory_context_.init_manager_.initialize(); } + NiceMock scope_; NiceMock server_; std::unique_ptr route_config_provider_manager_; - RouteConfigProviderSharedPtr rds_; + RouteConfigProviderPtr rds_; }; TEST_F(RdsImplTest, RdsAndStatic) { @@ -131,10 +132,10 @@ TEST_F(RdsImplTest, RdsAndStatic) { } )EOF"; - EXPECT_THROW(RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json), - factory_context_, "foo.", - *route_config_provider_manager_), - EnvoyException); + EXPECT_THROW( + RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json, scope_), + factory_context_, "foo.", *route_config_provider_manager_), + EnvoyException); } TEST_F(RdsImplTest, LocalInfoNotDefined) { @@ -154,10 +155,10 @@ TEST_F(RdsImplTest, LocalInfoNotDefined) { factory_context_.local_info_.node_.set_cluster(""); factory_context_.local_info_.node_.set_id(""); - EXPECT_THROW(RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json), - factory_context_, "foo.", - *route_config_provider_manager_), - EnvoyException); + EXPECT_THROW( + RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json, scope_), + factory_context_, "foo.", *route_config_provider_manager_), + EnvoyException); } TEST_F(RdsImplTest, UnknownCluster) { @@ -178,7 +179,7 @@ TEST_F(RdsImplTest, UnknownCluster) { Upstream::ClusterManager::ClusterInfoMap cluster_map; EXPECT_CALL(factory_context_.cluster_manager_, clusters()).WillOnce(Return(cluster_map)); EXPECT_THROW_WITH_MESSAGE( - RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json), + RouteConfigProviderUtil::create(parseHttpConnectionManagerFromJson(config_json, scope_), factory_context_, "foo.", *route_config_provider_manager_), EnvoyException, "envoy::api::v2::core::ConfigSource must have a statically defined non-EDS " @@ -362,7 +363,7 @@ class RouteConfigProviderManagerImplTest : public RdsTestBase { )EOF"; Json::ObjectSharedPtr config = Json::Factory::loadFromString(config_json); - Envoy::Config::Utility::translateRdsConfig(*config, rds_); + Envoy::Config::Utility::translateRdsConfig(*config, rds_, stats_options_); // Get a RouteConfigProvider. This one should create an entry in the RouteConfigProviderManager. Upstream::ClusterManager::ClusterInfoMap cluster_map; @@ -374,8 +375,8 @@ class RouteConfigProviderManagerImplTest : public RdsTestBase { EXPECT_CALL(*cluster.info_, addedViaApi()); EXPECT_CALL(*cluster.info_, type()); interval_timer_ = new Event::MockTimer(&factory_context_.dispatcher_); - provider_ = route_config_provider_manager_->getRdsRouteConfigProvider(rds_, factory_context_, - "foo_prefix."); + provider_ = route_config_provider_manager_->createRdsRouteConfigProvider(rds_, factory_context_, + "foo_prefix."); } RouteConfigProviderManagerImplTest() { @@ -386,9 +387,10 @@ class RouteConfigProviderManagerImplTest : public RdsTestBase { ~RouteConfigProviderManagerImplTest() { factory_context_.thread_local_.shutdownThread(); } + Stats::StatsOptionsImpl stats_options_; envoy::config::filter::network::http_connection_manager::v2::Rds rds_; std::unique_ptr route_config_provider_manager_; - RouteConfigProviderSharedPtr provider_; + RouteConfigProviderPtr provider_; }; envoy::api::v2::RouteConfiguration parseRouteConfigurationFromV2Yaml(const std::string& yaml) { @@ -427,8 +429,8 @@ name: foo EXPECT_CALL(factory_context_, systemTimeSource()).WillRepeatedly(ReturnRef(system_time_source_)); // Only static route. - RouteConfigProviderSharedPtr static_config = - route_config_provider_manager_->getStaticRouteConfigProvider( + RouteConfigProviderPtr static_config = + route_config_provider_manager_->createStaticRouteConfigProvider( parseRouteConfigurationFromV2Yaml(config_yaml), factory_context_); message_ptr = factory_context_.admin_.config_tracker_.config_tracker_callbacks_["routes"](); const auto& route_config_dump2 = @@ -506,14 +508,34 @@ TEST_F(RouteConfigProviderManagerImplTest, Basic) { // Get a RouteConfigProvider. This one should create an entry in the RouteConfigProviderManager. setup(); - // Because this get has the same cluster and route_config_name, the provider returned is just a - // shared_ptr to the same provider as the one above. - RouteConfigProviderSharedPtr provider2 = - route_config_provider_manager_->getRdsRouteConfigProvider(rds_, factory_context_, - "foo_prefix"); - // So this means that both shared_ptrs should be the same. - EXPECT_EQ(provider_, provider2); - EXPECT_EQ(2UL, provider_.use_count()); + EXPECT_FALSE(provider_->configInfo().has_value()); + + Protobuf::RepeatedPtrField route_configs; + route_configs.Add()->MergeFrom(parseRouteConfigurationFromV2Yaml(R"EOF( +name: foo_route_config +virtual_hosts: + - name: bar + domains: ["*"] + routes: + - match: { prefix: "/" } + route: { cluster: baz } +)EOF")); + + RdsRouteConfigSubscription& subscription = + dynamic_cast(*provider_).subscription(); + + subscription.onConfigUpdate(route_configs, "1"); + + RouteConfigProviderPtr provider2 = route_config_provider_manager_->createRdsRouteConfigProvider( + rds_, factory_context_, "foo_prefix"); + + // provider2 should have route config immediately after create + EXPECT_TRUE(provider2->configInfo().has_value()); + + // So this means that both provider have same subscription. + EXPECT_EQ(&dynamic_cast(*provider_).subscription(), + &dynamic_cast(*provider2).subscription()); + EXPECT_EQ(&provider_->configInfo().value().config_, &provider2->configInfo().value().config_); std::string config_json2 = R"EOF( { @@ -525,7 +547,7 @@ TEST_F(RouteConfigProviderManagerImplTest, Basic) { Json::ObjectSharedPtr config2 = Json::Factory::loadFromString(config_json2); envoy::config::filter::network::http_connection_manager::v2::Rds rds2; - Envoy::Config::Utility::translateRdsConfig(*config2, rds2); + Envoy::Config::Utility::translateRdsConfig(*config2, rds2, stats_options_); Upstream::ClusterManager::ClusterInfoMap cluster_map; Upstream::MockCluster cluster; @@ -535,34 +557,32 @@ TEST_F(RouteConfigProviderManagerImplTest, Basic) { EXPECT_CALL(*cluster.info_, addedViaApi()); EXPECT_CALL(*cluster.info_, type()); new Event::MockTimer(&factory_context_.dispatcher_); - RouteConfigProviderSharedPtr provider3 = - route_config_provider_manager_->getRdsRouteConfigProvider(rds2, factory_context_, - "foo_prefix"); + RouteConfigProviderPtr provider3 = route_config_provider_manager_->createRdsRouteConfigProvider( + rds2, factory_context_, "foo_prefix"); EXPECT_NE(provider3, provider_); - EXPECT_EQ(2UL, provider_.use_count()); - EXPECT_EQ(1UL, provider3.use_count()); + dynamic_cast(*provider3) + .subscription() + .onConfigUpdate(route_configs, "provider3"); - std::vector configured_providers = - route_config_provider_manager_->getRdsRouteConfigProviders(); - EXPECT_EQ(2UL, configured_providers.size()); - EXPECT_EQ(3UL, provider_.use_count()); - EXPECT_EQ(2UL, provider3.use_count()); + EXPECT_EQ(2UL, + route_config_provider_manager_->dumpRouteConfigs()->dynamic_route_configs().size()); provider_.reset(); provider2.reset(); - configured_providers.clear(); // All shared_ptrs to the provider pointed at by provider1, and provider2 have been deleted, so // now we should only have the provider pointed at by provider3. - configured_providers = route_config_provider_manager_->getRdsRouteConfigProviders(); - EXPECT_EQ(1UL, configured_providers.size()); - EXPECT_EQ(provider3, configured_providers.front()); + auto dynamic_route_configs = + route_config_provider_manager_->dumpRouteConfigs()->dynamic_route_configs(); + EXPECT_EQ(1UL, dynamic_route_configs.size()); + + // Make sure the left one is provider3 + EXPECT_EQ("provider3", dynamic_route_configs[0].version_info()); provider3.reset(); - configured_providers.clear(); - configured_providers = route_config_provider_manager_->getRdsRouteConfigProviders(); - EXPECT_EQ(0UL, configured_providers.size()); + EXPECT_EQ(0UL, + route_config_provider_manager_->dumpRouteConfigs()->dynamic_route_configs().size()); } // Negative test for protoc-gen-validate constraints. @@ -573,7 +593,8 @@ TEST_F(RouteConfigProviderManagerImplTest, ValidateFail) { auto* route_config = route_configs.Add(); route_config->set_name("foo_route_config"); route_config->mutable_virtual_hosts()->Add(); - EXPECT_THROW(provider_impl.onConfigUpdate(route_configs, ""), ProtoValidationException); + EXPECT_THROW(provider_impl.subscription().onConfigUpdate(route_configs, ""), + ProtoValidationException); } TEST_F(RouteConfigProviderManagerImplTest, onConfigUpdateEmpty) { @@ -581,7 +602,7 @@ TEST_F(RouteConfigProviderManagerImplTest, onConfigUpdateEmpty) { factory_context_.init_manager_.initialize(); auto& provider_impl = dynamic_cast(*provider_.get()); EXPECT_CALL(factory_context_.init_manager_.initialized_, ready()); - provider_impl.onConfigUpdate({}, ""); + provider_impl.subscription().onConfigUpdate({}, ""); EXPECT_EQ( 1UL, factory_context_.scope_.counter("foo_prefix.rds.foo_route_config.update_empty").value()); } @@ -594,8 +615,8 @@ TEST_F(RouteConfigProviderManagerImplTest, onConfigUpdateWrongSize) { route_configs.Add(); route_configs.Add(); EXPECT_CALL(factory_context_.init_manager_.initialized_, ready()); - EXPECT_THROW_WITH_MESSAGE(provider_impl.onConfigUpdate(route_configs, ""), EnvoyException, - "Unexpected RDS resource length: 2"); + EXPECT_THROW_WITH_MESSAGE(provider_impl.subscription().onConfigUpdate(route_configs, ""), + EnvoyException, "Unexpected RDS resource length: 2"); } } // namespace diff --git a/test/common/router/router_ratelimit_test.cc b/test/common/router/router_ratelimit_test.cc index e2bb39813bcd3..5142ae43d6fcf 100644 --- a/test/common/router/router_ratelimit_test.cc +++ b/test/common/router/router_ratelimit_test.cc @@ -102,7 +102,8 @@ class RateLimitConfiguration : public testing::Test { void SetUpTest(const std::string json) { envoy::api::v2::RouteConfiguration route_config; auto json_object_ptr = Json::Factory::loadFromString(json); - Envoy::Config::RdsJson::translateRouteConfiguration(*json_object_ptr, route_config); + Envoy::Config::RdsJson::translateRouteConfiguration(*json_object_ptr, route_config, + stats_options); config_.reset(new ConfigImpl(route_config, factory_context_, true)); } @@ -111,6 +112,7 @@ class RateLimitConfiguration : public testing::Test { Http::TestHeaderMapImpl header_; const RouteEntry* route_; Network::Address::Ipv4Instance default_remote_address_{"10.0.0.1"}; + Stats::StatsOptionsImpl stats_options; }; TEST_F(RateLimitConfiguration, NoApplicableRateLimit) { diff --git a/test/common/ssl/context_impl_test.cc b/test/common/ssl/context_impl_test.cc index de82f328a7570..8efaad4f499fa 100644 --- a/test/common/ssl/context_impl_test.cc +++ b/test/common/ssl/context_impl_test.cc @@ -99,7 +99,7 @@ TEST_F(SslContextImplTest, TestExpiringCert) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); // This is a total hack, but right now we generate the cert and it expires in 15 days only in the // first second that it's valid. This can become invalid and then cause slower tests to fail. @@ -122,7 +122,7 @@ TEST_F(SslContextImplTest, TestExpiredCert) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); EXPECT_EQ(0U, context->daysUntilFirstCertExpires()); } @@ -141,7 +141,7 @@ TEST_F(SslContextImplTest, TestGetCertInformation) { ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); // This is similar to the hack above, but right now we generate the ca_cert and it expires in 15 // days only in the first second that it's valid. We will partially match for up until Days until // Expiration: 1. @@ -166,7 +166,7 @@ TEST_F(SslContextImplTest, TestNoCert) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ClientContextPtr context(manager.createSslClientContext(store, cfg)); + ClientContextSharedPtr context(manager.createSslClientContext(store, cfg)); EXPECT_EQ("", context->getCaCertInformation()); EXPECT_EQ("", context->getCertChainInformation()); } @@ -178,7 +178,7 @@ class SslServerContextImplTicketTest : public SslContextImplTest { Secret::MockSecretManager secret_manager; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - ServerContextPtr server_ctx( + ServerContextSharedPtr server_ctx( manager.createSslServerContext(store, cfg, std::vector{})); } @@ -500,7 +500,7 @@ TEST(ServerContextImplTest, TlsCertificateNonEmpty) { Runtime::MockLoader runtime; ContextManagerImpl manager(runtime); Stats::IsolatedStoreImpl store; - EXPECT_THROW_WITH_MESSAGE(ServerContextPtr server_ctx(manager.createSslServerContext( + EXPECT_THROW_WITH_MESSAGE(ServerContextSharedPtr server_ctx(manager.createSslServerContext( store, client_context_config, std::vector{})), EnvoyException, "Server TlsCertificates must have a certificate specified"); diff --git a/test/common/ssl/ssl_socket_test.cc b/test/common/ssl/ssl_socket_test.cc index 095600f02cfe2..35925d673ef75 100644 --- a/test/common/ssl/ssl_socket_test.cc +++ b/test/common/ssl/ssl_socket_test.cc @@ -46,10 +46,10 @@ namespace { void testUtil(const std::string& client_ctx_json, const std::string& server_ctx_json, const std::string& expected_digest, const std::string& expected_uri, - const std::string& expected_local_uri, const std::string& expected_subject, - const std::string& expected_local_subject, const std::string& expected_peer_cert, - const std::string& expected_stats, bool expect_success, - const Network::Address::IpVersion version) { + const std::string& expected_local_uri, const std::string& expected_serial_number, + const std::string& expected_subject, const std::string& expected_local_subject, + const std::string& expected_peer_cert, const std::string& expected_stats, + bool expect_success, const Network::Address::IpVersion version) { Stats::IsolatedStoreImpl stats_store; Runtime::MockLoader runtime; Secret::MockSecretManager secret_manager; @@ -101,6 +101,8 @@ void testUtil(const std::string& client_ctx_json, const std::string& server_ctx_ if (!expected_local_uri.empty()) { EXPECT_EQ(expected_local_uri, server_connection->ssl()->uriSanLocalCertificate()); } + EXPECT_EQ(expected_serial_number, + server_connection->ssl()->serialNumberPeerCertificate()); if (!expected_subject.empty()) { EXPECT_EQ(expected_subject, server_connection->ssl()->subjectPeerCertificate()); } @@ -140,6 +142,7 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, const std::string& expected_protocol_version, const std::string& expected_server_cert_digest, const std::string& expected_client_cert_uri, + const std::string& expected_requested_server_name, const std::string& expected_alpn_protocol, const std::string& expected_stats, unsigned expected_stats_value, const Network::Address::IpVersion version) { @@ -167,7 +170,7 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, ClientContextConfigImpl client_ctx_config(client_ctx_proto, secret_manager); ClientSslSocketFactory client_ssl_socket_factory(client_ctx_config, manager, stats_store); - ClientContextPtr client_ctx(manager.createSslClientContext(stats_store, client_ctx_config)); + ClientContextSharedPtr client_ctx(manager.createSslClientContext(stats_store, client_ctx_config)); Network::ClientConnectionPtr client_connection = dispatcher.createClientConnection( socket.localAddress(), Network::Address::InstanceConstSharedPtr(), client_ssl_socket_factory.createTransportSocket(), nullptr); @@ -188,6 +191,7 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, Network::MockConnectionCallbacks server_connection_callbacks; EXPECT_CALL(callbacks, onAccept_(_, _)) .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket, bool) -> void { + socket->setRequestedServerName(client_ctx_proto.sni()); Network::ConnectionPtr new_connection = dispatcher.createServerConnection( std::move(socket), server_ssl_socket_factory.createTransportSocket()); callbacks.onNewConnection(std::move(new_connection)); @@ -236,7 +240,10 @@ const std::string testUtilV2(const envoy::api::v2::Listener& server_proto, EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { stopSecondTime(); })); EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) - .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { stopSecondTime(); })); + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { + EXPECT_EQ(expected_requested_server_name, server_connection->requestedServerName()); + stopSecondTime(); + })); EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); } else { @@ -309,8 +316,8 @@ TEST_P(SslSocketTest, GetCertDigest) { )EOF"; testUtil(client_ctx_json, server_ctx_json, - "4444fbca965d916475f04fb4dd234dd556adb028ceb4300fa8ad6f2983c6aaa3", "", "", "", "", "", - "ssl.handshake", true, GetParam()); + "4444fbca965d916475f04fb4dd234dd556adb028ceb4300fa8ad6f2983c6aaa3", "", "", + "f3828eb24fd779cf", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, GetCertDigestInline) { @@ -448,7 +455,7 @@ TEST_P(SslSocketTest, GetCertDigestInline) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client_ctx, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, GetCertDigestServerCertWithIntermediateCA) { @@ -469,8 +476,8 @@ TEST_P(SslSocketTest, GetCertDigestServerCertWithIntermediateCA) { )EOF"; testUtil(client_ctx_json, server_ctx_json, - "4444fbca965d916475f04fb4dd234dd556adb028ceb4300fa8ad6f2983c6aaa3", "", "", "", "", "", - "ssl.handshake", true, GetParam()); + "4444fbca965d916475f04fb4dd234dd556adb028ceb4300fa8ad6f2983c6aaa3", "", "", + "f3828eb24fd779cf", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, GetCertDigestServerCertWithoutCommonName) { @@ -490,8 +497,8 @@ TEST_P(SslSocketTest, GetCertDigestServerCertWithoutCommonName) { )EOF"; testUtil(client_ctx_json, server_ctx_json, - "4444fbca965d916475f04fb4dd234dd556adb028ceb4300fa8ad6f2983c6aaa3", "", "", "", "", "", - "ssl.handshake", true, GetParam()); + "4444fbca965d916475f04fb4dd234dd556adb028ceb4300fa8ad6f2983c6aaa3", "", "", + "f3828eb24fd779cf", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, GetUriWithUriSan) { @@ -511,8 +518,8 @@ TEST_P(SslSocketTest, GetUriWithUriSan) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "spiffe://lyft.com/test-team", "", "", "", "", - "ssl.handshake", true, GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "spiffe://lyft.com/test-team", "", + "de8a932ffc6a57da", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, GetNoUriWithDnsSan) { @@ -532,8 +539,8 @@ TEST_P(SslSocketTest, GetNoUriWithDnsSan) { )EOF"; // The SAN field only has DNS, expect "" for uriSanPeerCertificate(). - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.handshake", true, - GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "", "f3828eb24fd779d0", "", "", "", + "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, NoCert) { @@ -546,7 +553,7 @@ TEST_P(SslSocketTest, NoCert) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.no_certificate", true, + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.no_certificate", true, GetParam()); } @@ -566,8 +573,8 @@ TEST_P(SslSocketTest, GetUriWithLocalUriSan) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "spiffe://lyft.com/test-team", "", "", "", - "ssl.handshake", true, GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "spiffe://lyft.com/test-team", + "f3828eb24fd779cf", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, GetSubjectsWithBothCerts) { @@ -587,7 +594,7 @@ TEST_P(SslSocketTest, GetSubjectsWithBothCerts) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", + testUtil(client_ctx_json, server_ctx_json, "", "", "", "f3828eb24fd779cf", "CN=Test Server,OU=Lyft Engineering,O=Lyft,L=San Francisco,ST=California,C=US", "CN=Test Server,OU=Lyft Engineering,O=Lyft,L=San Francisco,ST=California,C=US", "", "ssl.handshake", true, GetParam()); @@ -610,7 +617,7 @@ TEST_P(SslSocketTest, GetPeerCert) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", + testUtil(client_ctx_json, server_ctx_json, "", "", "", "f3828eb24fd779cf", "CN=Test Server,OU=Lyft Engineering,O=Lyft,L=San Francisco,ST=California,C=US", "CN=Test Server,OU=Lyft Engineering,O=Lyft,L=San Francisco,ST=California,C=US", "-----BEGIN%20CERTIFICATE-----%0A" @@ -646,7 +653,7 @@ TEST_P(SslSocketTest, FailedClientAuthCaVerificationNoClientCert) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_no_cert", + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_no_cert", false, GetParam()); } @@ -666,8 +673,8 @@ TEST_P(SslSocketTest, FailedClientAuthCaVerification) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_error", false, - GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_error", + false, GetParam()); } TEST_P(SslSocketTest, FailedClientAuthSanVerificationNoClientCert) { @@ -682,7 +689,7 @@ TEST_P(SslSocketTest, FailedClientAuthSanVerificationNoClientCert) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_no_cert", + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_no_cert", false, GetParam()); } @@ -703,8 +710,8 @@ TEST_P(SslSocketTest, FailedClientAuthSanVerification) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_san", false, - GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_san", + false, GetParam()); } // By default, expired certificates are not permitted. @@ -714,7 +721,7 @@ TEST_P(SslSocketTest, FailedClientCertificateDefaultExpirationVerification) { configureServerAndExpiredClientCertificate(listener, client); - testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", + testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", "ssl.fail_verify_error", 1, GetParam()); } @@ -732,7 +739,7 @@ TEST_P(SslSocketTest, FailedClientCertificateExpirationVerification) { ->mutable_validation_context() ->set_allow_expired_certificate(false); - testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", + testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", "ssl.fail_verify_error", 1, GetParam()); } @@ -749,8 +756,8 @@ TEST_P(SslSocketTest, ClientCertificateExpirationAllowedVerification) { ->mutable_validation_context() ->set_allow_expired_certificate(true); - testUtilV2(listener, client, "", true, "", "", "spiffe://lyft.com/test-team", "", "ssl.handshake", - 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "spiffe://lyft.com/test-team", "", "", + "ssl.handshake", 2, GetParam()); } // Allow expired certificates, but add a certificate hash requirement so it still fails. @@ -770,7 +777,7 @@ TEST_P(SslSocketTest, FailedClientCertAllowExpiredBadHashVerification) { server_validation_ctx->add_verify_certificate_hash( "0000000000000000000000000000000000000000000000000000000000000000"); - testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", + testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); } @@ -793,7 +800,7 @@ TEST_P(SslSocketTest, FailedClientCertAllowServerExpiredWrongCAVerification) { server_validation_ctx->mutable_trusted_ca()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/fake_ca_cert.pem")); - testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", + testUtilV2(listener, client, "", false, "", "", "spiffe://lyft.com/test-team", "", "", "ssl.fail_verify_error", 1, GetParam()); } @@ -814,8 +821,8 @@ TEST_P(SslSocketTest, ClientCertificateHashVerification) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "spiffe://lyft.com/test-team", "", "", "", "", - "ssl.handshake", true, GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "spiffe://lyft.com/test-team", "", + "de8a932ffc6a57da", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, ClientCertificateHashVerificationNoCA) { @@ -834,8 +841,8 @@ TEST_P(SslSocketTest, ClientCertificateHashVerificationNoCA) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "spiffe://lyft.com/test-team", "", "", "", "", - "ssl.handshake", true, GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "spiffe://lyft.com/test-team", "", + "de8a932ffc6a57da", "", "", "", "ssl.handshake", true, GetParam()); } TEST_P(SslSocketTest, ClientCertificateHashListVerification) { @@ -869,14 +876,14 @@ TEST_P(SslSocketTest, ClientCertificateHashListVerification) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); // Works even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, ClientCertificateHashListVerificationNoCA) { @@ -908,14 +915,14 @@ TEST_P(SslSocketTest, ClientCertificateHashListVerificationNoCA) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); // Works even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateHashVerificationNoClientCertificate) { @@ -930,7 +937,7 @@ TEST_P(SslSocketTest, FailedClientCertificateHashVerificationNoClientCertificate } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_no_cert", + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_no_cert", false, GetParam()); } @@ -945,7 +952,7 @@ TEST_P(SslSocketTest, FailedClientCertificateHashVerificationNoCANoClientCertifi } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_no_cert", + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_no_cert", false, GetParam()); } @@ -966,8 +973,8 @@ TEST_P(SslSocketTest, FailedClientCertificateHashVerificationWrongClientCertific } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_cert_hash", - false, GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", + "ssl.fail_verify_cert_hash", false, GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateHashVerificationNoCAWrongClientCertificate) { @@ -986,8 +993,8 @@ TEST_P(SslSocketTest, FailedClientCertificateHashVerificationNoCAWrongClientCert } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_cert_hash", - false, GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", + "ssl.fail_verify_cert_hash", false, GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateHashVerificationWrongCA) { @@ -1007,8 +1014,8 @@ TEST_P(SslSocketTest, FailedClientCertificateHashVerificationWrongCA) { } )EOF"; - testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.fail_verify_error", false, - GetParam()); + testUtil(client_ctx_json, server_ctx_json, "", "", "", "", "", "", "", "ssl.fail_verify_error", + false, GetParam()); } TEST_P(SslSocketTest, ClientCertificateSpkiVerification) { @@ -1042,14 +1049,14 @@ TEST_P(SslSocketTest, ClientCertificateSpkiVerification) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); // Works even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, ClientCertificateSpkiVerificationNoCA) { @@ -1081,14 +1088,14 @@ TEST_P(SslSocketTest, ClientCertificateSpkiVerificationNoCA) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); // Works even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoClientCertificate) { @@ -1113,11 +1120,13 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoClientCertificate envoy::api::v2::auth::UpstreamTlsContext client; - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCANoClientCertificate) { @@ -1140,11 +1149,13 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCANoClientCertifi envoy::api::v2::auth::UpstreamTlsContext client; - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongClientCertificate) { @@ -1175,12 +1186,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongClientCertific client_cert->mutable_private_key()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); } @@ -1210,12 +1221,12 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationNoCAWrongClientCert client_cert->mutable_private_key()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); } @@ -1247,11 +1258,13 @@ TEST_P(SslSocketTest, FailedClientCertificateSpkiVerificationWrongCA) { client_cert->mutable_private_key()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem")); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", 1, + GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", 1, + GetParam()); } TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerification) { @@ -1287,14 +1300,14 @@ TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerification) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); // Works even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerificationNoCA) { @@ -1328,14 +1341,14 @@ TEST_P(SslSocketTest, ClientCertificateHashAndSpkiVerificationNoCA) { // ssl.handshake logged by both: client & server. testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); // Works even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); testUtilV2(listener, client, "", true, "", "1406294e80c818158697d65d2aaca16748ff132442ab0e2f28bc1109f1d47a2e", - "spiffe://lyft.com/test-team", "", "ssl.handshake", 2, GetParam()); + "spiffe://lyft.com/test-team", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoClientCertificate) { @@ -1360,11 +1373,13 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoClientCert envoy::api::v2::auth::UpstreamTlsContext client; - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCANoClientCertificate) { @@ -1387,11 +1402,13 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCANoClient envoy::api::v2::auth::UpstreamTlsContext client; - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_no_cert", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_no_cert", 1, + GetParam()); } TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongClientCertificate) { @@ -1422,12 +1439,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongClientC client_cert->mutable_private_key()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); } @@ -1457,12 +1474,12 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationNoCAWrongCli client_cert->mutable_private_key()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/no_san_key.pem")); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_cert_hash", 1, + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_cert_hash", 1, GetParam()); } @@ -1494,11 +1511,13 @@ TEST_P(SslSocketTest, FailedClientCertificateHashAndSpkiVerificationWrongCA) { client_cert->mutable_private_key()->set_filename( TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem")); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", 1, + GetParam()); // Fails even with client renegotiation. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.fail_verify_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.fail_verify_error", 1, + GetParam()); } // Make sure that we do not flush code and do an immediate close if we have not completed the @@ -2273,43 +2292,44 @@ TEST_P(SslSocketTest, ProtocolVersions) { // Connection using defaults (client & server) succeeds, negotiating TLSv1.2. // ssl.handshake logged by both: client & server. - testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client & server) succeeds, negotiating TLSv1.2, // even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", 2, GetParam()); client.set_allow_renegotiation(false); // Connection using TLSv1.0 (client) and defaults (server) succeeds. // ssl.handshake logged by both: client & server. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); - testUtilV2(listener, client, "", true, "TLSv1", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using TLSv1.1 (client) and defaults (server) succeeds. // ssl.handshake logged by both: client & server. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); - testUtilV2(listener, client, "", true, "TLSv1.1", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.1", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using TLSv1.2 (client) and defaults (server) succeeds. // ssl.handshake logged by both: client & server. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); - testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using TLSv1.3 (client) and defaults (server) fails. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.connection_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", 1, + GetParam()); // Connection using TLSv1.3 (client) and TLSv1.0-1.3 (server) succeeds. // ssl.handshake logged by both: client & server. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); - testUtilV2(listener, client, "", true, "TLSv1.3", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.3", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client) and TLSv1.0 (server) succeeds. // ssl.handshake logged by both: client & server. @@ -2317,30 +2337,31 @@ TEST_P(SslSocketTest, ProtocolVersions) { client_params->clear_tls_maximum_protocol_version(); server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); - testUtilV2(listener, client, "", true, "TLSv1", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client) and TLSv1.1 (server) succeeds. // ssl.handshake logged by both: client & server. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_1); - testUtilV2(listener, client, "", true, "TLSv1.1", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.1", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client) and TLSv1.2 (server) succeeds. // ssl.handshake logged by both: client & server. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_2); - testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.2", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client) and TLSv1.3 (server) fails. server_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); server_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.connection_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", 1, + GetParam()); // Connection using TLSv1.0-TLSv1.3 (client) and TLSv1.3 (server) succeeds. // ssl.handshake logged by both: client & server. client_params->set_tls_minimum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_0); client_params->set_tls_maximum_protocol_version(envoy::api::v2::auth::TlsParameters::TLSv1_3); - testUtilV2(listener, client, "", true, "TLSv1.3", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "TLSv1.3", "", "", "", "", "ssl.handshake", 2, GetParam()); } TEST_P(SslSocketTest, ALPN) { @@ -2360,30 +2381,30 @@ TEST_P(SslSocketTest, ALPN) { // Connection using defaults (client & server) succeeds, no ALPN is negotiated. // ssl.handshake logged by both: client & server. - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client & server) succeeds, no ALPN is negotiated, // even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client.set_allow_renegotiation(false); // Client connects without ALPN to a server with "test" ALPN, no ALPN is negotiated. server_ctx->add_alpn_protocols("test"); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); server_ctx->clear_alpn_protocols(); // Client connects with "test" ALPN to a server without ALPN, no ALPN is negotiated. client_ctx->add_alpn_protocols("test"); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client_ctx->clear_alpn_protocols(); // Client connects with "test" ALPN to a server with "test" ALPN, "test" ALPN is negotiated. // ssl.handshake logged by both: client & server. client_ctx->add_alpn_protocols("test"); server_ctx->add_alpn_protocols("test"); - testUtilV2(listener, client, "", true, "", "", "", "test", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "test", "ssl.handshake", 2, GetParam()); client_ctx->clear_alpn_protocols(); server_ctx->clear_alpn_protocols(); @@ -2393,7 +2414,7 @@ TEST_P(SslSocketTest, ALPN) { client.set_allow_renegotiation(true); client_ctx->add_alpn_protocols("test"); server_ctx->add_alpn_protocols("test"); - testUtilV2(listener, client, "", true, "", "", "", "test", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "test", "ssl.handshake", 2, GetParam()); client.set_allow_renegotiation(false); client_ctx->clear_alpn_protocols(); server_ctx->clear_alpn_protocols(); @@ -2402,7 +2423,7 @@ TEST_P(SslSocketTest, ALPN) { // ssl.handshake logged by both: client & server. client_ctx->add_alpn_protocols("test"); server_ctx->add_alpn_protocols("test2"); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client_ctx->clear_alpn_protocols(); server_ctx->clear_alpn_protocols(); } @@ -2425,12 +2446,12 @@ TEST_P(SslSocketTest, CipherSuites) { // Connection using defaults (client & server) succeeds. // ssl.handshake logged by both: client & server. - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client & server) succeeds, even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client.set_allow_renegotiation(false); // Client connects with one of the supported cipher suites, connection succeeds. @@ -2438,14 +2459,15 @@ TEST_P(SslSocketTest, CipherSuites) { client_params->add_cipher_suites("ECDHE-RSA-CHACHA20-POLY1305"); server_params->add_cipher_suites("ECDHE-RSA-CHACHA20-POLY1305"); server_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client_params->clear_cipher_suites(); server_params->clear_cipher_suites(); // Client connects with unsupported cipher suite, connection fails. client_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); server_params->add_cipher_suites("ECDHE-RSA-CHACHA20-POLY1305"); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.connection_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", 1, + GetParam()); client_params->clear_cipher_suites(); server_params->clear_cipher_suites(); } @@ -2468,12 +2490,12 @@ TEST_P(SslSocketTest, EcdhCurves) { // Connection using defaults (client & server) succeeds. // ssl.handshake logged by both: client & server. - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); // Connection using defaults (client & server) succeeds, even with client renegotiation. // ssl.handshake logged by both: client & server. client.set_allow_renegotiation(true); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client.set_allow_renegotiation(false); // Client connects with one of the supported ECDH curves, connection succeeds. @@ -2482,7 +2504,7 @@ TEST_P(SslSocketTest, EcdhCurves) { server_params->add_ecdh_curves("X25519"); server_params->add_ecdh_curves("P-256"); server_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); - testUtilV2(listener, client, "", true, "", "", "", "", "ssl.handshake", 2, GetParam()); + testUtilV2(listener, client, "", true, "", "", "", "", "", "ssl.handshake", 2, GetParam()); client_params->clear_ecdh_curves(); server_params->clear_ecdh_curves(); server_params->clear_cipher_suites(); @@ -2491,7 +2513,8 @@ TEST_P(SslSocketTest, EcdhCurves) { client_params->add_ecdh_curves("X25519"); server_params->add_ecdh_curves("P-256"); server_params->add_cipher_suites("ECDHE-RSA-AES128-GCM-SHA256"); - testUtilV2(listener, client, "", false, "", "", "", "", "ssl.connection_error", 1, GetParam()); + testUtilV2(listener, client, "", false, "", "", "", "", "", "ssl.connection_error", 1, + GetParam()); client_params->clear_ecdh_curves(); server_params->clear_ecdh_curves(); server_params->clear_cipher_suites(); @@ -2514,7 +2537,7 @@ TEST_P(SslSocketTest, RevokedCertificate) { "private_key_file": "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } )EOF"; - testUtil(revoked_client_ctx_json, server_ctx_json, "", "", "", "", "", "", + testUtil(revoked_client_ctx_json, server_ctx_json, "", "", "", "db6c9a4af16d9091", "", "", "", "ssl.fail_verify_error", false, GetParam()); // This should succeed, since the cert isn't revoked. @@ -2524,8 +2547,25 @@ TEST_P(SslSocketTest, RevokedCertificate) { "private_key_file": "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key2.pem" } )EOF"; - testUtil(successful_client_ctx_json, server_ctx_json, "", "", "", "", "", "", "ssl.handshake", - true, GetParam()); + testUtil(successful_client_ctx_json, server_ctx_json, "", "", "", "db6c9a4af16d9091", "", "", "", + "ssl.handshake", true, GetParam()); +} + +TEST_P(SslSocketTest, GetRequestedServerName) { + envoy::api::v2::Listener listener; + envoy::api::v2::listener::FilterChain* filter_chain = listener.add_filter_chains(); + envoy::api::v2::auth::TlsCertificate* server_cert = + filter_chain->mutable_tls_context()->mutable_common_tls_context()->add_tls_certificates(); + server_cert->mutable_certificate_chain()->set_filename( + TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem")); + server_cert->mutable_private_key()->set_filename( + TestEnvironment::substitute("{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem")); + + envoy::api::v2::auth::UpstreamTlsContext client; + client.set_sni("lyft.com"); + + testUtilV2(listener, client, "", true, "", "", "", "lyft.com", "", "ssl.handshake", 2, + GetParam()); } class SslReadBufferLimitTest : public SslCertsTest, @@ -2714,7 +2754,7 @@ class SslReadBufferLimitTest : public SslCertsTest, Network::ListenerPtr listener_; Json::ObjectSharedPtr client_ctx_loader_; std::unique_ptr client_ctx_config_; - ClientContextPtr client_ctx_; + ClientContextSharedPtr client_ctx_; Network::TransportSocketFactoryPtr client_ssl_socket_factory_; Network::ClientConnectionPtr client_connection_; Network::ConnectionPtr server_connection_; diff --git a/test/common/stats/BUILD b/test/common/stats/BUILD index dfd1984f27290..a157fb6cdad8b 100644 --- a/test/common/stats/BUILD +++ b/test/common/stats/BUILD @@ -12,6 +12,7 @@ envoy_cc_test( name = "stats_impl_test", srcs = ["stats_impl_test.cc"], deps = [ + "//source/common/common:hex_lib", "//source/common/stats:stats_lib", "//test/mocks/stats:stats_mocks", "//test/test_common:logging_lib", @@ -26,8 +27,10 @@ envoy_cc_test( deps = [ "//source/common/stats:thread_local_store_lib", "//test/mocks/event:event_mocks", + "//test/mocks/server:server_mocks", "//test/mocks/stats:stats_mocks", "//test/mocks/thread_local:thread_local_mocks", + "//test/test_common:logging_lib", "//test/test_common:utility_lib", ], ) diff --git a/test/common/stats/stats_impl_test.cc b/test/common/stats/stats_impl_test.cc index 4d595c1b73ece..d976a5a5ffdfc 100644 --- a/test/common/stats/stats_impl_test.cc +++ b/test/common/stats/stats_impl_test.cc @@ -5,6 +5,7 @@ #include "envoy/config/metrics/v2/stats.pb.h" #include "envoy/stats/stats_macros.h" +#include "common/common/hex.h" #include "common/config/well_known_names.h" #include "common/stats/stats_impl.h" @@ -65,6 +66,16 @@ TEST(StatsIsolatedStoreImplTest, All) { EXPECT_EQ(2UL, store.gauges().size()); } +TEST(StatsIsolatedStoreImplTest, LongStatName) { + IsolatedStoreImpl store; + Stats::StatsOptionsImpl stats_options; + const std::string long_string(stats_options.maxNameLength() + 1, 'A'); + + ScopePtr scope = store.createScope("scope."); + Counter& counter = scope->counter(long_string); + EXPECT_EQ(absl::StrCat("scope.", long_string), counter.name()); +} + /** * Test stats macros. @see stats_macros.h */ @@ -477,22 +488,26 @@ TEST(TagProducerTest, CheckConstructor) { "No regex specified for tag specifier and no default regex for name: 'test_extractor'"); } -// Validate truncation behavior of RawStatData. -TEST(RawStatDataTest, Truncate) { - HeapRawStatDataAllocator alloc; - const std::string long_string(RawStatData::maxNameLength() + 1, 'A'); - RawStatData* stat{}; - EXPECT_LOG_CONTAINS("warning", "is too long with", stat = alloc.alloc(long_string)); +// No truncation occurs in the implementation of HeapStatData. +TEST(RawStatDataTest, HeapNoTruncate) { + Stats::StatsOptionsImpl stats_options; + HeapStatDataAllocator alloc; //(/*stats_options*/); + const std::string long_string(stats_options.maxNameLength() + 1, 'A'); + HeapStatData* stat{}; + EXPECT_NO_LOGS(stat = alloc.alloc(long_string)); + EXPECT_EQ(stat->key(), long_string); alloc.free(*stat); } +// Note: a similar test using RawStatData* is in test/server/hot_restart_impl_test.cc. TEST(RawStatDataTest, HeapAlloc) { - HeapRawStatDataAllocator alloc; - RawStatData* stat_1 = alloc.alloc("ref_name"); + Stats::StatsOptionsImpl stats_options; + HeapStatDataAllocator alloc; //(stats_options); + HeapStatData* stat_1 = alloc.alloc("ref_name"); ASSERT_NE(stat_1, nullptr); - RawStatData* stat_2 = alloc.alloc("ref_name"); + HeapStatData* stat_2 = alloc.alloc("ref_name"); ASSERT_NE(stat_2, nullptr); - RawStatData* stat_3 = alloc.alloc("not_ref_name"); + HeapStatData* stat_3 = alloc.alloc("not_ref_name"); ASSERT_NE(stat_3, nullptr); EXPECT_EQ(stat_1, stat_2); EXPECT_NE(stat_1, stat_3); diff --git a/test/common/stats/thread_local_store_test.cc b/test/common/stats/thread_local_store_test.cc index 37affaa4486b9..a32cc1bd4bb55 100644 --- a/test/common/stats/thread_local_store_test.cc +++ b/test/common/stats/thread_local_store_test.cc @@ -7,8 +7,10 @@ #include "common/stats/thread_local_store.h" #include "test/mocks/event/mocks.h" +#include "test/mocks/server/mocks.h" #include "test/mocks/stats/mocks.h" #include "test/mocks/thread_local/mocks.h" +#include "test/test_common/logging.h" #include "test/test_common/utility.h" #include "absl/strings/str_split.h" @@ -25,47 +27,34 @@ using testing::_; namespace Envoy { namespace Stats { -class StatsThreadLocalStoreTest : public testing::Test, public RawStatDataAllocator { +class StatsThreadLocalStoreTest : public testing::Test { public: - StatsThreadLocalStoreTest() { - ON_CALL(*this, alloc(_)).WillByDefault(Invoke([this](const std::string& name) -> RawStatData* { - return alloc_.alloc(name); - })); - - ON_CALL(*this, free(_)).WillByDefault(Invoke([this](RawStatData& data) -> void { - return alloc_.free(data); - })); + void SetUp() override { + alloc_ = std::make_unique(options_); + resetStoreWithAlloc(*alloc_); + } - EXPECT_CALL(*this, alloc("stats.overflow")); - store_.reset(new ThreadLocalStoreImpl(*this)); + void resetStoreWithAlloc(StatDataAllocator& alloc) { + store_ = std::make_unique(options_, alloc); store_->addSink(sink_); } - MOCK_METHOD1(alloc, RawStatData*(const std::string& name)); - MOCK_METHOD1(free, void(RawStatData& data)); - NiceMock main_thread_dispatcher_; NiceMock tls_; - TestAllocator alloc_; + StatsOptionsImpl options_; + std::unique_ptr alloc_; MockSink sink_; std::unique_ptr store_; }; -class HistogramTest : public testing::Test, public RawStatDataAllocator { +class HistogramTest : public testing::Test { public: - typedef std::map NameHistogramMap; + typedef std::map NameHistogramMap; - void SetUp() override { - ON_CALL(*this, alloc(_)).WillByDefault(Invoke([this](const std::string& name) -> RawStatData* { - return alloc_.alloc(name); - })); - - ON_CALL(*this, free(_)).WillByDefault(Invoke([this](RawStatData& data) -> void { - return alloc_.free(data); - })); + HistogramTest() : alloc_(options_) {} - EXPECT_CALL(*this, alloc("stats.overflow")); - store_.reset(new ThreadLocalStoreImpl(*this)); + void SetUp() override { + store_ = std::make_unique(options_, alloc_); store_->addSink(sink_); store_->initializeThreading(main_thread_dispatcher_, tls_); } @@ -74,12 +63,12 @@ class HistogramTest : public testing::Test, public RawStatDataAllocator { store_->shutdownThreading(); tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(alloc_, free(_)); } NameHistogramMap makeHistogramMap(const std::vector& hist_list) { NameHistogramMap name_histogram_map; - for (const Stats::ParentHistogramSharedPtr& histogram : hist_list) { + for (const ParentHistogramSharedPtr& histogram : hist_list) { // Exclude the scope part of the name. const std::vector& split_vector = absl::StrSplit(histogram->name(), '.'); name_histogram_map.insert(std::make_pair(split_vector.back(), histogram)); @@ -110,12 +99,12 @@ class HistogramTest : public testing::Test, public RawStatDataAllocator { HistogramStatisticsImpl h2_interval_statistics(hist2_interval); NameHistogramMap name_histogram_map = makeHistogramMap(histogram_list); - const Stats::ParentHistogramSharedPtr& h1 = name_histogram_map["h1"]; + const ParentHistogramSharedPtr& h1 = name_histogram_map["h1"]; EXPECT_EQ(h1->cumulativeStatistics().summary(), h1_cumulative_statistics.summary()); EXPECT_EQ(h1->intervalStatistics().summary(), h1_interval_statistics.summary()); if (histogram_list.size() > 1) { - const Stats::ParentHistogramSharedPtr& h2 = name_histogram_map["h2"]; + const ParentHistogramSharedPtr& h2 = name_histogram_map["h2"]; EXPECT_EQ(h2->cumulativeStatistics().summary(), h2_cumulative_statistics.summary()); EXPECT_EQ(h2->intervalStatistics().summary(), h2_interval_statistics.summary()); } @@ -157,7 +146,8 @@ class HistogramTest : public testing::Test, public RawStatDataAllocator { NiceMock main_thread_dispatcher_; NiceMock tls_; - TestAllocator alloc_; + StatsOptionsImpl options_; + MockedTestAllocator alloc_; MockSink sink_; std::unique_ptr store_; InSequence s; @@ -167,7 +157,7 @@ class HistogramTest : public testing::Test, public RawStatDataAllocator { TEST_F(StatsThreadLocalStoreTest, NoTls) { InSequence s; - EXPECT_CALL(*this, alloc(_)).Times(2); + EXPECT_CALL(*alloc_, alloc(_)).Times(2); Counter& c1 = store_->counter("c1"); EXPECT_EQ(&c1, &store_->counter("c1")); @@ -191,7 +181,7 @@ TEST_F(StatsThreadLocalStoreTest, NoTls) { EXPECT_EQ(2L, store_->gauges().front().use_count()); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(3); + EXPECT_CALL(*alloc_, free(_)).Times(3); store_->shutdownThreading(); } @@ -200,7 +190,7 @@ TEST_F(StatsThreadLocalStoreTest, Tls) { InSequence s; store_->initializeThreading(main_thread_dispatcher_, tls_); - EXPECT_CALL(*this, alloc(_)).Times(2); + EXPECT_CALL(*alloc_, alloc(_)).Times(2); Counter& c1 = store_->counter("c1"); EXPECT_EQ(&c1, &store_->counter("c1")); @@ -229,7 +219,7 @@ TEST_F(StatsThreadLocalStoreTest, Tls) { EXPECT_EQ(2L, store_->gauges().front().use_count()); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(3); + EXPECT_CALL(*alloc_, free(_)).Times(3); } TEST_F(StatsThreadLocalStoreTest, BasicScope) { @@ -237,7 +227,7 @@ TEST_F(StatsThreadLocalStoreTest, BasicScope) { store_->initializeThreading(main_thread_dispatcher_, tls_); ScopePtr scope1 = store_->createScope("scope1."); - EXPECT_CALL(*this, alloc(_)).Times(4); + EXPECT_CALL(*alloc_, alloc(_)).Times(4); Counter& c1 = store_->counter("c1"); Counter& c2 = scope1->counter("c2"); EXPECT_EQ("c1", c1.name()); @@ -263,7 +253,7 @@ TEST_F(StatsThreadLocalStoreTest, BasicScope) { tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(5); + EXPECT_CALL(*alloc_, free(_)).Times(5); } // Validate that we sanitize away bad characters in the stats prefix. @@ -272,7 +262,7 @@ TEST_F(StatsThreadLocalStoreTest, SanitizePrefix) { store_->initializeThreading(main_thread_dispatcher_, tls_); ScopePtr scope1 = store_->createScope(std::string("scope1:\0:foo.", 13)); - EXPECT_CALL(*this, alloc(_)); + EXPECT_CALL(*alloc_, alloc(_)); Counter& c1 = scope1->counter("c1"); EXPECT_EQ("scope1___foo.c1", c1.name()); @@ -280,7 +270,7 @@ TEST_F(StatsThreadLocalStoreTest, SanitizePrefix) { tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(2); + EXPECT_CALL(*alloc_, free(_)).Times(2); } TEST_F(StatsThreadLocalStoreTest, ScopeDelete) { @@ -288,7 +278,7 @@ TEST_F(StatsThreadLocalStoreTest, ScopeDelete) { store_->initializeThreading(main_thread_dispatcher_, tls_); ScopePtr scope1 = store_->createScope("scope1."); - EXPECT_CALL(*this, alloc(_)); + EXPECT_CALL(*alloc_, alloc(_)); scope1->counter("c1"); EXPECT_EQ(2UL, store_->counters().size()); CounterSharedPtr c1 = store_->counters().front(); @@ -303,7 +293,7 @@ TEST_F(StatsThreadLocalStoreTest, ScopeDelete) { store_->source().clearCache(); EXPECT_EQ(1UL, store_->source().cachedCounters().size()); - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(*alloc_, free(_)); EXPECT_EQ(1L, c1.use_count()); c1.reset(); @@ -311,7 +301,7 @@ TEST_F(StatsThreadLocalStoreTest, ScopeDelete) { tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(*alloc_, free(_)); } TEST_F(StatsThreadLocalStoreTest, NestedScopes) { @@ -319,12 +309,12 @@ TEST_F(StatsThreadLocalStoreTest, NestedScopes) { store_->initializeThreading(main_thread_dispatcher_, tls_); ScopePtr scope1 = store_->createScope("scope1."); - EXPECT_CALL(*this, alloc(_)); + EXPECT_CALL(*alloc_, alloc(_)); Counter& c1 = scope1->counter("foo.bar"); EXPECT_EQ("scope1.foo.bar", c1.name()); ScopePtr scope2 = scope1->createScope("foo."); - EXPECT_CALL(*this, alloc(_)); + EXPECT_CALL(*alloc_, alloc(_)); Counter& c2 = scope2->counter("bar"); EXPECT_NE(&c1, &c2); EXPECT_EQ("scope1.foo.bar", c2.name()); @@ -334,7 +324,7 @@ TEST_F(StatsThreadLocalStoreTest, NestedScopes) { EXPECT_EQ(1UL, c1.value()); EXPECT_EQ(c1.value(), c2.value()); - EXPECT_CALL(*this, alloc(_)); + EXPECT_CALL(*alloc_, alloc(_)); Gauge& g1 = scope2->gauge("some_gauge"); EXPECT_EQ("scope1.foo.some_gauge", g1.name()); @@ -342,7 +332,7 @@ TEST_F(StatsThreadLocalStoreTest, NestedScopes) { tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(4); + EXPECT_CALL(*alloc_, free(_)).Times(4); } TEST_F(StatsThreadLocalStoreTest, OverlappingScopes) { @@ -355,7 +345,7 @@ TEST_F(StatsThreadLocalStoreTest, OverlappingScopes) { ScopePtr scope2 = store_->createScope("scope1."); // We will call alloc twice, but they should point to the same backing storage. - EXPECT_CALL(*this, alloc(_)).Times(2); + EXPECT_CALL(*alloc_, alloc(_)).Times(2); Counter& c1 = scope1->counter("c"); Counter& c2 = scope2->counter("c"); EXPECT_NE(&c1, &c2); @@ -370,7 +360,7 @@ TEST_F(StatsThreadLocalStoreTest, OverlappingScopes) { EXPECT_EQ(2UL, store_->counters().size()); // Gauges should work the same way. - EXPECT_CALL(*this, alloc(_)).Times(2); + EXPECT_CALL(*alloc_, alloc(_)).Times(2); Gauge& g1 = scope1->gauge("g"); Gauge& g2 = scope2->gauge("g"); EXPECT_NE(&g1, &g2); @@ -383,7 +373,7 @@ TEST_F(StatsThreadLocalStoreTest, OverlappingScopes) { EXPECT_EQ(1UL, store_->gauges().size()); // Deleting scope 1 will call free but will be reference counted. It still leaves scope 2 valid. - EXPECT_CALL(*this, free(_)).Times(2); + EXPECT_CALL(*alloc_, free(_)).Times(2); scope1.reset(); c2.inc(); EXPECT_EQ(3UL, c2.value()); @@ -396,14 +386,14 @@ TEST_F(StatsThreadLocalStoreTest, OverlappingScopes) { tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(3); + EXPECT_CALL(*alloc_, free(_)).Times(3); } TEST_F(StatsThreadLocalStoreTest, AllocFailed) { InSequence s; store_->initializeThreading(main_thread_dispatcher_, tls_); - EXPECT_CALL(*this, alloc("foo")).WillOnce(Return(nullptr)); + EXPECT_CALL(*alloc_, alloc(absl::string_view("foo"))).WillOnce(Return(nullptr)); Counter& c1 = store_->counter("foo"); EXPECT_EQ(1UL, store_->counter("stats.overflow").value()); @@ -414,14 +404,89 @@ TEST_F(StatsThreadLocalStoreTest, AllocFailed) { tls_.shutdownThread(); // Includes overflow but not the failsafe stat which we allocated from the heap. - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(*alloc_, free(_)); +} + +TEST_F(StatsThreadLocalStoreTest, HotRestartTruncation) { + InSequence s; + store_->initializeThreading(main_thread_dispatcher_, tls_); + + // First, with a successful RawStatData allocation: + const uint64_t max_name_length = options_.maxNameLength(); + const std::string name_1(max_name_length + 1, 'A'); + + EXPECT_CALL(*alloc_, alloc(_)); + EXPECT_LOG_CONTAINS("warning", "is too long with", store_->counter(name_1)); + + // The stats did not overflow yet. + EXPECT_EQ(0UL, store_->counter("stats.overflow").value()); + + // The name will be truncated, so we won't be able to find it with the entire name. + EXPECT_EQ(nullptr, TestUtility::findCounter(*store_, name_1).get()); + + // But we can find it based on the expected truncation. + EXPECT_NE(nullptr, TestUtility::findCounter(*store_, name_1.substr(0, max_name_length)).get()); + + // The same should be true with heap allocation, which occurs when the default + // allocator fails. + const std::string name_2(max_name_length + 1, 'B'); + EXPECT_CALL(*alloc_, alloc(_)).WillOnce(Return(nullptr)); + store_->counter(name_2); + + // Same deal: the name will be truncated, so we won't be able to find it with the entire name. + EXPECT_EQ(nullptr, TestUtility::findCounter(*store_, name_1).get()); + + // But we can find it based on the expected truncation. + EXPECT_NE(nullptr, TestUtility::findCounter(*store_, name_1.substr(0, max_name_length)).get()); + + // Now the stats have overflowed. + EXPECT_EQ(1UL, store_->counter("stats.overflow").value()); + + store_->shutdownThreading(); + tls_.shutdownThread(); + + // Includes overflow, and the first raw-allocated stat, but not the failsafe stat which we + // allocated from the heap. + EXPECT_CALL(*alloc_, free(_)).Times(2); +} + +class HeapStatsThreadLocalStoreTest : public StatsThreadLocalStoreTest { +public: + void SetUp() override { + resetStoreWithAlloc(heap_alloc_); + // Note: we do not call StatsThreadLocalStoreTest::SetUp here as that + // sets up a thread_local_store with raw stat alloc. + } + void TearDown() override { + store_.reset(); // delete before the allocator. + } + + HeapStatDataAllocator heap_alloc_; +}; + +TEST_F(HeapStatsThreadLocalStoreTest, NonHotRestartNoTruncation) { + InSequence s; + store_->initializeThreading(main_thread_dispatcher_, tls_); + + // Allocate a stat greater than the max name length. + const uint64_t max_name_length = options_.maxNameLength(); + const std::string name_1(max_name_length + 1, 'A'); + + store_->counter(name_1); + + // This works fine, and we can find it by its long name because heap-stats do not + // get truncsated. + EXPECT_NE(nullptr, TestUtility::findCounter(*store_, name_1).get()); + + store_->shutdownThreading(); + tls_.shutdownThread(); } TEST_F(StatsThreadLocalStoreTest, ShuttingDown) { InSequence s; store_->initializeThreading(main_thread_dispatcher_, tls_); - EXPECT_CALL(*this, alloc(_)).Times(4); + EXPECT_CALL(*alloc_, alloc(_)).Times(4); store_->counter("c1"); store_->gauge("g1"); store_->shutdownThreading(); @@ -437,7 +502,7 @@ TEST_F(StatsThreadLocalStoreTest, ShuttingDown) { tls_.shutdownThread(); // Includes overflow stat. - EXPECT_CALL(*this, free(_)).Times(5); + EXPECT_CALL(*alloc_, free(_)).Times(5); } TEST_F(StatsThreadLocalStoreTest, MergeDuringShutDown) { @@ -460,7 +525,7 @@ TEST_F(StatsThreadLocalStoreTest, MergeDuringShutDown) { tls_.shutdownThread(); - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(*alloc_, free(_)); } // Histogram tests @@ -607,7 +672,7 @@ TEST_F(HistogramTest, BasicHistogramUsed) { // Merge histograms again and validate that both h1 and h2 are used. store_->mergeHistograms([]() -> void {}); - for (const Stats::ParentHistogramSharedPtr& histogram : store_->histograms()) { + for (const ParentHistogramSharedPtr& histogram : store_->histograms()) { EXPECT_TRUE(histogram->used()); } } diff --git a/test/common/tcp/BUILD b/test/common/tcp/BUILD new file mode 100644 index 0000000000000..097d23b9ecb6e --- /dev/null +++ b/test/common/tcp/BUILD @@ -0,0 +1,27 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_test", + "envoy_package", +) + +envoy_package() + +envoy_cc_test( + name = "conn_pool_test", + srcs = ["conn_pool_test.cc"], + deps = [ + "//source/common/event:dispatcher_lib", + "//source/common/network:utility_lib", + "//source/common/tcp:conn_pool_lib", + "//source/common/upstream:upstream_includes", + "//source/common/upstream:upstream_lib", + "//test/common/upstream:utility_lib", + "//test/mocks/event:event_mocks", + "//test/mocks/network:network_mocks", + "//test/mocks/runtime:runtime_mocks", + "//test/mocks/upstream:upstream_mocks", + "//test/test_common:utility_lib", + ], +) diff --git a/test/common/tcp/conn_pool_test.cc b/test/common/tcp/conn_pool_test.cc new file mode 100644 index 0000000000000..a6a080f83910d --- /dev/null +++ b/test/common/tcp/conn_pool_test.cc @@ -0,0 +1,754 @@ +#include +#include + +#include "common/event/dispatcher_impl.h" +#include "common/network/utility.h" +#include "common/tcp/conn_pool.h" +#include "common/upstream/upstream_impl.h" + +#include "test/common/upstream/utility.h" +#include "test/mocks/event/mocks.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/runtime/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::DoAll; +using testing::InSequence; +using testing::Invoke; +using testing::NiceMock; +using testing::Property; +using testing::Return; +using testing::ReturnRef; +using testing::SaveArg; +using testing::_; + +namespace Envoy { +namespace Tcp { + +/** + * Mock callbacks used for conn pool testing. + */ +struct ConnPoolCallbacks : public Tcp::ConnectionPool::Callbacks { + void onPoolReady(ConnectionPool::ConnectionDataPtr&& conn, + Upstream::HostDescriptionConstSharedPtr host) override { + conn_data_ = std::move(conn); + host_ = host; + pool_ready_.ready(); + } + + void onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) override { + reason_ = reason; + host_ = host; + pool_failure_.ready(); + } + + ReadyWatcher pool_failure_; + ReadyWatcher pool_ready_; + ConnectionPool::ConnectionDataPtr conn_data_{}; + absl::optional reason_; + Upstream::HostDescriptionConstSharedPtr host_; +}; + +/** + * A test version of ConnPoolImpl that allows for mocking. + */ +class ConnPoolImplForTest : public ConnPoolImpl { +public: + ConnPoolImplForTest(Event::MockDispatcher& dispatcher, + Upstream::ClusterInfoConstSharedPtr cluster, + NiceMock* upstream_ready_timer) + : ConnPoolImpl(dispatcher, Upstream::makeTestHost(cluster, "tcp://127.0.0.1:9000"), + Upstream::ResourcePriority::Default, nullptr), + mock_dispatcher_(dispatcher), mock_upstream_ready_timer_(upstream_ready_timer) {} + + ~ConnPoolImplForTest() { + EXPECT_EQ(0U, ready_conns_.size()); + EXPECT_EQ(0U, busy_conns_.size()); + EXPECT_EQ(0U, pending_requests_.size()); + } + + MOCK_METHOD0(onConnReleasedForTest, void()); + MOCK_METHOD0(onConnDestroyedForTest, void()); + + struct TestConnection { + Network::MockClientConnection* connection_; + Event::MockTimer* connect_timer_; + Network::ReadFilterSharedPtr filter_; + }; + + void expectConnCreate() { + test_conns_.emplace_back(); + TestConnection& test_conn = test_conns_.back(); + test_conn.connection_ = new NiceMock(); + test_conn.connect_timer_ = new NiceMock(&mock_dispatcher_); + + EXPECT_CALL(mock_dispatcher_, createClientConnection_(_, _, _, _)) + .WillOnce(Return(test_conn.connection_)); + EXPECT_CALL(*test_conn.connection_, addReadFilter(_)) + .WillOnce(Invoke( + [&](Network::ReadFilterSharedPtr filter) -> void { test_conn.filter_ = filter; })); + EXPECT_CALL(*test_conn.connection_, connect()); + EXPECT_CALL(*test_conn.connect_timer_, enableTimer(_)); + } + + void expectEnableUpstreamReady() { + EXPECT_FALSE(upstream_ready_enabled_); + EXPECT_CALL(*mock_upstream_ready_timer_, enableTimer(_)).Times(1).RetiresOnSaturation(); + } + + void expectAndRunUpstreamReady() { + EXPECT_TRUE(upstream_ready_enabled_); + mock_upstream_ready_timer_->callback_(); + EXPECT_FALSE(upstream_ready_enabled_); + } + + Event::MockDispatcher& mock_dispatcher_; + NiceMock* mock_upstream_ready_timer_; + std::vector test_conns_; + +protected: + void onConnReleased(ConnPoolImpl::ActiveConn& conn) override { + for (auto i = test_conns_.begin(); i != test_conns_.end(); i++) { + if (conn.conn_.get() == i->connection_) { + onConnReleasedForTest(); + break; + } + } + + ConnPoolImpl::onConnReleased(conn); + } + + void onConnDestroyed(ConnPoolImpl::ActiveConn& conn) override { + for (auto i = test_conns_.begin(); i != test_conns_.end(); i++) { + if (conn.conn_.get() == i->connection_) { + onConnDestroyedForTest(); + test_conns_.erase(i); + break; + } + } + + ConnPoolImpl::onConnDestroyed(conn); + } +}; + +/** + * Test fixture for connection pool tests. + */ +class TcpConnPoolImplTest : public testing::Test { +public: + TcpConnPoolImplTest() + : upstream_ready_timer_(new NiceMock(&dispatcher_)), + conn_pool_(dispatcher_, cluster_, upstream_ready_timer_) {} + + ~TcpConnPoolImplTest() { + // Make sure all gauges are 0. + for (const Stats::GaugeSharedPtr& gauge : cluster_->stats_store_.gauges()) { + EXPECT_EQ(0U, gauge->value()); + } + } + + NiceMock dispatcher_; + std::shared_ptr cluster_{new NiceMock()}; + NiceMock* upstream_ready_timer_; + ConnPoolImplForTest conn_pool_; + NiceMock runtime_; +}; + +/** + * Test fixture for connection pool destructor tests. + */ +class TcpConnPoolImplDestructorTest : public testing::Test { +public: + TcpConnPoolImplDestructorTest() + : upstream_ready_timer_(new NiceMock(&dispatcher_)), + conn_pool_{new ConnPoolImpl(dispatcher_, + Upstream::makeTestHost(cluster_, "tcp://127.0.0.1:9000"), + Upstream::ResourcePriority::Default, nullptr)} {} + + ~TcpConnPoolImplDestructorTest() {} + + void prepareConn() { + connection_ = new NiceMock(); + connect_timer_ = new NiceMock(&dispatcher_); + EXPECT_CALL(dispatcher_, createClientConnection_(_, _, _, _)).WillOnce(Return(connection_)); + EXPECT_CALL(*connect_timer_, enableTimer(_)); + + callbacks_ = std::make_unique(); + ConnectionPool::Cancellable* handle = conn_pool_->newConnection(*callbacks_); + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(*connect_timer_, disableTimer()); + EXPECT_CALL(callbacks_->pool_ready_, ready()); + connection_->raiseEvent(Network::ConnectionEvent::Connected); + } + + NiceMock dispatcher_; + std::shared_ptr cluster_{new NiceMock()}; + NiceMock* upstream_ready_timer_; + NiceMock* connect_timer_; + NiceMock* connection_; + std::unique_ptr conn_pool_; + std::unique_ptr callbacks_; +}; + +/** + * Helper for dealing with an active test connection. + */ +struct ActiveTestConn { + enum class Type { Pending, CreateConnection, Immediate }; + + ActiveTestConn(TcpConnPoolImplTest& parent, size_t conn_index, Type type) + : parent_(parent), conn_index_(conn_index) { + if (type == Type::CreateConnection) { + parent.conn_pool_.expectConnCreate(); + } + + if (type == Type::Immediate) { + expectNewConn(); + } + handle_ = parent.conn_pool_.newConnection(callbacks_); + + if (type == Type::Immediate) { + EXPECT_EQ(nullptr, handle_); + verifyConn(); + } else { + EXPECT_NE(nullptr, handle_); + } + + if (type == Type::CreateConnection) { + EXPECT_CALL(*parent_.conn_pool_.test_conns_[conn_index_].connect_timer_, disableTimer()); + expectNewConn(); + parent.conn_pool_.test_conns_[conn_index_].connection_->raiseEvent( + Network::ConnectionEvent::Connected); + verifyConn(); + } + } + + void expectNewConn() { EXPECT_CALL(callbacks_.pool_ready_, ready()); } + + void releaseConn() { callbacks_.conn_data_.reset(); } + + void verifyConn() { + EXPECT_EQ(&callbacks_.conn_data_->connection(), + parent_.conn_pool_.test_conns_[conn_index_].connection_); + } + + TcpConnPoolImplTest& parent_; + size_t conn_index_; + Tcp::ConnectionPool::Cancellable* handle_{}; + ConnPoolCallbacks callbacks_; +}; + +/** + * Verify that connections are drained when requested. + */ +TEST_F(TcpConnPoolImplTest, DrainConnections) { + cluster_->resource_manager_.reset( + new Upstream::ResourceManagerImpl(runtime_, "fake_key", 2, 1024, 1024, 1)); + InSequence s; + + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + ActiveTestConn c2(*this, 1, ActiveTestConn::Type::CreateConnection); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + c1.releaseConn(); + + // This will destroy the ready connection and set requests remaining to 1 on the busy connection. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.drainConnections(); + dispatcher_.clearDeferredDeleteList(); + + // This will destroy the busy connection when the response finishes. + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + c2.releaseConn(); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Test all timing stats are set. + */ +TEST_F(TcpConnPoolImplTest, VerifyTimingStats) { + EXPECT_CALL(cluster_->stats_store_, + deliverHistogramToSinks(Property(&Stats::Metric::name, "upstream_cx_connect_ms"), _)); + EXPECT_CALL(cluster_->stats_store_, + deliverHistogramToSinks(Property(&Stats::Metric::name, "upstream_cx_length_ms"), _)); + + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + c1.releaseConn(); + + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Test that buffer limits are set. + */ +TEST_F(TcpConnPoolImplTest, VerifyBufferLimits) { + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + EXPECT_CALL(*cluster_, perConnectionBufferLimitBytes()).WillOnce(Return(8192)); + EXPECT_CALL(*conn_pool_.test_conns_.back().connection_, setBufferLimits(8192)); + + EXPECT_CALL(callbacks.pool_failure_, ready()); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(TcpConnPoolImplTest, UpstreamCallbacks) { + Buffer::OwnedImpl buffer; + + InSequence s; + ConnectionPool::MockUpstreamCallbacks callbacks; + + // Create connection, set UpstreamCallbacks + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + c1.callbacks_.conn_data_->addUpstreamCallbacks(callbacks); + + // Expect invocation when connection's ReadFilter::onData is invoked + EXPECT_CALL(callbacks, onUpstreamData(_, _)); + EXPECT_EQ(Network::FilterStatus::StopIteration, + conn_pool_.test_conns_[0].filter_->onData(buffer, false)); + + EXPECT_CALL(callbacks, onAboveWriteBufferHighWatermark()); + for (auto* cb : conn_pool_.test_conns_[0].connection_->callbacks_) { + cb->onAboveWriteBufferHighWatermark(); + } + + EXPECT_CALL(callbacks, onBelowWriteBufferLowWatermark()); + for (auto* cb : conn_pool_.test_conns_[0].connection_->callbacks_) { + cb->onBelowWriteBufferLowWatermark(); + } + + // Shutdown normally. + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + c1.releaseConn(); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(TcpConnPoolImplTest, UpstreamCallbacksCloseEvent) { + Buffer::OwnedImpl buffer; + + InSequence s; + ConnectionPool::MockUpstreamCallbacks callbacks; + + // Create connection, set UpstreamCallbacks + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + c1.callbacks_.conn_data_->addUpstreamCallbacks(callbacks); + + EXPECT_CALL(callbacks, onEvent(Network::ConnectionEvent::RemoteClose)); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(TcpConnPoolImplTest, NoUpstreamCallbacks) { + Buffer::OwnedImpl buffer; + + InSequence s; + + // Create connection. + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + + // Trigger connection's ReadFilter::onData -- connection pool closes connection. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + EXPECT_EQ(Network::FilterStatus::StopIteration, + conn_pool_.test_conns_[0].filter_->onData(buffer, false)); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Tests a request that generates a new connection, completes, and then a second request that uses + * the same connection. + */ +TEST_F(TcpConnPoolImplTest, MultipleRequestAndResponse) { + InSequence s; + + // Request 1 should kick off a new connection. + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + c1.releaseConn(); + + // Request 2 should not. + ActiveTestConn c2(*this, 0, ActiveTestConn::Type::Immediate); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + c2.releaseConn(); + + // Cause the connection to go away. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Test when we overflow max pending requests. + */ +TEST_F(TcpConnPoolImplTest, MaxPendingRequests) { + cluster_->resource_manager_.reset( + new Upstream::ResourceManagerImpl(runtime_, "fake_key", 1, 1, 1024, 1)); + + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + ConnPoolCallbacks callbacks2; + EXPECT_CALL(callbacks2.pool_failure_, ready()); + Tcp::ConnectionPool::Cancellable* handle2 = conn_pool_.newConnection(callbacks2); + EXPECT_EQ(nullptr, handle2); + + handle->cancel(); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(ConnectionPool::PoolFailureReason::Overflow, callbacks2.reason_); + + EXPECT_EQ(1U, cluster_->stats_.upstream_rq_pending_overflow_.value()); +} + +/** + * Tests a connection failure before a request is bound which should result in the pending request + * getting purged. + */ +TEST_F(TcpConnPoolImplTest, RemoteConnectFailure) { + InSequence s; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(callbacks.pool_failure_, ready()); + EXPECT_CALL(*conn_pool_.test_conns_[0].connect_timer_, disableTimer()); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(ConnectionPool::PoolFailureReason::RemoteConnectionFailure, callbacks.reason_); + + EXPECT_EQ(1U, cluster_->stats_.upstream_cx_connect_fail_.value()); + EXPECT_EQ(1U, cluster_->stats_.upstream_rq_pending_failure_eject_.value()); +} + +/** + * Tests a connection failure before a request is bound which should result in the pending request + * getting purged. + */ +TEST_F(TcpConnPoolImplTest, LocalConnectFailure) { + InSequence s; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(callbacks.pool_failure_, ready()); + EXPECT_CALL(*conn_pool_.test_conns_[0].connect_timer_, disableTimer()); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::LocalClose); + dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(ConnectionPool::PoolFailureReason::LocalConnectionFailure, callbacks.reason_); + + EXPECT_EQ(1U, cluster_->stats_.upstream_cx_connect_fail_.value()); + EXPECT_EQ(1U, cluster_->stats_.upstream_rq_pending_failure_eject_.value()); +} + +/** + * Tests a connect timeout. Also test that we can add a new request during ejection processing. + */ +TEST_F(TcpConnPoolImplTest, ConnectTimeout) { + InSequence s; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks1; + conn_pool_.expectConnCreate(); + EXPECT_NE(nullptr, conn_pool_.newConnection(callbacks1)); + + ConnPoolCallbacks callbacks2; + EXPECT_CALL(callbacks1.pool_failure_, ready()).WillOnce(Invoke([&]() -> void { + conn_pool_.expectConnCreate(); + EXPECT_NE(nullptr, conn_pool_.newConnection(callbacks2)); + })); + + conn_pool_.test_conns_[0].connect_timer_->callback_(); + + EXPECT_CALL(callbacks2.pool_failure_, ready()); + conn_pool_.test_conns_[1].connect_timer_->callback_(); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()).Times(2); + dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(ConnectionPool::PoolFailureReason::Timeout, callbacks1.reason_); + EXPECT_EQ(ConnectionPool::PoolFailureReason::Timeout, callbacks2.reason_); + + EXPECT_EQ(2U, cluster_->stats_.upstream_cx_connect_fail_.value()); + EXPECT_EQ(2U, cluster_->stats_.upstream_cx_connect_timeout_.value()); +} + +/** + * Test cancelling before the request is bound to a connection. + */ +TEST_F(TcpConnPoolImplTest, CancelBeforeBound) { + InSequence s; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + handle->cancel(); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + // Cause the connection to go away. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Test an upstream disconnection while there is a bound request. + */ +TEST_F(TcpConnPoolImplTest, DisconnectWhileBound) { + InSequence s; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(callbacks.pool_ready_, ready()); + + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + // Kill the connection while it has an active request. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(TcpConnPoolImplTest, DisconnectWhilePending) { + InSequence s; + + cluster_->resource_manager_.reset( + new Upstream::ResourceManagerImpl(runtime_, "fake_key", 1, 1024, 1024, 1)); + + // First request connected. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(*conn_pool_.test_conns_[0].connect_timer_, disableTimer()); + EXPECT_CALL(callbacks.pool_ready_, ready()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + // Second request pending. + ConnPoolCallbacks callbacks2; + ConnectionPool::Cancellable* handle2 = conn_pool_.newConnection(callbacks2); + EXPECT_NE(nullptr, handle2); + + // Connection closed, triggering new connection for pending request. + conn_pool_.expectConnCreate(); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + dispatcher_.clearDeferredDeleteList(); + + // test_conns_[0] was replaced with a new connection + EXPECT_CALL(*conn_pool_.test_conns_[0].connect_timer_, disableTimer()); + EXPECT_CALL(callbacks2.pool_ready_, ready()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + callbacks2.conn_data_.reset(); + + // Disconnect + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Test that we correctly handle reaching max connections. + */ +TEST_F(TcpConnPoolImplTest, MaxConnections) { + InSequence s; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + + EXPECT_NE(nullptr, handle); + + // Request 2 should not kick off a new connection. + ConnPoolCallbacks callbacks2; + handle = conn_pool_.newConnection(callbacks2); + EXPECT_EQ(1U, cluster_->stats_.upstream_cx_overflow_.value()); + + EXPECT_NE(nullptr, handle); + + // Connect event will bind to request 1. + EXPECT_CALL(callbacks.pool_ready_, ready()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + // Finishing request 1 will immediately bind to request 2. + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + conn_pool_.expectEnableUpstreamReady(); + EXPECT_CALL(callbacks2.pool_ready_, ready()); + callbacks.conn_data_.reset(); + + conn_pool_.expectAndRunUpstreamReady(); + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + callbacks2.conn_data_.reset(); + + // Cause the connection to go away. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +/** + * Test when we reach max requests per connection. + */ +TEST_F(TcpConnPoolImplTest, MaxRequestsPerConnection) { + InSequence s; + + cluster_->max_requests_per_connection_ = 1; + + // Request 1 should kick off a new connection. + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + + EXPECT_NE(nullptr, handle); + + EXPECT_CALL(callbacks.pool_ready_, ready()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + callbacks.conn_data_.reset(); + dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(0U, cluster_->stats_.upstream_cx_destroy_with_active_rq_.value()); + EXPECT_EQ(1U, cluster_->stats_.upstream_cx_max_requests_.value()); +} + +TEST_F(TcpConnPoolImplTest, ConcurrentConnections) { + InSequence s; + + cluster_->resource_manager_.reset( + new Upstream::ResourceManagerImpl(runtime_, "fake_key", 2, 1024, 1024, 1)); + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + ActiveTestConn c2(*this, 1, ActiveTestConn::Type::CreateConnection); + ActiveTestConn c3(*this, 0, ActiveTestConn::Type::Pending); + + // Finish c1, which gets c3 going. + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + conn_pool_.expectEnableUpstreamReady(); + c3.expectNewConn(); + c1.releaseConn(); + + conn_pool_.expectAndRunUpstreamReady(); + EXPECT_CALL(conn_pool_, onConnReleasedForTest()).Times(2); + c2.releaseConn(); + c3.releaseConn(); + + // Disconnect both connections. + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()).Times(2); + conn_pool_.test_conns_[1].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(TcpConnPoolImplTest, DrainCallback) { + InSequence s; + ReadyWatcher drained; + + EXPECT_CALL(drained, ready()); + conn_pool_.addDrainedCallback([&]() -> void { drained.ready(); }); + + ActiveTestConn c1(*this, 0, ActiveTestConn::Type::CreateConnection); + ActiveTestConn c2(*this, 0, ActiveTestConn::Type::Pending); + c2.handle_->cancel(); + + EXPECT_CALL(conn_pool_, onConnReleasedForTest()); + EXPECT_CALL(drained, ready()); + c1.releaseConn(); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); + dispatcher_.clearDeferredDeleteList(); +} + +// Test draining a connection pool that has a pending connection. +TEST_F(TcpConnPoolImplTest, DrainWhileConnecting) { + InSequence s; + ReadyWatcher drained; + + ConnPoolCallbacks callbacks; + conn_pool_.expectConnCreate(); + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(callbacks); + EXPECT_NE(nullptr, handle); + + conn_pool_.addDrainedCallback([&]() -> void { drained.ready(); }); + handle->cancel(); + EXPECT_CALL(*conn_pool_.test_conns_[0].connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_CALL(drained, ready()); + conn_pool_.test_conns_[0].connection_->raiseEvent(Network::ConnectionEvent::Connected); + + EXPECT_CALL(conn_pool_, onConnDestroyedForTest()); + dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(TcpConnPoolImplDestructorTest, TestBusyConnectionsAreClosed) { + prepareConn(); + + EXPECT_CALL(*connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_CALL(dispatcher_, clearDeferredDeleteList()); + conn_pool_.reset(); +} + +TEST_F(TcpConnPoolImplDestructorTest, TestReadyConnectionsAreClosed) { + prepareConn(); + + // Transition connection to ready list + callbacks_->conn_data_.reset(); + + EXPECT_CALL(*connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_CALL(dispatcher_, clearDeferredDeleteList()); + conn_pool_.reset(); +} + +} // namespace Tcp +} // namespace Envoy diff --git a/test/common/tcp_proxy/tcp_proxy_test.cc b/test/common/tcp_proxy/tcp_proxy_test.cc index b8c61ed9f7c70..63ed7b470986c 100644 --- a/test/common/tcp_proxy/tcp_proxy_test.cc +++ b/test/common/tcp_proxy/tcp_proxy_test.cc @@ -19,6 +19,7 @@ #include "test/mocks/network/mocks.h" #include "test/mocks/runtime/mocks.h" #include "test/mocks/server/mocks.h" +#include "test/mocks/tcp/mocks.h" #include "test/mocks/upstream/host.h" #include "test/mocks/upstream/mocks.h" #include "test/test_common/printers.h" @@ -26,6 +27,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +using testing::Invoke; using testing::MatchesRegex; using testing::NiceMock; using testing::Return; @@ -315,7 +317,7 @@ TEST(ConfigTest, EmptyRouteConfig) { TEST(ConfigTest, AccessLogConfig) { envoy::config::filter::network::tcp_proxy::v2::TcpProxy config; envoy::config::filter::accesslog::v2::AccessLog* log = config.mutable_access_log()->Add(); - log->set_name(Extensions::AccessLoggers::AccessLogNames::get().FILE); + log->set_name(Extensions::AccessLoggers::AccessLogNames::get().File); { envoy::config::accesslog::v2::FileAccessLog file_access_log; file_access_log.set_path("some_path"); @@ -325,7 +327,7 @@ TEST(ConfigTest, AccessLogConfig) { } log = config.mutable_access_log()->Add(); - log->set_name(Extensions::AccessLoggers::AccessLogNames::get().FILE); + log->set_name(Extensions::AccessLoggers::AccessLogNames::get().File); { envoy::config::accesslog::v2::FileAccessLog file_access_log; file_access_log.set_path("another path"); @@ -365,7 +367,7 @@ class TcpProxyTest : public testing::Test { envoy::config::filter::network::tcp_proxy::v2::TcpProxy config = defaultConfig(); envoy::config::filter::accesslog::v2::AccessLog* access_log = config.mutable_access_log()->Add(); - access_log->set_name(Extensions::AccessLoggers::AccessLogNames::get().FILE); + access_log->set_name(Extensions::AccessLoggers::AccessLogNames::get().File); envoy::config::accesslog::v2::FileAccessLog file_access_log; file_access_log.set_path("unused"); file_access_log.set_format(access_log_format); @@ -380,53 +382,50 @@ class TcpProxyTest : public testing::Test { upstream_local_address_ = Network::Utility::resolveUrl("tcp://2.2.2.2:50000"); upstream_remote_address_ = Network::Utility::resolveUrl("tcp://127.0.0.1:80"); if (connections >= 1) { - { - testing::InSequence sequence; - for (uint32_t i = 0; i < connections; i++) { - connect_timers_.push_back( - new NiceMock(&filter_callbacks_.connection_.dispatcher_)); - EXPECT_CALL(*connect_timers_.at(i), enableTimer(_)); - } - } - for (uint32_t i = 0; i < connections; i++) { - upstream_connections_.push_back(new NiceMock()); + upstream_connections_.push_back( + std::make_unique>()); + upstream_connection_data_.push_back( + std::make_unique>()); + ON_CALL(*upstream_connection_data_.back(), connection()) + .WillByDefault(ReturnRef(*upstream_connections_.back())); upstream_hosts_.push_back(std::make_shared>()); - conn_infos_.push_back(Upstream::MockHost::MockCreateConnectionData()); - conn_infos_.at(i).connection_ = upstream_connections_.back(); - conn_infos_.at(i).host_description_ = upstream_hosts_.back(); + conn_pool_handles_.push_back( + std::make_unique>()); ON_CALL(*upstream_hosts_.at(i), cluster()) .WillByDefault(ReturnPointee( factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_)); ON_CALL(*upstream_hosts_.at(i), address()).WillByDefault(Return(upstream_remote_address_)); upstream_connections_.at(i)->local_address_ = upstream_local_address_; - EXPECT_CALL(*upstream_connections_.at(i), addReadFilter(_)) - .WillOnce(SaveArg<0>(&upstream_read_filter_)); EXPECT_CALL(*upstream_connections_.at(i), dispatcher()) .WillRepeatedly(ReturnRef(filter_callbacks_.connection_.dispatcher_)); - EXPECT_CALL(*upstream_connections_.at(i), enableHalfClose(true)); } } { testing::InSequence sequence; for (uint32_t i = 0; i < connections; i++) { - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnForCluster_("fake_cluster", _)) - .WillOnce(Return(conn_infos_.at(i))) + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillOnce(Return(&conn_pool_)) + .RetiresOnSaturation(); + EXPECT_CALL(conn_pool_, newConnection(_)) + .WillOnce(Invoke( + [=](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* { + conn_pool_callbacks_.push_back(&cb); + return conn_pool_handles_.at(i).get(); + })) .RetiresOnSaturation(); } - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnForCluster_("fake_cluster", _)) - .WillRepeatedly(Return(Upstream::MockHost::MockCreateConnectionData())); + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)) + .WillRepeatedly(Return(nullptr)); } filter_.reset(new Filter(config_, factory_context_.cluster_manager_)); EXPECT_CALL(filter_callbacks_.connection_, readDisable(true)); EXPECT_CALL(filter_callbacks_.connection_, enableHalfClose(true)); filter_->initializeReadFilterCallbacks(filter_callbacks_); - EXPECT_EQ(connections >= 1 ? Network::FilterStatus::Continue - : Network::FilterStatus::StopIteration, - filter_->onNewConnection()); + EXPECT_EQ(Network::FilterStatus::StopIteration, filter_->onNewConnection()); EXPECT_EQ(absl::optional(), filter_->computeHashKey()); EXPECT_EQ(&filter_callbacks_.connection_, filter_->downstreamConnection()); @@ -436,19 +435,37 @@ class TcpProxyTest : public testing::Test { void setup(uint32_t connections) { setup(connections, defaultConfig()); } void raiseEventUpstreamConnected(uint32_t conn_index) { - EXPECT_CALL(*connect_timers_.at(conn_index), disableTimer()); EXPECT_CALL(filter_callbacks_.connection_, readDisable(false)); - upstream_connections_.at(conn_index)->raiseEvent(Network::ConnectionEvent::Connected); + EXPECT_CALL(*upstream_connection_data_.at(conn_index), addUpstreamCallbacks(_)) + .WillOnce(Invoke([=](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { + upstream_callbacks_ = &cb; + + // Simulate TCP conn pool upstream callbacks. This is safe because the TCP proxy never + // releases a connection so all events go to the same UpstreamCallbacks instance. + upstream_connections_.at(conn_index)->addConnectionCallbacks(cb); + })); + EXPECT_CALL(*upstream_connections_.at(conn_index), enableHalfClose(true)); + conn_pool_callbacks_.at(conn_index) + ->onPoolReady(std::move(upstream_connection_data_.at(conn_index)), + upstream_hosts_.at(conn_index)); + } + + void raiseEventUpstreamConnectFailed(uint32_t conn_index, + Tcp::ConnectionPool::PoolFailureReason reason) { + conn_pool_callbacks_.at(conn_index)->onPoolFailure(reason, upstream_hosts_.at(conn_index)); } ConfigSharedPtr config_; NiceMock filter_callbacks_; NiceMock factory_context_; std::vector>> upstream_hosts_{}; - std::vector*> upstream_connections_{}; - std::vector conn_infos_; - Network::ReadFilterSharedPtr upstream_read_filter_; - std::vector*> connect_timers_; + std::vector>> upstream_connections_{}; + std::vector>> + upstream_connection_data_{}; + std::vector conn_pool_callbacks_; + std::vector>> conn_pool_handles_; + NiceMock conn_pool_; + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks_; std::unique_ptr filter_; StringViewSaver access_log_data_; Network::Address::InstanceConstSharedPtr upstream_local_address_; @@ -462,67 +479,65 @@ TEST_F(TcpProxyTest, HalfCloseProxy) { EXPECT_CALL(filter_callbacks_.connection_, close(_)).Times(0); EXPECT_CALL(*upstream_connections_.at(0), close(_)).Times(0); + raiseEventUpstreamConnected(0); + Buffer::OwnedImpl buffer("hello"); EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), true)); filter_->onData(buffer, true); - raiseEventUpstreamConnected(0); - Buffer::OwnedImpl response("world"); EXPECT_CALL(filter_callbacks_.connection_, write(BufferEqual(&response), true)); - upstream_read_filter_->onData(response, true); + upstream_callbacks_->onUpstreamData(response, true); EXPECT_CALL(filter_callbacks_.connection_, close(_)); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, - deferredDelete_(upstream_connections_.at(0))); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose); } // Test that downstream is closed after an upstream LocalClose. TEST_F(TcpProxyTest, UpstreamLocalDisconnect) { setup(1); + raiseEventUpstreamConnected(0); + Buffer::OwnedImpl buffer("hello"); EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), false)); filter_->onData(buffer, false); - raiseEventUpstreamConnected(0); - Buffer::OwnedImpl response("world"); EXPECT_CALL(filter_callbacks_.connection_, write(BufferEqual(&response), _)); - upstream_read_filter_->onData(response, false); + upstream_callbacks_->onUpstreamData(response, false); EXPECT_CALL(filter_callbacks_.connection_, close(_)); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::LocalClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose); } // Test that downstream is closed after an upstream RemoteClose. TEST_F(TcpProxyTest, UpstreamRemoteDisconnect) { setup(1); + raiseEventUpstreamConnected(0); + Buffer::OwnedImpl buffer("hello"); EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), false)); filter_->onData(buffer, false); - raiseEventUpstreamConnected(0); - Buffer::OwnedImpl response("world"); EXPECT_CALL(filter_callbacks_.connection_, write(BufferEqual(&response), _)); - upstream_read_filter_->onData(response, false); + upstream_callbacks_->onUpstreamData(response, false); EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose); } // Test that reconnect is attempted after a local connect failure TEST_F(TcpProxyTest, ConnectAttemptsUpstreamLocalFail) { envoy::config::filter::network::tcp_proxy::v2::TcpProxy config = defaultConfig(); config.mutable_max_connect_attempts()->set_value(2); + setup(2, config); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, - deferredDelete_(upstream_connections_.at(0))); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::LocalClose); + raiseEventUpstreamConnectFailed(0, + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure); raiseEventUpstreamConnected(1); EXPECT_EQ(0U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ @@ -536,9 +551,8 @@ TEST_F(TcpProxyTest, ConnectAttemptsUpstreamRemoteFail) { config.mutable_max_connect_attempts()->set_value(2); setup(2, config); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, - deferredDelete_(upstream_connections_.at(0))); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); + raiseEventUpstreamConnectFailed(0, + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); raiseEventUpstreamConnected(1); EXPECT_EQ(0U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ @@ -552,8 +566,7 @@ TEST_F(TcpProxyTest, ConnectAttemptsUpstreamTimeout) { config.mutable_max_connect_attempts()->set_value(2); setup(2, config); - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::NoFlush)); - connect_timers_.at(0)->callback_(); + raiseEventUpstreamConnectFailed(0, Tcp::ConnectionPool::PoolFailureReason::Timeout); raiseEventUpstreamConnected(1); EXPECT_EQ(0U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ @@ -567,38 +580,21 @@ TEST_F(TcpProxyTest, ConnectAttemptsLimit) { config.mutable_max_connect_attempts()->set_value(3); setup(3, config); - { - testing::InSequence sequence; - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::NoFlush)); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, - deferredDelete_(upstream_connections_.at(0))); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, - deferredDelete_(upstream_connections_.at(1))); - EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, - deferredDelete_(upstream_connections_.at(2))); - EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); - } + EXPECT_CALL(upstream_hosts_.at(0)->outlier_detector_, + putResult(Upstream::Outlier::Result::TIMEOUT)); + EXPECT_CALL(upstream_hosts_.at(1)->outlier_detector_, + putResult(Upstream::Outlier::Result::CONNECT_FAILED)); + EXPECT_CALL(upstream_hosts_.at(2)->outlier_detector_, + putResult(Upstream::Outlier::Result::CONNECT_FAILED)); - // Try both failure modes - connect_timers_.at(0)->callback_(); - upstream_connections_.at(1)->raiseEvent(Network::ConnectionEvent::RemoteClose); - upstream_connections_.at(2)->raiseEvent(Network::ConnectionEvent::RemoteClose); + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); - EXPECT_EQ(1U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_connect_timeout") - .value()); - EXPECT_EQ(2U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_connect_fail") - .value()); - EXPECT_EQ(1U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_connect_attempts_exceeded") - .value()); - EXPECT_EQ(0U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_overflow") - .value()); - EXPECT_EQ(0U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_no_successful_host") - .value()); + // Try both failure modes + raiseEventUpstreamConnectFailed(0, Tcp::ConnectionPool::PoolFailureReason::Timeout); + raiseEventUpstreamConnectFailed(1, + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); + raiseEventUpstreamConnectFailed(2, + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); } // Test that the tcp proxy sends the correct notifications to the outlier detector @@ -609,11 +605,12 @@ TEST_F(TcpProxyTest, OutlierDetection) { EXPECT_CALL(upstream_hosts_.at(0)->outlier_detector_, putResult(Upstream::Outlier::Result::TIMEOUT)); - connect_timers_.at(0)->callback_(); + raiseEventUpstreamConnectFailed(0, Tcp::ConnectionPool::PoolFailureReason::Timeout); EXPECT_CALL(upstream_hosts_.at(1)->outlier_detector_, putResult(Upstream::Outlier::Result::CONNECT_FAILED)); - upstream_connections_.at(1)->raiseEvent(Network::ConnectionEvent::RemoteClose); + raiseEventUpstreamConnectFailed(1, + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); EXPECT_CALL(upstream_hosts_.at(2)->outlier_detector_, putResult(Upstream::Outlier::Result::SUCCESS)); @@ -623,21 +620,21 @@ TEST_F(TcpProxyTest, OutlierDetection) { TEST_F(TcpProxyTest, UpstreamDisconnectDownstreamFlowControl) { setup(1); + raiseEventUpstreamConnected(0); + Buffer::OwnedImpl buffer("hello"); EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), _)); filter_->onData(buffer, false); - raiseEventUpstreamConnected(0); - Buffer::OwnedImpl response("world"); EXPECT_CALL(filter_callbacks_.connection_, write(BufferEqual(&response), _)); - upstream_read_filter_->onData(response, false); + upstream_callbacks_->onUpstreamData(response, false); EXPECT_CALL(*upstream_connections_.at(0), readDisable(true)); filter_callbacks_.connection_.runHighWatermarkCallbacks(); EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose); filter_callbacks_.connection_.runLowWatermarkCallbacks(); } @@ -645,15 +642,15 @@ TEST_F(TcpProxyTest, UpstreamDisconnectDownstreamFlowControl) { TEST_F(TcpProxyTest, DownstreamDisconnectRemote) { setup(1); + raiseEventUpstreamConnected(0); + Buffer::OwnedImpl buffer("hello"); EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), _)); filter_->onData(buffer, false); - raiseEventUpstreamConnected(0); - Buffer::OwnedImpl response("world"); EXPECT_CALL(filter_callbacks_.connection_, write(BufferEqual(&response), _)); - upstream_read_filter_->onData(response, false); + upstream_callbacks_->onUpstreamData(response, false); EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::FlushWrite)); filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::RemoteClose); @@ -662,15 +659,15 @@ TEST_F(TcpProxyTest, DownstreamDisconnectRemote) { TEST_F(TcpProxyTest, DownstreamDisconnectLocal) { setup(1); + raiseEventUpstreamConnected(0); + Buffer::OwnedImpl buffer("hello"); EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), _)); filter_->onData(buffer, false); - raiseEventUpstreamConnected(0); - Buffer::OwnedImpl response("world"); EXPECT_CALL(filter_callbacks_.connection_, write(BufferEqual(&response), _)); - upstream_read_filter_->onData(response, false); + upstream_callbacks_->onUpstreamData(response, false); EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::NoFlush)); filter_callbacks_.connection_.raiseEvent(Network::ConnectionEvent::LocalClose); @@ -679,16 +676,8 @@ TEST_F(TcpProxyTest, DownstreamDisconnectLocal) { TEST_F(TcpProxyTest, UpstreamConnectTimeout) { setup(1, accessLogConfig("%RESPONSE_FLAGS%")); - Buffer::OwnedImpl buffer("hello"); - EXPECT_CALL(*upstream_connections_.at(0), write(BufferEqual(&buffer), _)); - filter_->onData(buffer, false); - EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::NoFlush)); - connect_timers_.at(0)->callback_(); - EXPECT_EQ(1U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_connect_timeout") - .value()); + raiseEventUpstreamConnectFailed(0, Tcp::ConnectionPool::PoolFailureReason::Timeout); filter_.reset(); EXPECT_EQ(access_log_data_, "UF"); @@ -745,15 +734,9 @@ TEST_F(TcpProxyTest, DisconnectBeforeData) { TEST_F(TcpProxyTest, UpstreamConnectFailure) { setup(1, accessLogConfig("%RESPONSE_FLAGS%")); - Buffer::OwnedImpl buffer("hello"); - filter_->onData(buffer, false); - EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); - EXPECT_CALL(*connect_timers_.at(0), disableTimer()); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_connect_fail") - .value()); + raiseEventUpstreamConnectFailed(0, + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); filter_.reset(); EXPECT_EQ(access_log_data_, "UF"); @@ -771,10 +754,6 @@ TEST_F(TcpProxyTest, UpstreamConnectionLimit) { filter_->initializeReadFilterCallbacks(filter_callbacks_); filter_->onNewConnection(); - EXPECT_EQ(1U, factory_context_.cluster_manager_.thread_local_cluster_.cluster_.info_->stats_store_ - .counter("upstream_cx_overflow") - .value()); - filter_.reset(); EXPECT_EQ(access_log_data_, "UO"); } @@ -796,7 +775,7 @@ TEST_F(TcpProxyTest, IdleTimeout) { buffer.add("hello2"); EXPECT_CALL(*idle_timer, enableTimer(std::chrono::milliseconds(1000))); - upstream_read_filter_->onData(buffer, false); + upstream_callbacks_->onUpstreamData(buffer, false); EXPECT_CALL(*idle_timer, enableTimer(std::chrono::milliseconds(1000))); filter_callbacks_.connection_.raiseBytesSentCallbacks(1); @@ -835,12 +814,13 @@ TEST_F(TcpProxyTest, IdleTimerDisabledUpstreamClose) { raiseEventUpstreamConnected(0); EXPECT_CALL(*idle_timer, disableTimer()); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose); } // Test that access log fields %UPSTREAM_HOST% and %UPSTREAM_CLUSTER% are correctly logged. TEST_F(TcpProxyTest, AccessLogUpstreamHost) { setup(1, accessLogConfig("%UPSTREAM_HOST% %UPSTREAM_CLUSTER%")); + raiseEventUpstreamConnected(0); filter_.reset(); EXPECT_EQ(access_log_data_, "127.0.0.1:80 fake_cluster"); } @@ -848,6 +828,7 @@ TEST_F(TcpProxyTest, AccessLogUpstreamHost) { // Test that access log field %UPSTREAM_LOCAL_ADDRESS% is correctly logged. TEST_F(TcpProxyTest, AccessLogUpstreamLocalAddress) { setup(1, accessLogConfig("%UPSTREAM_LOCAL_ADDRESS%")); + raiseEventUpstreamConnected(0); filter_.reset(); EXPECT_EQ(access_log_data_, "2.2.2.2:50000"); } @@ -874,10 +855,10 @@ TEST_F(TcpProxyTest, AccessLogBytesRxTxDuration) { Buffer::OwnedImpl buffer("a"); filter_->onData(buffer, false); Buffer::OwnedImpl response("bb"); - upstream_read_filter_->onData(response, false); + upstream_callbacks_->onUpstreamData(response, false); std::this_thread::sleep_for(std::chrono::milliseconds(1)); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::RemoteClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose); filter_.reset(); EXPECT_THAT(access_log_data_, @@ -890,7 +871,8 @@ TEST_F(TcpProxyTest, UpstreamFlushNoTimeout) { setup(1); raiseEventUpstreamConnected(0); - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::FlushWrite)) + EXPECT_CALL(*upstream_connections_.at(0), + close(Network::ConnectionCloseType::FlushWrite)) .WillOnce(Return()); // Cancel default action of raising LocalClose EXPECT_CALL(*upstream_connections_.at(0), state()) .WillOnce(Return(Network::Connection::State::Closing)); @@ -903,7 +885,7 @@ TEST_F(TcpProxyTest, UpstreamFlushNoTimeout) { upstream_connections_.at(0)->raiseBytesSentCallbacks(1); // Simulate flush complete. - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::LocalClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose); EXPECT_EQ(1U, config_->stats().upstream_flush_total_.value()); EXPECT_EQ(0U, config_->stats().upstream_flush_active_.value()); } @@ -920,7 +902,8 @@ TEST_F(TcpProxyTest, UpstreamFlushTimeoutConfigured) { EXPECT_CALL(*idle_timer, enableTimer(_)); raiseEventUpstreamConnected(0); - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::FlushWrite)) + EXPECT_CALL(*upstream_connections_.at(0), + close(Network::ConnectionCloseType::FlushWrite)) .WillOnce(Return()); // Cancel default action of raising LocalClose EXPECT_CALL(*upstream_connections_.at(0), state()) .WillOnce(Return(Network::Connection::State::Closing)); @@ -934,7 +917,7 @@ TEST_F(TcpProxyTest, UpstreamFlushTimeoutConfigured) { // Simulate flush complete. EXPECT_CALL(*idle_timer, disableTimer()); - upstream_connections_.at(0)->raiseEvent(Network::ConnectionEvent::LocalClose); + upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose); EXPECT_EQ(1U, config_->stats().upstream_flush_total_.value()); EXPECT_EQ(0U, config_->stats().upstream_flush_active_.value()); EXPECT_EQ(0U, config_->stats().idle_timeout_.value()); @@ -951,7 +934,8 @@ TEST_F(TcpProxyTest, UpstreamFlushTimeoutExpired) { EXPECT_CALL(*idle_timer, enableTimer(_)); raiseEventUpstreamConnected(0); - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::FlushWrite)) + EXPECT_CALL(*upstream_connections_.at(0), + close(Network::ConnectionCloseType::FlushWrite)) .WillOnce(Return()); // Cancel default action of raising LocalClose EXPECT_CALL(*upstream_connections_.at(0), state()) .WillOnce(Return(Network::Connection::State::Closing)); @@ -973,7 +957,8 @@ TEST_F(TcpProxyTest, UpstreamFlushReceiveUpstreamData) { setup(1); raiseEventUpstreamConnected(0); - EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::FlushWrite)) + EXPECT_CALL(*upstream_connections_.at(0), + close(Network::ConnectionCloseType::FlushWrite)) .WillOnce(Return()); // Cancel default action of raising LocalClose EXPECT_CALL(*upstream_connections_.at(0), state()) .WillOnce(Return(Network::Connection::State::Closing)); @@ -985,7 +970,7 @@ TEST_F(TcpProxyTest, UpstreamFlushReceiveUpstreamData) { // Send some bytes; no timeout configured so this should be a no-op (not a crash). Buffer::OwnedImpl buffer("a"); EXPECT_CALL(*upstream_connections_.at(0), close(Network::ConnectionCloseType::NoFlush)); - upstream_read_filter_->onData(buffer, false); + upstream_callbacks_->onUpstreamData(buffer, false); } class TcpProxyRoutingTest : public testing::Test { @@ -1050,7 +1035,7 @@ TEST_F(TcpProxyRoutingTest, RoutableConnection) { connection_.local_address_ = std::make_shared("1.2.3.4", 9999); // Expect filter to try to open a connection to specified cluster. - EXPECT_CALL(factory_context_.cluster_manager_, tcpConnForCluster_("fake_cluster", _)); + EXPECT_CALL(factory_context_.cluster_manager_, tcpConnPoolForCluster("fake_cluster", _, _)); filter_->onNewConnection(); diff --git a/test/common/tracing/http_tracer_impl_test.cc b/test/common/tracing/http_tracer_impl_test.cc index 82d3516575f3c..1b6e2e61c6707 100644 --- a/test/common/tracing/http_tracer_impl_test.cc +++ b/test/common/tracing/http_tracer_impl_test.cc @@ -263,7 +263,7 @@ TEST(HttpConnManFinalizerImpl, SpanPopulatedFailureResponse) { absl::optional response_code(503); EXPECT_CALL(request_info, responseCode()).WillRepeatedly(ReturnPointee(&response_code)); EXPECT_CALL(request_info, bytesSent()).WillOnce(Return(100)); - ON_CALL(request_info, getResponseFlag(RequestInfo::ResponseFlag::UpstreamRequestTimeout)) + ON_CALL(request_info, hasResponseFlag(RequestInfo::ResponseFlag::UpstreamRequestTimeout)) .WillByDefault(Return(true)); EXPECT_CALL(request_info, upstreamHost()).WillOnce(Return(nullptr)); diff --git a/test/common/upstream/BUILD b/test/common/upstream/BUILD index 6dc40b5cb8397..91090e047bb4e 100644 --- a/test/common/upstream/BUILD +++ b/test/common/upstream/BUILD @@ -30,6 +30,7 @@ envoy_cc_test( srcs = ["cluster_manager_impl_test.cc"], deps = [ ":utility_lib", + "//include/envoy/stats:stats_interface", "//include/envoy/upstream:upstream_interface", "//source/common/config:bootstrap_json_lib", "//source/common/config:utility_lib", @@ -48,6 +49,7 @@ envoy_cc_test( "//test/mocks/runtime:runtime_mocks", "//test/mocks/secret:secret_mocks", "//test/mocks/server:server_mocks", + "//test/mocks/tcp:tcp_mocks", "//test/mocks/thread_local:thread_local_mocks", "//test/mocks/upstream:upstream_mocks", "//test/test_common:threadsafe_singleton_injector_lib", @@ -93,6 +95,7 @@ envoy_cc_test( "//source/common/upstream:health_checker_lib", "//source/common/upstream:upstream_lib", "//test/common/http:common_lib", + "//test/mocks/access_log:access_log_mocks", "//test/mocks/network:network_mocks", "//test/mocks/runtime:runtime_mocks", "//test/mocks/upstream:upstream_mocks", @@ -158,6 +161,24 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "hds_test", + srcs = ["hds_test.cc"], + deps = [ + "//source/common/ssl:context_lib", + "//source/common/stats:stats_lib", + "//source/common/upstream:health_discovery_service_lib", + "//test/mocks/access_log:access_log_mocks", + "//test/mocks/event:event_mocks", + "//test/mocks/grpc:grpc_mocks", + "//test/mocks/network:network_mocks", + "//test/mocks/upstream:upstream_mocks", + "//test/test_common:utility_lib", + "@envoy_api//envoy/api/v2/endpoint:load_report_cc", + "@envoy_api//envoy/service/discovery/v2:hds_cc", + ], +) + envoy_cc_test( name = "logical_dns_cluster_test", srcs = ["logical_dns_cluster_test.cc"], @@ -169,6 +190,7 @@ envoy_cc_test( "//source/common/upstream:upstream_lib", "//source/extensions/transport_sockets/raw_buffer:config", "//test/mocks:common_lib", + "//test/mocks/local_info:local_info_mocks", "//test/mocks/network:network_mocks", "//test/mocks/runtime:runtime_mocks", "//test/mocks/ssl:ssl_mocks", @@ -323,6 +345,7 @@ envoy_cc_test( "//source/common/upstream:upstream_lib", "//source/extensions/transport_sockets/raw_buffer:config", "//test/mocks:common_lib", + "//test/mocks/local_info:local_info_mocks", "//test/mocks/network:network_mocks", "//test/mocks/runtime:runtime_mocks", "//test/mocks/ssl:ssl_mocks", @@ -335,9 +358,11 @@ envoy_cc_test_library( name = "utility_lib", hdrs = ["utility.h"], deps = [ + "//include/envoy/stats:stats_interface", "//source/common/config:cds_json_lib", "//source/common/json:json_loader_lib", "//source/common/network:utility_lib", + "//source/common/stats:stats_lib", "//source/common/upstream:upstream_includes", "//source/common/upstream:upstream_lib", ], diff --git a/test/common/upstream/cluster_manager_impl_test.cc b/test/common/upstream/cluster_manager_impl_test.cc index a424ca180f1c1..6507286e8ad05 100644 --- a/test/common/upstream/cluster_manager_impl_test.cc +++ b/test/common/upstream/cluster_manager_impl_test.cc @@ -3,6 +3,7 @@ #include "envoy/admin/v2alpha/config_dump.pb.h" #include "envoy/network/listen_socket.h" +#include "envoy/stats/stats.h" #include "envoy/upstream/upstream.h" #include "common/config/bootstrap_json.h" @@ -22,6 +23,7 @@ #include "test/mocks/runtime/mocks.h" #include "test/mocks/secret/mocks.h" #include "test/mocks/server/mocks.h" +#include "test/mocks/tcp/mocks.h" #include "test/mocks/thread_local/mocks.h" #include "test/mocks/upstream/mocks.h" #include "test/test_common/threadsafe_singleton_injector.h" @@ -51,13 +53,14 @@ namespace { class TestClusterManagerFactory : public ClusterManagerFactory { public: TestClusterManagerFactory() { - ON_CALL(*this, clusterFromProto_(_, _, _, _)) + ON_CALL(*this, clusterFromProto_(_, _, _, _, _)) .WillByDefault(Invoke([&](const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, + AccessLog::AccessLogManager& log_manager, bool added_via_api) -> ClusterSharedPtr { - return ClusterImplBase::create(cluster, cm, stats_, tls_, dns_resolver_, - ssl_context_manager_, runtime_, random_, dispatcher_, - local_info_, outlier_event_logger, added_via_api); + return ClusterImplBase::create( + cluster, cm, stats_, tls_, dns_resolver_, ssl_context_manager_, runtime_, random_, + dispatcher_, log_manager, local_info_, outlier_event_logger, added_via_api); })); } @@ -67,10 +70,17 @@ class TestClusterManagerFactory : public ClusterManagerFactory { return Http::ConnectionPool::InstancePtr{allocateConnPool_(host)}; } + Tcp::ConnectionPool::InstancePtr + allocateTcpConnPool(Event::Dispatcher&, HostConstSharedPtr host, ResourcePriority, + const Network::ConnectionSocket::OptionsSharedPtr&) override { + return Tcp::ConnectionPool::InstancePtr{allocateTcpConnPool_(host)}; + } + ClusterSharedPtr clusterFromProto(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, + AccessLog::AccessLogManager& log_manager, bool added_via_api) override { - return clusterFromProto_(cluster, cm, outlier_event_logger, added_via_api); + return clusterFromProto_(cluster, cm, outlier_event_logger, log_manager, added_via_api); } CdsApiPtr createCds(const envoy::api::v2::core::ConfigSource&, @@ -97,10 +107,11 @@ class TestClusterManagerFactory : public ClusterManagerFactory { const LocalInfo::LocalInfo& local_info, AccessLog::AccessLogManager& log_manager, Server::Admin& admin)); MOCK_METHOD1(allocateConnPool_, Http::ConnectionPool::Instance*(HostConstSharedPtr host)); - MOCK_METHOD4(clusterFromProto_, + MOCK_METHOD1(allocateTcpConnPool_, Tcp::ConnectionPool::Instance*(HostConstSharedPtr host)); + MOCK_METHOD5(clusterFromProto_, ClusterSharedPtr(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, - bool added_via_api)); + AccessLog::AccessLogManager& log_manager, bool added_via_api)); MOCK_METHOD0(createCds_, CdsApi*()); Stats::IsolatedStoreImpl stats_; @@ -111,10 +122,50 @@ class TestClusterManagerFactory : public ClusterManagerFactory { NiceMock random_; Ssl::ContextManagerImpl ssl_context_manager_{runtime_}; NiceMock dispatcher_; - LocalInfo::MockLocalInfo local_info_; + NiceMock local_info_; Secret::MockSecretManager secret_manager_; }; +// Helper to intercept calls to postThreadLocalClusterUpdate. +class MockLocalClusterUpdate { +public: + MOCK_METHOD3(post, void(uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed)); +}; + +// Override postThreadLocalClusterUpdate so we can test that merged updates calls +// it with the right values at the right times. +class TestClusterManagerImpl : public ClusterManagerImpl { +public: + TestClusterManagerImpl(const envoy::config::bootstrap::v2::Bootstrap& bootstrap, + ClusterManagerFactory& factory, Stats::Store& stats, + ThreadLocal::Instance& tls, Runtime::Loader& runtime, + Runtime::RandomGenerator& random, const LocalInfo::LocalInfo& local_info, + AccessLog::AccessLogManager& log_manager, + Event::Dispatcher& main_thread_dispatcher, Server::Admin& admin, + SystemTimeSource& system_time_source, + MonotonicTimeSource& monotonic_time_source, + MockLocalClusterUpdate& local_cluster_update) + : ClusterManagerImpl(bootstrap, factory, stats, tls, runtime, random, local_info, log_manager, + main_thread_dispatcher, admin, system_time_source, + monotonic_time_source), + local_cluster_update_(local_cluster_update) {} + +protected: + void postThreadLocalClusterUpdate(const Cluster&, uint32_t priority, + const HostVector& hosts_added, + const HostVector& hosts_removed) override { + local_cluster_update_.post(priority, hosts_added, hosts_removed); + } + MockLocalClusterUpdate& local_cluster_update_; +}; + +envoy::config::bootstrap::v2::Bootstrap parseBootstrapFromV2Yaml(const std::string& yaml) { + envoy::config::bootstrap::v2::Bootstrap bootstrap; + MessageUtil::loadFromYaml(yaml, bootstrap); + return bootstrap; +} + class ClusterManagerImplTest : public testing::Test { public: void create(const envoy::config::bootstrap::v2::Bootstrap& bootstrap) { @@ -124,6 +175,39 @@ class ClusterManagerImplTest : public testing::Test { monotonic_time_source_)); } + void createWithLocalClusterUpdate(const bool enable_merge_window = true) { + std::string yaml = R"EOF( + static_resources: + clusters: + - name: cluster_1 + connect_timeout: 0.250s + type: STATIC + lb_policy: ROUND_ROBIN + hosts: + - socket_address: + address: "127.0.0.1" + port_value: 11001 + - socket_address: + address: "127.0.0.1" + port_value: 11002 + )EOF"; + const std::string merge_window = R"EOF( + common_lb_config: + update_merge_window: 3s + )EOF"; + + if (enable_merge_window) { + yaml += merge_window; + } + + const auto& bootstrap = parseBootstrapFromV2Yaml(yaml); + + cluster_manager_.reset(new TestClusterManagerImpl( + bootstrap, factory_, factory_.stats_, factory_.tls_, factory_.runtime_, factory_.random_, + factory_.local_info_, log_manager_, factory_.dispatcher_, admin_, system_time_source_, + monotonic_time_source_, local_cluster_update_)); + } + void checkStats(uint64_t added, uint64_t modified, uint64_t removed, uint64_t active, uint64_t warming) { EXPECT_EQ(added, factory_.stats_.counter("cluster_manager.cluster_added").value()); @@ -143,24 +227,33 @@ class ClusterManagerImplTest : public testing::Test { EXPECT_EQ(expected_clusters_config_dump.DebugString(), clusters_config_dump.DebugString()); } + envoy::api::v2::core::Metadata buildMetadata(const std::string& version) const { + envoy::api::v2::core::Metadata metadata; + + if (version != "") { + Envoy::Config::Metadata::mutableMetadataValue( + metadata, Config::MetadataFilters::get().ENVOY_LB, "version") + .set_string_value(version); + } + + return metadata; + } + NiceMock factory_; std::unique_ptr cluster_manager_; AccessLog::MockAccessLogManager log_manager_; NiceMock admin_; NiceMock system_time_source_; NiceMock monotonic_time_source_; + MockLocalClusterUpdate local_cluster_update_; }; envoy::config::bootstrap::v2::Bootstrap parseBootstrapFromJson(const std::string& json_string) { envoy::config::bootstrap::v2::Bootstrap bootstrap; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Config::BootstrapJson::translateClusterManagerBootstrap(*json_object_ptr, bootstrap); - return bootstrap; -} - -envoy::config::bootstrap::v2::Bootstrap parseBootstrapFromV2Yaml(const std::string& yaml) { - envoy::config::bootstrap::v2::Bootstrap bootstrap; - MessageUtil::loadFromYaml(yaml, bootstrap); + Stats::StatsOptionsImpl stats_options; + Config::BootstrapJson::translateClusterManagerBootstrap(*json_object_ptr, bootstrap, + stats_options); return bootstrap; } @@ -180,6 +273,27 @@ TEST_F(ClusterManagerImplTest, MultipleProtocolClusterFail) { "'protocol_selection' values"); } +TEST_F(ClusterManagerImplTest, MultipleHealthCheckFail) { + const std::string yaml = R"EOF( + static_resources: + clusters: + - name: service_google + connect_timeout: 0.25s + health_checks: + - timeout: 1s + interval: 1s + http_health_check: + path: "/blah" + - timeout: 1s + interval: 1s + http_health_check: + path: "/" + )EOF"; + + EXPECT_THROW_WITH_MESSAGE(create(parseBootstrapFromV2Yaml(yaml)), EnvoyException, + "Multiple health checks not supported"); +} + TEST_F(ClusterManagerImplTest, MultipleProtocolCluster) { EXPECT_CALL(system_time_source_, currentTime()) .WillRepeatedly(Return(SystemTime(std::chrono::milliseconds(1234567891234)))); @@ -497,7 +611,7 @@ class ClusterManagerImplThreadAwareLbTest : public ClusterManagerImplTest { cluster1->info_->lb_type_ = lb_type; InSequence s; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); ON_CALL(*cluster1, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); create(parseBootstrapFromJson(json)); @@ -600,6 +714,8 @@ TEST_F(ClusterManagerImplTest, UnknownCluster) { EXPECT_EQ(nullptr, cluster_manager_->get("hello")); EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster("hello", ResourcePriority::Default, Http::Protocol::Http2, nullptr)); + EXPECT_EQ(nullptr, + cluster_manager_->tcpConnPoolForCluster("hello", ResourcePriority::Default, nullptr)); EXPECT_THROW(cluster_manager_->tcpConnForCluster("hello", nullptr), EnvoyException); EXPECT_THROW(cluster_manager_->httpAsyncClientForCluster("hello"), EnvoyException); factory_.tls_.shutdownThread(); @@ -680,11 +796,11 @@ TEST_F(ClusterManagerImplTest, InitializeOrder) { // This part tests static init. InSequence s; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cds_cluster)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cds_cluster)); ON_CALL(*cds_cluster, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); ON_CALL(*cluster1, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster2)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster2)); ON_CALL(*cluster2, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Secondary)); EXPECT_CALL(factory_, createCds_()).WillOnce(Return(cds)); EXPECT_CALL(*cds, setInitializedCb(_)); @@ -711,16 +827,16 @@ TEST_F(ClusterManagerImplTest, InitializeOrder) { std::shared_ptr cluster5(new NiceMock()); cluster5->info_->name_ = "cluster5"; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster3)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster3)); ON_CALL(*cluster3, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Secondary)); cluster_manager_->addOrUpdateCluster(defaultStaticCluster("cluster3"), "version1"); - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster4)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster4)); ON_CALL(*cluster4, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); EXPECT_CALL(*cluster4, initialize(_)); cluster_manager_->addOrUpdateCluster(defaultStaticCluster("cluster4"), "version2"); - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster5)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster5)); ON_CALL(*cluster5, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Secondary)); cluster_manager_->addOrUpdateCluster(defaultStaticCluster("cluster5"), "version3"); @@ -835,7 +951,7 @@ TEST_F(ClusterManagerImplTest, DynamicRemoveWithLocalCluster) { std::shared_ptr foo(new NiceMock()); foo->info_->name_ = "foo"; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, false)).WillOnce(Return(foo)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, false)).WillOnce(Return(foo)); ON_CALL(*foo, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); EXPECT_CALL(*foo, initialize(_)); @@ -846,7 +962,7 @@ TEST_F(ClusterManagerImplTest, DynamicRemoveWithLocalCluster) { // cluster in its load balancer. std::shared_ptr cluster1(new NiceMock()); cluster1->info_->name_ = "cluster1"; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, true)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, true)).WillOnce(Return(cluster1)); ON_CALL(*cluster1, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); EXPECT_CALL(*cluster1, initialize(_)); cluster_manager_->addOrUpdateCluster(defaultStaticCluster("cluster1"), ""); @@ -892,7 +1008,7 @@ TEST_F(ClusterManagerImplTest, RemoveWarmingCluster) { cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); std::shared_ptr cluster1(new NiceMock()); - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); EXPECT_CALL(*cluster1, initializePhase()).Times(0); EXPECT_CALL(*cluster1, initialize(_)); EXPECT_TRUE( @@ -940,7 +1056,7 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { cluster_manager_->addThreadLocalClusterUpdateCallbacks(*callbacks); std::shared_ptr cluster1(new NiceMock()); - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); EXPECT_CALL(*cluster1, initializePhase()).Times(0); EXPECT_CALL(*cluster1, initialize(_)); EXPECT_CALL(*callbacks, onClusterAddOrUpdate(_)).Times(1); @@ -962,7 +1078,7 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { std::shared_ptr cluster2(new NiceMock()); cluster2->prioritySet().getMockHostSet(0)->hosts_ = { makeTestHost(cluster2->info_, "tcp://127.0.0.1:80")}; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster2)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster2)); EXPECT_CALL(*cluster2, initializePhase()).Times(0); EXPECT_CALL(*cluster2, initialize(_)) .WillOnce(Invoke([cluster1](std::function initialize_callback) { @@ -979,6 +1095,11 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { EXPECT_EQ(cp, cluster_manager_->httpConnPoolForCluster("fake_cluster", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + Tcp::ConnectionPool::MockInstance* cp2 = new Tcp::ConnectionPool::MockInstance(); + EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(cp2)); + EXPECT_EQ(cp2, cluster_manager_->tcpConnPoolForCluster("fake_cluster", ResourcePriority::Default, + nullptr)); + Network::MockClientConnection* connection = new Network::MockClientConnection(); ON_CALL(*cluster2->info_, features()) .WillByDefault(Return(ClusterInfo::Features::CLOSE_CONNECTIONS_ON_HOST_HEALTH_FAILURE)); @@ -989,10 +1110,12 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { auto conn_info = cluster_manager_->tcpConnForCluster("fake_cluster", nullptr); EXPECT_EQ(conn_info.connection_.get(), connection); - // Now remove it. This should drain the connection pool, but not affect + // Now remove the cluster. This should drain the connection pools, but not affect // tcp connections. Http::ConnectionPool::Instance::DrainedCb drained_cb; + Tcp::ConnectionPool::Instance::DrainedCb drained_cb2; EXPECT_CALL(*cp, addDrainedCallback(_)).WillOnce(SaveArg<0>(&drained_cb)); + EXPECT_CALL(*cp2, addDrainedCallback(_)).WillOnce(SaveArg<0>(&drained_cb2)); EXPECT_CALL(*callbacks, onClusterRemoval(_)).Times(1); EXPECT_TRUE(cluster_manager_->removeCluster("fake_cluster")); EXPECT_EQ(nullptr, cluster_manager_->get("fake_cluster")); @@ -1007,6 +1130,7 @@ TEST_F(ClusterManagerImplTest, DynamicAddRemove) { EXPECT_FALSE(cluster_manager_->removeCluster("foo")); drained_cb(); + drained_cb2(); checkStats(1 /*added*/, 1 /*modified*/, 1 /*removed*/, 0 /*active*/, 0 /*warming*/); @@ -1020,7 +1144,7 @@ TEST_F(ClusterManagerImplTest, addOrUpdateClusterStaticExists) { fmt::sprintf("{%s}", clustersJson({defaultStaticClusterJson("some_cluster")})); std::shared_ptr cluster1(new NiceMock()); InSequence s; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); ON_CALL(*cluster1, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); EXPECT_CALL(*cluster1, initialize(_)); @@ -1064,7 +1188,7 @@ TEST_F(ClusterManagerImplTest, CloseHttpConnectionsOnHealthFailure) { { InSequence s; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); EXPECT_CALL(health_checker, addHostCheckCompleteCb(_)); EXPECT_CALL(outlier_detector, addChangedStateCb(_)); EXPECT_CALL(*cluster1, initialize(_)) @@ -1104,6 +1228,66 @@ TEST_F(ClusterManagerImplTest, CloseHttpConnectionsOnHealthFailure) { EXPECT_TRUE(Mock::VerifyAndClearExpectations(cluster1.get())); } +// Test that we close all TCP connection pool connections when there is a host health failure. +TEST_F(ClusterManagerImplTest, CloseTcpConnectionPoolsOnHealthFailure) { + const std::string json = + fmt::sprintf("{%s}", clustersJson({defaultStaticClusterJson("some_cluster")})); + std::shared_ptr cluster1(new NiceMock()); + cluster1->info_->name_ = "some_cluster"; + HostSharedPtr test_host = makeTestHost(cluster1->info_, "tcp://127.0.0.1:80"); + cluster1->prioritySet().getMockHostSet(0)->hosts_ = {test_host}; + ON_CALL(*cluster1, initializePhase()).WillByDefault(Return(Cluster::InitializePhase::Primary)); + + MockHealthChecker health_checker; + ON_CALL(*cluster1, healthChecker()).WillByDefault(Return(&health_checker)); + + Outlier::MockDetector outlier_detector; + ON_CALL(*cluster1, outlierDetector()).WillByDefault(Return(&outlier_detector)); + + Tcp::ConnectionPool::MockInstance* cp1 = new Tcp::ConnectionPool::MockInstance(); + Tcp::ConnectionPool::MockInstance* cp2 = new Tcp::ConnectionPool::MockInstance(); + + { + InSequence s; + + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(health_checker, addHostCheckCompleteCb(_)); + EXPECT_CALL(outlier_detector, addChangedStateCb(_)); + EXPECT_CALL(*cluster1, initialize(_)) + .WillOnce(Invoke([cluster1](std::function initialize_callback) { + // Test inline init. + initialize_callback(); + })); + create(parseBootstrapFromJson(json)); + + EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(cp1)); + cluster_manager_->tcpConnPoolForCluster("some_cluster", ResourcePriority::Default, nullptr); + + outlier_detector.runCallbacks(test_host); + health_checker.runCallbacks(test_host, HealthTransition::Unchanged); + + EXPECT_CALL(*cp1, drainConnections()); + test_host->healthFlagSet(Host::HealthFlag::FAILED_OUTLIER_CHECK); + outlier_detector.runCallbacks(test_host); + + EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(cp2)); + cluster_manager_->tcpConnPoolForCluster("some_cluster", ResourcePriority::High, nullptr); + } + + // Order of these calls is implementation dependent, so can't sequence them! + EXPECT_CALL(*cp1, drainConnections()); + EXPECT_CALL(*cp2, drainConnections()); + test_host->healthFlagSet(Host::HealthFlag::FAILED_ACTIVE_HC); + health_checker.runCallbacks(test_host, HealthTransition::Changed); + + test_host->healthFlagClear(Host::HealthFlag::FAILED_OUTLIER_CHECK); + outlier_detector.runCallbacks(test_host); + test_host->healthFlagClear(Host::HealthFlag::FAILED_ACTIVE_HC); + health_checker.runCallbacks(test_host, HealthTransition::Changed); + + EXPECT_TRUE(Mock::VerifyAndClearExpectations(cluster1.get())); +} + // Test that we close all TCP connection pool connections when there is a host health failure, when // configured to do so. TEST_F(ClusterManagerImplTest, CloseTcpConnectionsOnHealthFailure) { @@ -1136,7 +1320,7 @@ TEST_F(ClusterManagerImplTest, CloseTcpConnectionsOnHealthFailure) { { InSequence s; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); EXPECT_CALL(health_checker, addHostCheckCompleteCb(_)); EXPECT_CALL(outlier_detector, addChangedStateCb(_)); EXPECT_CALL(*cluster1, initialize(_)) @@ -1208,7 +1392,7 @@ TEST_F(ClusterManagerImplTest, DoNotCloseTcpConnectionsOnHealthFailure) { Network::MockClientConnection* connection1 = new NiceMock(); Host::CreateConnectionData conn_info1; - EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _)).WillOnce(Return(cluster1)); + EXPECT_CALL(factory_, clusterFromProto_(_, _, _, _, _)).WillOnce(Return(cluster1)); EXPECT_CALL(health_checker, addHostCheckCompleteCb(_)); EXPECT_CALL(outlier_detector, addChangedStateCb(_)); EXPECT_CALL(*cluster1, initialize(_)) @@ -1261,8 +1445,10 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { // Test for no hosts returning the correct values before we have hosts. EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, + nullptr)); EXPECT_EQ(nullptr, cluster_manager_->tcpConnForCluster("cluster_1", nullptr).connection_); - EXPECT_EQ(2UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); + EXPECT_EQ(3UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); // Set up for an initialize callback. ReadyWatcher initialized; @@ -1303,14 +1489,41 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { Http::ConnectionPool::Instance::DrainedCb drained_cb_high; EXPECT_CALL(*cp1_high, addDrainedCallback(_)).WillOnce(SaveArg<0>(&drained_cb_high)); + EXPECT_CALL(factory_, allocateTcpConnPool_(_)) + .Times(4) + .WillRepeatedly(ReturnNew()); + + // This should provide us a CP for each of the above hosts. + Tcp::ConnectionPool::MockInstance* tcp1 = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp2 = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp1_high = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::High, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp2_high = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::High, nullptr)); + + EXPECT_NE(tcp1, tcp2); + EXPECT_NE(tcp1_high, tcp2_high); + EXPECT_NE(tcp1, tcp1_high); + + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb; + EXPECT_CALL(*tcp1, addDrainedCallback(_)).WillOnce(SaveArg<0>(&tcp_drained_cb)); + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb_high; + EXPECT_CALL(*tcp1_high, addDrainedCallback(_)).WillOnce(SaveArg<0>(&tcp_drained_cb_high)); + // Remove the first host, this should lead to the first cp being drained. dns_timer_->callback_(); dns_callback(TestUtility::makeDnsResponse({"127.0.0.2"})); drained_cb(); drained_cb = nullptr; - EXPECT_CALL(factory_.tls_.dispatcher_, deferredDelete_(_)).Times(2); + tcp_drained_cb(); + tcp_drained_cb = nullptr; + EXPECT_CALL(factory_.tls_.dispatcher_, deferredDelete_(_)).Times(4); drained_cb_high(); drained_cb_high = nullptr; + tcp_drained_cb_high(); + tcp_drained_cb_high = nullptr; // Make sure we get back the same connection pool for the 2nd host as we did before the change. Http::ConnectionPool::MockInstance* cp3 = @@ -1322,6 +1535,13 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemove) { EXPECT_EQ(cp2, cp3); EXPECT_EQ(cp2_high, cp3_high); + Tcp::ConnectionPool::MockInstance* tcp3 = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp3_high = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::High, nullptr)); + EXPECT_EQ(tcp2, tcp3); + EXPECT_EQ(tcp2_high, tcp3_high); + // Now add and remove a host that we never have a conn pool to. This should not lead to any // drain callbacks, etc. dns_timer_->callback_(); @@ -1368,14 +1588,23 @@ TEST_F(ClusterManagerImplTest, DynamicHostRemoveDefaultPriority) { EXPECT_CALL(factory_, allocateConnPool_(_)) .WillOnce(ReturnNew()); + EXPECT_CALL(factory_, allocateTcpConnPool_(_)) + .WillOnce(ReturnNew()); + Http::ConnectionPool::MockInstance* cp = dynamic_cast(cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + // Immediate drain, since this can happen with the HTTP codecs. EXPECT_CALL(*cp, addDrainedCallback(_)) .WillOnce(Invoke([](Http::ConnectionPool::Instance::DrainedCb cb) { cb(); })); + EXPECT_CALL(*tcp, addDrainedCallback(_)) + .WillOnce(Invoke([](Tcp::ConnectionPool::Instance::DrainedCb cb) { cb(); })); + // Remove the first host, this should lead to the cp being drained, without // crash. dns_timer_->callback_(); @@ -1391,6 +1620,13 @@ class MockConnPoolWithDestroy : public Http::ConnectionPool::MockInstance { MOCK_METHOD0(onDestroy, void()); }; +class MockTcpConnPoolWithDestroy : public Tcp::ConnectionPool::MockInstance { +public: + ~MockTcpConnPoolWithDestroy() { onDestroy(); } + + MOCK_METHOD0(onDestroy, void()); +}; + // Regression test for https://github.com/envoyproxy/envoy/issues/3518. Make sure we handle a // drain callback during CP destroy. TEST_F(ClusterManagerImplTest, ConnPoolDestroyWithDraining) { @@ -1427,18 +1663,27 @@ TEST_F(ClusterManagerImplTest, ConnPoolDestroyWithDraining) { MockConnPoolWithDestroy* mock_cp = new MockConnPoolWithDestroy(); EXPECT_CALL(factory_, allocateConnPool_(_)).WillOnce(Return(mock_cp)); + MockTcpConnPoolWithDestroy* mock_tcp = new MockTcpConnPoolWithDestroy(); + EXPECT_CALL(factory_, allocateTcpConnPool_(_)).WillOnce(Return(mock_tcp)); + Http::ConnectionPool::MockInstance* cp = dynamic_cast(cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + Tcp::ConnectionPool::MockInstance* tcp = dynamic_cast( + cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, nullptr)); + // Remove the first host, this should lead to the cp being drained. Http::ConnectionPool::Instance::DrainedCb drained_cb; EXPECT_CALL(*cp, addDrainedCallback(_)).WillOnce(SaveArg<0>(&drained_cb)); + Tcp::ConnectionPool::Instance::DrainedCb tcp_drained_cb; + EXPECT_CALL(*tcp, addDrainedCallback(_)).WillOnce(SaveArg<0>(&tcp_drained_cb)); dns_timer_->callback_(); dns_callback(TestUtility::makeDnsResponse({})); // The drained callback might get called when the CP is being destroyed. EXPECT_CALL(*mock_cp, onDestroy()).WillOnce(Invoke(drained_cb)); + EXPECT_CALL(*mock_tcp, onDestroy()).WillOnce(Invoke(tcp_drained_cb)); factory_.tls_.shutdownThread(); } @@ -1468,12 +1713,188 @@ TEST_F(ClusterManagerImplTest, OriginalDstInitialization) { // Test for no hosts returning the correct values before we have hosts. EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster( "cluster_1", ResourcePriority::Default, Http::Protocol::Http11, nullptr)); + EXPECT_EQ(nullptr, cluster_manager_->tcpConnPoolForCluster("cluster_1", ResourcePriority::Default, + nullptr)); EXPECT_EQ(nullptr, cluster_manager_->tcpConnForCluster("cluster_1", nullptr).connection_); - EXPECT_EQ(2UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); + EXPECT_EQ(3UL, factory_.stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); factory_.tls_.shutdownThread(); } +// Tests that all the HC/weight/metadata changes are delivered in one go, as long as +// there's no hosts changes in between. +// Also tests that if hosts are added/removed between mergeable updates, delivery will +// happen and the scheduled update will be canceled. +TEST_F(ClusterManagerImplTest, MergedUpdates) { + createWithLocalClusterUpdate(); + + // Ensure we see the right set of added/removed hosts on every call. + EXPECT_CALL(local_cluster_update_, post(_, _, _)) + .WillOnce(Invoke([](uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed) -> void { + // 1st removal. + EXPECT_EQ(0, priority); + EXPECT_EQ(0, hosts_added.size()); + EXPECT_EQ(1, hosts_removed.size()); + })) + .WillOnce(Invoke([](uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed) -> void { + // Triggerd by the 2 HC updates, it's a merged update so no added/removed + // hosts. + EXPECT_EQ(0, priority); + EXPECT_EQ(0, hosts_added.size()); + EXPECT_EQ(0, hosts_removed.size()); + })) + .WillOnce(Invoke([](uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed) -> void { + // 1st removed host added back. + EXPECT_EQ(0, priority); + EXPECT_EQ(1, hosts_added.size()); + EXPECT_EQ(0, hosts_removed.size()); + })) + .WillOnce(Invoke([](uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed) -> void { + // 1st removed host removed again, plus the 3 HC/weight/metadata updates that were + // waiting for delivery. + EXPECT_EQ(0, priority); + EXPECT_EQ(0, hosts_added.size()); + EXPECT_EQ(1, hosts_removed.size()); + })); + + Event::MockTimer* timer = new NiceMock(&factory_.dispatcher_); + const Cluster& cluster = cluster_manager_->clusters().begin()->second; + HostVectorSharedPtr hosts( + new HostVector(cluster.prioritySet().hostSetsPerPriority()[0]->hosts())); + HostsPerLocalitySharedPtr hosts_per_locality = std::make_shared(); + HostVector hosts_added; + HostVector hosts_removed; + + // The first update should be applied immediately, since it's not mergeable. + hosts_removed.push_back((*hosts)[0]); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); + + // This calls should be merged, since there are not added/removed hosts. + hosts_removed.clear(); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); + + // Ensure the merged updates were applied. + timer->callback_(); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); + + // Add the host back, the update should be immediately applied. + hosts_removed.clear(); + hosts_added.push_back((*hosts)[0]); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + EXPECT_EQ(2, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); + + // Now emit 3 updates that should be scheduled: metadata, HC, and weight. + hosts_added.clear(); + + (*hosts)[0]->metadata(buildMetadata("v1")); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + + (*hosts)[0]->healthFlagSet(Host::HealthFlag::FAILED_EDS_HEALTH); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + + (*hosts)[0]->weight(100); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + + // Updates not delivered yet. + EXPECT_EQ(2, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); + + // Remove the host again, should cancel the scheduled update and be delivered immediately. + hosts_removed.push_back((*hosts)[0]); + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + + EXPECT_EQ(3, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); +} + +// Tests that mergeable updates outside of a window get applied immediately. +TEST_F(ClusterManagerImplTest, MergedUpdatesOutOfWindow) { + createWithLocalClusterUpdate(); + + // Ensure we see the right set of added/removed hosts on every call. + EXPECT_CALL(local_cluster_update_, post(_, _, _)) + .WillOnce(Invoke([](uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed) -> void { + // HC update, immediately delivered. + EXPECT_EQ(0, priority); + EXPECT_EQ(0, hosts_added.size()); + EXPECT_EQ(0, hosts_removed.size()); + })); + + const Cluster& cluster = cluster_manager_->clusters().begin()->second; + HostVectorSharedPtr hosts( + new HostVector(cluster.prioritySet().hostSetsPerPriority()[0]->hosts())); + HostsPerLocalitySharedPtr hosts_per_locality = std::make_shared(); + HostVector hosts_added; + HostVector hosts_removed; + + // The first update should be applied immediately, because even though it's mergeable + // it's outside a merge window. + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.update_out_of_merge_window").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); +} + +// Tests that mergeable updates outside of a window get applied immediately when +// merging is disabled, and that the counters are correct. +TEST_F(ClusterManagerImplTest, MergedUpdatesOutOfWindowDisabled) { + createWithLocalClusterUpdate(false); + + // Ensure we see the right set of added/removed hosts on every call. + EXPECT_CALL(local_cluster_update_, post(_, _, _)) + .WillOnce(Invoke([](uint32_t priority, const HostVector& hosts_added, + const HostVector& hosts_removed) -> void { + // HC update, immediately delivered. + EXPECT_EQ(0, priority); + EXPECT_EQ(0, hosts_added.size()); + EXPECT_EQ(0, hosts_removed.size()); + })); + + const Cluster& cluster = cluster_manager_->clusters().begin()->second; + HostVectorSharedPtr hosts( + new HostVector(cluster.prioritySet().hostSetsPerPriority()[0]->hosts())); + HostsPerLocalitySharedPtr hosts_per_locality = std::make_shared(); + HostVector hosts_added; + HostVector hosts_removed; + + // The first update should be applied immediately, because even though it's mergeable + // and outside a merge window, merging is disabled. + cluster.prioritySet().hostSetsPerPriority()[0]->updateHosts( + hosts, hosts, hosts_per_locality, hosts_per_locality, {}, hosts_added, hosts_removed); + EXPECT_EQ(1, factory_.stats_.counter("cluster_manager.cluster_updated").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.cluster_updated_via_merge").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_out_of_merge_window").value()); + EXPECT_EQ(0, factory_.stats_.counter("cluster_manager.update_merge_cancelled").value()); +} + class ClusterManagerInitHelperTest : public testing::Test { public: MOCK_METHOD1(onClusterInit, void(Cluster& cluster)); diff --git a/test/common/upstream/eds_test.cc b/test/common/upstream/eds_test.cc index ee7e6139444af..1d8555da76228 100644 --- a/test/common/upstream/eds_test.cc +++ b/test/common/upstream/eds_test.cc @@ -168,6 +168,9 @@ TEST_F(EdsTest, EndpointMetadata) { Config::MetadataFilters::get().ENVOY_LB, Config::MetadataEnvoyLbKeys::get().CANARY) .set_bool_value(true); + Config::Metadata::mutableMetadataValue(*canary->mutable_metadata(), + Config::MetadataFilters::get().ENVOY_LB, "version") + .set_string_value("v1"); bool initialized = false; cluster_->initialize([&initialized] { initialized = true; }); @@ -177,30 +180,46 @@ TEST_F(EdsTest, EndpointMetadata) { auto& hosts = cluster_->prioritySet().hostSetsPerPriority()[0]->hosts(); EXPECT_EQ(hosts.size(), 2); - EXPECT_EQ(hosts[0]->metadata().filter_metadata_size(), 2); - EXPECT_EQ(Config::Metadata::metadataValue(hosts[0]->metadata(), + EXPECT_EQ(hosts[0]->metadata()->filter_metadata_size(), 2); + EXPECT_EQ(Config::Metadata::metadataValue(*hosts[0]->metadata(), Config::MetadataFilters::get().ENVOY_LB, "string_key") .string_value(), std::string("string_value")); - EXPECT_EQ(Config::Metadata::metadataValue(hosts[0]->metadata(), "custom_namespace", "num_key") + EXPECT_EQ(Config::Metadata::metadataValue(*hosts[0]->metadata(), "custom_namespace", "num_key") .number_value(), 1.1); - EXPECT_FALSE(Config::Metadata::metadataValue(hosts[0]->metadata(), + EXPECT_FALSE(Config::Metadata::metadataValue(*hosts[0]->metadata(), Config::MetadataFilters::get().ENVOY_LB, Config::MetadataEnvoyLbKeys::get().CANARY) .bool_value()); EXPECT_FALSE(hosts[0]->canary()); - EXPECT_EQ(hosts[1]->metadata().filter_metadata_size(), 1); - EXPECT_TRUE(Config::Metadata::metadataValue(hosts[1]->metadata(), + EXPECT_EQ(hosts[1]->metadata()->filter_metadata_size(), 1); + EXPECT_TRUE(Config::Metadata::metadataValue(*hosts[1]->metadata(), Config::MetadataFilters::get().ENVOY_LB, Config::MetadataEnvoyLbKeys::get().CANARY) .bool_value()); EXPECT_TRUE(hosts[1]->canary()); + EXPECT_EQ(Config::Metadata::metadataValue(*hosts[1]->metadata(), + Config::MetadataFilters::get().ENVOY_LB, "version") + .string_value(), + "v1"); // We don't rebuild with the exact same config. VERBOSE_EXPECT_NO_THROW(cluster_->onConfigUpdate(resources, "")); EXPECT_EQ(1UL, stats_.counter("cluster.name.update_no_rebuild").value()); + + // New resources with Metadata updated. + Config::Metadata::mutableMetadataValue(*canary->mutable_metadata(), + Config::MetadataFilters::get().ENVOY_LB, "version") + .set_string_value("v2"); + VERBOSE_EXPECT_NO_THROW(cluster_->onConfigUpdate(resources, "")); + auto& nhosts = cluster_->prioritySet().hostSetsPerPriority()[0]->hosts(); + EXPECT_EQ(nhosts.size(), 2); + EXPECT_EQ(Config::Metadata::metadataValue(*nhosts[1]->metadata(), + Config::MetadataFilters::get().ENVOY_LB, "version") + .string_value(), + "v2"); } // Validate that onConfigUpdate() updates endpoint health status. diff --git a/test/common/upstream/hds_test.cc b/test/common/upstream/hds_test.cc new file mode 100644 index 0000000000000..b175e525433e3 --- /dev/null +++ b/test/common/upstream/hds_test.cc @@ -0,0 +1,287 @@ +#include "envoy/service/discovery/v2/hds.pb.h" + +#include "common/ssl/context_manager_impl.h" +#include "common/stats/stats_impl.h" +#include "common/upstream/health_discovery_service.h" + +#include "test/mocks/access_log/mocks.h" +#include "test/mocks/event/mocks.h" +#include "test/mocks/grpc/mocks.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::InSequence; +using testing::Invoke; +using testing::NiceMock; +using testing::Return; +using testing::_; + +using ::testing::AtLeast; + +namespace Envoy { +namespace Upstream { + +// Friend class of HdsDelegate, making it easier to access private fields +class HdsDelegateFriend { +public: + // Allows access to private function processMessage + void processPrivateMessage( + HdsDelegate& hd, + std::unique_ptr&& message) { + hd.processMessage(std::move(message)); + }; +}; + +class HdsTest : public testing::Test { +public: + HdsTest() + : retry_timer_(new Event::MockTimer()), server_response_timer_(new Event::MockTimer()), + async_client_(new Grpc::MockAsyncClient()) { + node_.set_id("foo"); + } + + // Creates an HdsDelegate + void createHdsDelegate() { + InSequence s; + EXPECT_CALL(dispatcher_, createTimer_(_)).WillOnce(Invoke([this](Event::TimerCb timer_cb) { + retry_timer_cb_ = timer_cb; + return retry_timer_; + })); + EXPECT_CALL(dispatcher_, createTimer_(_)) + .Times(AtLeast(1)) + .WillOnce(Invoke([this](Event::TimerCb timer_cb) { + server_response_timer_cb_ = timer_cb; + return server_response_timer_; + })); + hds_delegate_.reset(new HdsDelegate(node_, stats_store_, Grpc::AsyncClientPtr(async_client_), + dispatcher_, runtime_, stats_store_, ssl_context_manager_, + secret_manager_, random_, test_factory_, log_manager_)); + } + + // Creates a HealthCheckSpecifier message that contains one endpoint and one + // healthcheck + envoy::service::discovery::v2::HealthCheckSpecifier* createSimpleMessage() { + envoy::service::discovery::v2::HealthCheckSpecifier* msg = + new envoy::service::discovery::v2::HealthCheckSpecifier; + msg->mutable_interval()->set_seconds(1); + + auto* health_check = msg->add_health_check(); + health_check->set_cluster_name("anna"); + health_check->add_health_checks()->mutable_timeout()->set_seconds(1); + health_check->mutable_health_checks(0)->mutable_interval()->set_seconds(1); + health_check->mutable_health_checks(0)->mutable_unhealthy_threshold()->set_value(2); + health_check->mutable_health_checks(0)->mutable_healthy_threshold()->set_value(2); + health_check->mutable_health_checks(0)->mutable_grpc_health_check(); + health_check->mutable_health_checks(0)->mutable_http_health_check()->set_use_http2(false); + health_check->mutable_health_checks(0)->mutable_http_health_check()->set_path("/healthcheck"); + + auto* socket_address = + health_check->add_endpoints()->add_endpoints()->mutable_address()->mutable_socket_address(); + socket_address->set_address("127.0.0.0"); + socket_address->set_port_value(1234); + + return msg; + } + + envoy::api::v2::core::Node node_; + Event::MockDispatcher dispatcher_; + Stats::IsolatedStoreImpl stats_store_; + MockClusterInfoFactory test_factory_; + + std::unique_ptr hds_delegate_; + HdsDelegateFriend hds_delegate_friend_; + + Event::MockTimer* retry_timer_; + Event::TimerCb retry_timer_cb_; + Event::MockTimer* server_response_timer_; + Event::TimerCb server_response_timer_cb_; + + std::shared_ptr cluster_info_{ + new NiceMock()}; + std::unique_ptr message; + Grpc::MockAsyncStream async_stream_; + Grpc::MockAsyncClient* async_client_; + Runtime::MockLoader runtime_; + Ssl::ContextManagerImpl ssl_context_manager_{runtime_}; + Secret::MockSecretManager secret_manager_; + NiceMock random_; + NiceMock log_manager_; +}; + +// Test if processMessage processes endpoints from a HealthCheckSpecifier +// message correctly +TEST_F(HdsTest, TestProcessMessageEndpoints) { + EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + EXPECT_CALL(async_stream_, sendMessage(_, _)); + createHdsDelegate(); + + // Create Message + // - Cluster "anna0" with 3 endpoints + // - Cluster "anna1" with 3 endpoints + message.reset(new envoy::service::discovery::v2::HealthCheckSpecifier); + message->mutable_interval()->set_seconds(1); + + for (int i = 0; i < 2; i++) { + auto* health_check = message->add_health_check(); + health_check->set_cluster_name("anna" + std::to_string(i)); + for (int j = 0; j < 3; j++) { + auto* address = health_check->add_endpoints()->add_endpoints()->mutable_address(); + address->mutable_socket_address()->set_address("127.0.0." + std::to_string(i)); + address->mutable_socket_address()->set_port_value(1234 + j); + } + } + + // Process message + EXPECT_CALL(test_factory_, createClusterInfo(_, _, _, _, _, _, _)).Times(2); + hds_delegate_friend_.processPrivateMessage(*hds_delegate_, std::move(message)); + + // Check Correctness + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 3; j++) { + auto& host = + hds_delegate_->hdsClusters()[i]->prioritySet().hostSetsPerPriority()[0]->hosts()[j]; + EXPECT_EQ(host->address()->ip()->addressAsString(), "127.0.0." + std::to_string(i)); + EXPECT_EQ(host->address()->ip()->port(), 1234 + j); + } + } +} + +// Test if processMessage processes health checks from a HealthCheckSpecifier +// message correctly +TEST_F(HdsTest, TestProcessMessageHealthChecks) { + EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + EXPECT_CALL(async_stream_, sendMessage(_, _)); + createHdsDelegate(); + + // Create Message + // - Cluster "minkowski0" with 2 health_checks + // - Cluster "minkowski1" with 3 health_checks + message.reset(new envoy::service::discovery::v2::HealthCheckSpecifier); + message->mutable_interval()->set_seconds(1); + + for (int i = 0; i < 2; i++) { + auto* health_check = message->add_health_check(); + health_check->set_cluster_name("minkowski" + std::to_string(i)); + for (int j = 0; j < i + 2; j++) { + auto hc = health_check->add_health_checks(); + hc->mutable_timeout()->set_seconds(i); + hc->mutable_interval()->set_seconds(j); + hc->mutable_unhealthy_threshold()->set_value(j + 1); + hc->mutable_healthy_threshold()->set_value(j + 1); + hc->mutable_grpc_health_check(); + hc->mutable_http_health_check()->set_use_http2(false); + hc->mutable_http_health_check()->set_path("/healthcheck"); + } + } + + // Process message + EXPECT_CALL(test_factory_, createClusterInfo(_, _, _, _, _, _, _)) + .WillRepeatedly(Return(cluster_info_)); + + hds_delegate_friend_.processPrivateMessage(*hds_delegate_, std::move(message)); + + // Check Correctness + EXPECT_EQ(hds_delegate_->hdsClusters()[0]->healthCheckers().size(), 2); + EXPECT_EQ(hds_delegate_->hdsClusters()[1]->healthCheckers().size(), 3); +} + +// Tests OnReceiveMessage given a minimal HealthCheckSpecifier message +TEST_F(HdsTest, TestMinimalOnReceiveMessage) { + EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + EXPECT_CALL(async_stream_, sendMessage(_, _)); + createHdsDelegate(); + + // Create Message + message.reset(new envoy::service::discovery::v2::HealthCheckSpecifier); + message->mutable_interval()->set_seconds(1); + + // Process message + EXPECT_CALL(*server_response_timer_, enableTimer(_)).Times(AtLeast(1)); + hds_delegate_->onReceiveMessage(std::move(message)); +} + +// Tests that SendResponse responds to the server in a timely fashion +// given a minimal HealthCheckSpecifier message +TEST_F(HdsTest, TestMinimalSendResponse) { + EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + EXPECT_CALL(async_stream_, sendMessage(_, _)); + createHdsDelegate(); + + // Create Message + message.reset(new envoy::service::discovery::v2::HealthCheckSpecifier); + message->mutable_interval()->set_seconds(1); + + // Process message and send 2 responses + EXPECT_CALL(*server_response_timer_, enableTimer(_)).Times(AtLeast(1)); + EXPECT_CALL(async_stream_, sendMessage(_, _)).Times(2); + hds_delegate_->onReceiveMessage(std::move(message)); + hds_delegate_->sendResponse(); + server_response_timer_cb_(); +} + +TEST_F(HdsTest, TestStreamConnectionFailure) { + EXPECT_CALL(*async_client_, start(_, _)) + .WillOnce(Return(nullptr)) + .WillOnce(Return(&async_stream_)); + EXPECT_CALL(*retry_timer_, enableTimer(_)); + EXPECT_CALL(async_stream_, sendMessage(_, _)); + + // Test connection failure and retry + createHdsDelegate(); + retry_timer_cb_(); +} + +// TODO(lilika): Add unit tests for HdsDelegate::sendResponse() with healthy and +// unhealthy endpoints. + +// Tests that SendResponse responds to the server correctly given +// a HealthCheckSpecifier message that contains a single endpoint +// which times out +TEST_F(HdsTest, TestSendResponseOneEndpointTimeout) { + EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + EXPECT_CALL(async_stream_, sendMessage(_, _)); + createHdsDelegate(); + + // Create Message + message.reset(createSimpleMessage()); + + Network::MockClientConnection* connection_ = new NiceMock(); + EXPECT_CALL(dispatcher_, createClientConnection_(_, _, _, _)).WillRepeatedly(Return(connection_)); + EXPECT_CALL(*server_response_timer_, enableTimer(_)).Times(2); + EXPECT_CALL(async_stream_, sendMessage(_, false)); + EXPECT_CALL(test_factory_, createClusterInfo(_, _, _, _, _, _, _)) + .WillOnce(Return(cluster_info_)); + EXPECT_CALL(*connection_, setBufferLimits(_)); + EXPECT_CALL(dispatcher_, deferredDelete_(_)); + // Process message + hds_delegate_->onReceiveMessage(std::move(message)); + connection_->raiseEvent(Network::ConnectionEvent::Connected); + + // Send Response + auto msg = hds_delegate_->sendResponse(); + + // Correctness + EXPECT_EQ(msg.endpoint_health_response().endpoints_health(0).health_status(), + envoy::api::v2::core::HealthStatus::UNHEALTHY); + EXPECT_EQ(msg.endpoint_health_response() + .endpoints_health(0) + .endpoint() + .address() + .socket_address() + .address(), + "127.0.0.0"); + EXPECT_EQ(msg.endpoint_health_response() + .endpoints_health(0) + .endpoint() + .address() + .socket_address() + .port_value(), + 1234); +} + +} // namespace Upstream +} // namespace Envoy diff --git a/test/common/upstream/health_checker_impl_test.cc b/test/common/upstream/health_checker_impl_test.cc index dc041e7e9cb91..3606f7a9fa2f0 100644 --- a/test/common/upstream/health_checker_impl_test.cc +++ b/test/common/upstream/health_checker_impl_test.cc @@ -16,6 +16,7 @@ #include "test/common/http/common.h" #include "test/common/upstream/utility.h" +#include "test/mocks/access_log/mocks.h" #include "test/mocks/network/mocks.h" #include "test/mocks/runtime/mocks.h" #include "test/mocks/upstream/mocks.h" @@ -58,9 +59,10 @@ TEST(HealthCheckerFactoryTest, GrpcHealthCheckHTTP2NotConfiguredException) { Runtime::MockLoader runtime; Runtime::MockRandomGenerator random; Event::MockDispatcher dispatcher; + AccessLog::MockAccessLogManager log_manager; EXPECT_THROW_WITH_MESSAGE(HealthCheckerFactory::create(createGrpcHealthCheckConfig(), cluster, - runtime, random, dispatcher), + runtime, random, dispatcher, log_manager), EnvoyException, "fake_cluster cluster must support HTTP/2 for gRPC healthchecking"); } @@ -74,10 +76,11 @@ TEST(HealthCheckerFactoryTest, createGrpc) { Runtime::MockLoader runtime; Runtime::MockRandomGenerator random; Event::MockDispatcher dispatcher; + AccessLog::MockAccessLogManager log_manager; EXPECT_NE(nullptr, dynamic_cast( HealthCheckerFactory::create(createGrpcHealthCheckConfig(), cluster, - runtime, random, dispatcher) + runtime, random, dispatcher, log_manager) .get())); } @@ -114,7 +117,8 @@ class HttpHealthCheckerImplTest : public testing::Test { const envoy::api::v2::endpoint::Endpoint::HealthCheckConfig> HostWithHealthCheckMap; - HttpHealthCheckerImplTest() : cluster_(new NiceMock()) {} + HttpHealthCheckerImplTest() + : cluster_(new NiceMock()), event_logger_(new MockHealthCheckEventLogger()) {} void setupNoServiceValidationHCWithHttp2() { const std::string yaml = R"EOF( @@ -131,7 +135,30 @@ class HttpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); + health_checker_->addHostCheckCompleteCb( + [this](HostSharedPtr host, HealthTransition changed_state) -> void { + onHostStatus(host, changed_state); + }); + } + + void setupIntervalJitterPercent() { + const std::string yaml = R"EOF( + timeout: 1s + interval: 1s + no_traffic_interval: 5s + interval_jitter_percent: 40 + unhealthy_threshold: 2 + healthy_threshold: 2 + http_health_check: + service_name: locations + path: /healthcheck + )EOF"; + + health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV2Yaml(yaml), + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -152,7 +179,8 @@ class HttpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -174,7 +202,8 @@ class HttpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV1Json(json), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -198,7 +227,8 @@ class HttpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -220,7 +250,8 @@ class HttpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV1Json(json), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -242,7 +273,8 @@ class HttpHealthCheckerImplTest : public testing::Test { host); health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -290,10 +322,29 @@ class HttpHealthCheckerImplTest : public testing::Test { key: user-agent value: CoolEnvoy/HC append: false + - header: + key: x-protocol + value: "%PROTOCOL%" + - header: + key: x-upstream-metadata + value: "%UPSTREAM_METADATA([\"namespace\", \"key\"])%" + - header: + key: x-downstream-remote-address-without-port + value: "%DOWNSTREAM_REMOTE_ADDRESS_WITHOUT_PORT%" + - header: + key: x-downstream-local-address + value: "%DOWNSTREAM_LOCAL_ADDRESS%" + - header: + key: x-downstream-local-address-without-port + value: "%DOWNSTREAM_LOCAL_ADDRESS_WITHOUT_PORT%" + - header: + key: x-start-time + value: "%START_TIME(%s.%9f)%" )EOF"; health_checker_.reset(new TestHttpHealthCheckerImpl(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -390,6 +441,7 @@ class HttpHealthCheckerImplTest : public testing::Test { std::shared_ptr health_checker_; NiceMock runtime_; NiceMock random_; + MockHealthCheckEventLogger* event_logger_{}; std::list connection_index_{}; std::list codec_index_{}; const HostWithHealthCheckMap health_checker_map_{}; @@ -416,6 +468,94 @@ TEST_F(HttpHealthCheckerImplTest, Success) { EXPECT_TRUE(cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->healthy()); } +TEST_F(HttpHealthCheckerImplTest, SuccessIntervalJitter) { + setupNoServiceValidationHC(); + EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Unchanged)).Times(testing::AnyNumber()); + + cluster_->prioritySet().getMockHostSet(0)->hosts_ = { + makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; + expectSessionCreate(); + expectStreamCreate(0); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + health_checker_->start(); + + EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); + respond(0, "200", false, true, true); + EXPECT_TRUE(cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->healthy()); + + for (int i = 0; i < 50000; i += 239) { + EXPECT_CALL(random_, random()).WillOnce(Return(i)); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + expectStreamCreate(0); + test_sessions_[0]->interval_timer_->callback_(); + // the jitter is 1000ms here + EXPECT_CALL(*test_sessions_[0]->interval_timer_, + enableTimer(std::chrono::milliseconds(5000 + i % 1000))); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); + respond(0, "200", false, true, true); + } +} + +TEST_F(HttpHealthCheckerImplTest, SuccessIntervalJitterPercentNoTraffic) { + setupIntervalJitterPercent(); + EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Unchanged)).Times(testing::AnyNumber()); + + cluster_->prioritySet().getMockHostSet(0)->hosts_ = { + makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; + expectSessionCreate(); + expectStreamCreate(0); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + health_checker_->start(); + + EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); + respond(0, "200", false, true, true); + EXPECT_TRUE(cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->healthy()); + + for (int i = 0; i < 50000; i += 239) { + EXPECT_CALL(random_, random()).WillOnce(Return(i)); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + expectStreamCreate(0); + test_sessions_[0]->interval_timer_->callback_(); + // the jitter is 40% of 5000, so should be 2000 + EXPECT_CALL(*test_sessions_[0]->interval_timer_, + enableTimer(std::chrono::milliseconds(5000 + i % 2000))); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); + respond(0, "200", false, true, true); + } +} + +TEST_F(HttpHealthCheckerImplTest, SuccessIntervalJitterPercent) { + setupIntervalJitterPercent(); + EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Unchanged)).Times(testing::AnyNumber()); + + cluster_->prioritySet().getMockHostSet(0)->hosts_ = { + makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; + cluster_->info_->stats().upstream_cx_total_.inc(); + expectSessionCreate(); + expectStreamCreate(0); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + health_checker_->start(); + + EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); + respond(0, "200", false, true, true); + EXPECT_TRUE(cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->healthy()); + + for (int i = 0; i < 50000; i += 239) { + EXPECT_CALL(random_, random()).WillOnce(Return(i)); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + expectStreamCreate(0); + test_sessions_[0]->interval_timer_->callback_(); + // the jitter is 40% of 1000, so should be 400 + EXPECT_CALL(*test_sessions_[0]->interval_timer_, + enableTimer(std::chrono::milliseconds(1000 + i % 400))); + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); + respond(0, "200", false, true, true); + } +} + TEST_F(HttpHealthCheckerImplTest, SuccessWithSpurious100Continue) { setupNoServiceValidationHC(); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Unchanged)).Times(1); @@ -575,15 +715,28 @@ TEST_F(HttpHealthCheckerImplTest, SuccessServiceCheckWithCustomHostValue) { } TEST_F(HttpHealthCheckerImplTest, SuccessServiceCheckWithAdditionalHeaders) { - const Http::LowerCaseString headerOk("x-envoy-ok"); - const Http::LowerCaseString headerCool("x-envoy-cool"); - const Http::LowerCaseString headerAwesome("x-envoy-awesome"); - - const std::string valueOk = "ok"; - const std::string valueCool = "cool"; - const std::string valueAwesome = "awesome"; - - const std::string valueUserAgent = "CoolEnvoy/HC"; + const Http::LowerCaseString header_ok("x-envoy-ok"); + const Http::LowerCaseString header_cool("x-envoy-cool"); + const Http::LowerCaseString header_awesome("x-envoy-awesome"); + const Http::LowerCaseString upstream_metadata("x-upstream-metadata"); + const Http::LowerCaseString protocol("x-protocol"); + const Http::LowerCaseString downstream_remote_address_without_port( + "x-downstream-remote-address-without-port"); + const Http::LowerCaseString downstream_local_address("x-downstream-local-address"); + const Http::LowerCaseString downstream_local_address_without_port( + "x-downstream-local-address-without-port"); + const Http::LowerCaseString start_time("x-start-time"); + + const std::string value_ok = "ok"; + const std::string value_cool = "cool"; + const std::string value_awesome = "awesome"; + + const std::string value_user_agent = "CoolEnvoy/HC"; + const std::string value_upstream_metadata = "value"; + const std::string value_protocol = "HTTP/1.1"; + const std::string value_downstream_remote_address_without_port = "127.0.0.1"; + const std::string value_downstream_local_address = "127.0.0.1:0"; + const std::string value_downstream_local_address_without_port = "127.0.0.1"; setupServiceValidationWithAdditionalHeaders(); // requires non-empty `service_name` in config. @@ -591,20 +744,39 @@ TEST_F(HttpHealthCheckerImplTest, SuccessServiceCheckWithAdditionalHeaders) { .WillOnce(Return(true)); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Unchanged)).Times(1); - + auto metadata = TestUtility::parseYaml( + R"EOF( + filter_metadata: + namespace: + key: value + )EOF"); + + std::string current_start_time; cluster_->prioritySet().getMockHostSet(0)->hosts_ = { - makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; + makeTestHost(cluster_->info_, "tcp://127.0.0.1:80", metadata)}; cluster_->info_->stats().upstream_cx_total_.inc(); expectSessionCreate(); expectStreamCreate(0); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); EXPECT_CALL(test_sessions_[0]->request_encoder_, encodeHeaders(_, true)) - .WillOnce(Invoke([&](const Http::HeaderMap& headers, bool) { - EXPECT_EQ(headers.get(headerOk)->value().c_str(), valueOk); - EXPECT_EQ(headers.get(headerCool)->value().c_str(), valueCool); - EXPECT_EQ(headers.get(headerAwesome)->value().c_str(), valueAwesome); - - EXPECT_EQ(headers.UserAgent()->value().c_str(), valueUserAgent); + .WillRepeatedly(Invoke([&](const Http::HeaderMap& headers, bool) { + EXPECT_EQ(headers.get(header_ok)->value().c_str(), value_ok); + EXPECT_EQ(headers.get(header_cool)->value().c_str(), value_cool); + EXPECT_EQ(headers.get(header_awesome)->value().c_str(), value_awesome); + + EXPECT_EQ(headers.UserAgent()->value().c_str(), value_user_agent); + EXPECT_EQ(headers.get(upstream_metadata)->value().c_str(), value_upstream_metadata); + + EXPECT_EQ(headers.get(protocol)->value().c_str(), value_protocol); + EXPECT_EQ(headers.get(downstream_remote_address_without_port)->value().c_str(), + value_downstream_remote_address_without_port); + EXPECT_EQ(headers.get(downstream_local_address)->value().c_str(), + value_downstream_local_address); + EXPECT_EQ(headers.get(downstream_local_address_without_port)->value().c_str(), + value_downstream_local_address_without_port); + + EXPECT_NE(headers.get(start_time)->value().c_str(), current_start_time); + current_start_time = headers.get(start_time)->value().c_str(); })); health_checker_->start(); @@ -616,6 +788,10 @@ TEST_F(HttpHealthCheckerImplTest, SuccessServiceCheckWithAdditionalHeaders) { absl::optional health_checked_cluster("locations-production-iad"); respond(0, "200", false, true, false, health_checked_cluster); EXPECT_TRUE(cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->healthy()); + + EXPECT_CALL(*test_sessions_[0]->timeout_timer_, enableTimer(_)); + expectStreamCreate(0); + test_sessions_[0]->interval_timer_->callback_(); } TEST_F(HttpHealthCheckerImplTest, ServiceDoesNotMatchFail) { @@ -624,6 +800,7 @@ TEST_F(HttpHealthCheckerImplTest, ServiceDoesNotMatchFail) { .WillOnce(Return(true)); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)).Times(1); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); cluster_->prioritySet().getMockHostSet(0)->hosts_ = { makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; @@ -651,6 +828,7 @@ TEST_F(HttpHealthCheckerImplTest, ServiceNotPresentInResponseFail) { .WillOnce(Return(true)); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)).Times(1); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); cluster_->prioritySet().getMockHostSet(0)->hosts_ = { makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; @@ -736,6 +914,7 @@ TEST_F(HttpHealthCheckerImplTest, SuccessStartFailedFailFirstServiceCheck) { test_sessions_[0]->interval_timer_->callback_(); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "200", false, false, false, health_checked_cluster); @@ -772,6 +951,7 @@ TEST_F(HttpHealthCheckerImplTest, SuccessStartFailedSuccessFirst) { // Test fast success immediately moves us to healthy. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)).Times(1); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, true)); EXPECT_CALL(runtime_.snapshot_, getInteger("health_check.max_interval", _)).WillOnce(Return(500)); EXPECT_CALL(runtime_.snapshot_, getInteger("health_check.min_interval", _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(500))); @@ -817,6 +997,7 @@ TEST_F(HttpHealthCheckerImplTest, SuccessStartFailedFailFirst) { test_sessions_[0]->interval_timer_->callback_(); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "200", false); @@ -833,6 +1014,7 @@ TEST_F(HttpHealthCheckerImplTest, HttpFail) { health_checker_->start(); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "503", false); @@ -857,6 +1039,7 @@ TEST_F(HttpHealthCheckerImplTest, HttpFail) { test_sessions_[0]->interval_timer_->callback_(); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "200", false); @@ -886,6 +1069,7 @@ TEST_F(HttpHealthCheckerImplTest, Disconnect) { EXPECT_CALL(*this, onHostStatus(cluster_->prioritySet().getMockHostSet(0)->hosts_[0], HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); test_sessions_[0]->client_connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); @@ -916,6 +1100,7 @@ TEST_F(HttpHealthCheckerImplTest, Timeout) { test_sessions_[0]->interval_timer_->callback_(); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(_)); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); test_sessions_[0]->client_connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); @@ -1011,6 +1196,7 @@ TEST_F(HttpHealthCheckerImplTest, HealthCheckIntervals) { // ignored and health state changes immediately. Since the threshold is ignored, next health // check respects "unhealthy_interval". EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(2000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "503", false); @@ -1067,6 +1253,7 @@ TEST_F(HttpHealthCheckerImplTest, HealthCheckIntervals) { // After the healthy threshold is reached, health state should change while checks should respect // the default interval. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(1000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "200", false); @@ -1117,6 +1304,7 @@ TEST_F(HttpHealthCheckerImplTest, HealthCheckIntervals) { // Subsequent failing checks should respect unhealthy_interval. As the unhealthy threshold is // reached, health state should also change. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(2000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); test_sessions_[0]->timeout_timer_->callback_(); @@ -1165,6 +1353,7 @@ TEST_F(HttpHealthCheckerImplTest, HealthCheckIntervals) { // After the healthy threshold is reached, health state should change while checks should respect // the default interval. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(1000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respond(0, "200", false); @@ -1384,7 +1573,8 @@ class ProdHttpHealthCheckerTest : public HttpHealthCheckerImplTest { )EOF"; health_checker_.reset(new TestProdHttpHealthChecker(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -1405,7 +1595,8 @@ class ProdHttpHealthCheckerTest : public HttpHealthCheckerImplTest { )EOF"; health_checker_.reset(new TestProdHttpHealthChecker(*cluster_, parseHealthCheckFromV2Yaml(yaml), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -1490,7 +1681,8 @@ TEST(TcpHealthCheckMatcher, match) { class TcpHealthCheckerImplTest : public testing::Test { public: - TcpHealthCheckerImplTest() : cluster_(new NiceMock()) {} + TcpHealthCheckerImplTest() + : cluster_(new NiceMock()), event_logger_(new MockHealthCheckEventLogger()) {} void setupData() { std::string json = R"EOF( @@ -1510,7 +1702,8 @@ class TcpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TcpHealthCheckerImpl(*cluster_, parseHealthCheckFromV1Json(json), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); } void setupNoData() { @@ -1527,7 +1720,8 @@ class TcpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TcpHealthCheckerImpl(*cluster_, parseHealthCheckFromV1Json(json), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); } void setupDataDontReuseConnection() { @@ -1549,7 +1743,8 @@ class TcpHealthCheckerImplTest : public testing::Test { )EOF"; health_checker_.reset(new TcpHealthCheckerImpl(*cluster_, parseHealthCheckFromV1Json(json), - dispatcher_, runtime_, random_)); + dispatcher_, runtime_, random_, + HealthCheckEventLoggerPtr(event_logger_))); } void expectSessionCreate() { @@ -1566,6 +1761,7 @@ class TcpHealthCheckerImplTest : public testing::Test { std::shared_ptr cluster_; NiceMock dispatcher_; std::shared_ptr health_checker_; + MockHealthCheckEventLogger* event_logger_{}; Network::MockClientConnection* connection_{}; Event::MockTimer* timeout_timer_{}; Event::MockTimer* interval_timer_{}; @@ -1660,6 +1856,7 @@ TEST_F(TcpHealthCheckerImplTest, Timeout) { connection_->raiseEvent(Network::ConnectionEvent::Connected); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*timeout_timer_, disableTimer()); EXPECT_CALL(*interval_timer_, enableTimer(_)); connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); @@ -1737,6 +1934,7 @@ TEST_F(TcpHealthCheckerImplTest, TimeoutWithoutReusingConnection) { connection_->raiseEvent(Network::ConnectionEvent::Connected); // Expected flow when a healthcheck times out. + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*timeout_timer_, disableTimer()); EXPECT_CALL(*interval_timer_, enableTimer(_)); connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); @@ -1782,6 +1980,7 @@ TEST_F(TcpHealthCheckerImplTest, PassiveFailure) { expectClientCreate(); EXPECT_CALL(*connection_, write(_, _)).Times(0); EXPECT_CALL(*timeout_timer_, enableTimer(_)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); health_checker_->start(); // Do multiple passive failures. This will not reset the active HC timers. @@ -1947,15 +2146,17 @@ class GrpcHealthCheckerImplTestBase { std::vector> trailers; }; - GrpcHealthCheckerImplTestBase() : cluster_(new NiceMock()) { + GrpcHealthCheckerImplTestBase() + : cluster_(new NiceMock()), event_logger_(new MockHealthCheckEventLogger()) { EXPECT_CALL(*cluster_->info_, features()) .WillRepeatedly(Return(Upstream::ClusterInfo::Features::HTTP2)); } void setupHC() { const auto config = createGrpcHealthCheckConfig(); - health_checker_.reset( - new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, random_)); + health_checker_.reset(new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, + random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -1965,8 +2166,9 @@ class GrpcHealthCheckerImplTestBase { void setupHCWithUnhealthyThreshold(int value) { auto config = createGrpcHealthCheckConfig(); config.mutable_unhealthy_threshold()->set_value(value); - health_checker_.reset( - new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, random_)); + health_checker_.reset(new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, + random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -1976,8 +2178,9 @@ class GrpcHealthCheckerImplTestBase { void setupServiceNameHC() { auto config = createGrpcHealthCheckConfig(); config.mutable_grpc_health_check()->set_service_name("service"); - health_checker_.reset( - new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, random_)); + health_checker_.reset(new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, + random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -1987,8 +2190,9 @@ class GrpcHealthCheckerImplTestBase { void setupNoReuseConnectionHC() { auto config = createGrpcHealthCheckConfig(); config.mutable_reuse_connection()->set_value(false); - health_checker_.reset( - new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, random_)); + health_checker_.reset(new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, + random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -2005,8 +2209,9 @@ class GrpcHealthCheckerImplTestBase { config.mutable_interval_jitter()->set_seconds(0); config.mutable_unhealthy_threshold()->set_value(3); config.mutable_healthy_threshold()->set_value(3); - health_checker_.reset( - new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, random_)); + health_checker_.reset(new TestGrpcHealthCheckerImpl(*cluster_, config, dispatcher_, runtime_, + random_, + HealthCheckEventLoggerPtr(event_logger_))); health_checker_->addHostCheckCompleteCb( [this](HostSharedPtr host, HealthTransition changed_state) -> void { onHostStatus(host, changed_state); @@ -2163,6 +2368,7 @@ class GrpcHealthCheckerImplTestBase { std::shared_ptr health_checker_; NiceMock runtime_; NiceMock random_; + MockHealthCheckEventLogger* event_logger_{}; std::list connection_index_{}; std::list codec_index_{}; }; @@ -2338,6 +2544,7 @@ TEST_F(GrpcHealthCheckerImplTest, SuccessStartFailedSuccessFirst) { expectHealthcheckStop(0, 500); // Fast success immediately moves us to healthy. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, true)); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::SERVING); expectHostHealthy(true); } @@ -2378,6 +2585,7 @@ TEST_F(GrpcHealthCheckerImplTest, SuccessStartFailedFailFirst) { expectHealthcheckStop(0); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::SERVING); expectHostHealthy(true); } @@ -2395,6 +2603,7 @@ TEST_F(GrpcHealthCheckerImplTest, GrpcHealthFail) { // Explicit healthcheck failure immediately renders host unhealthy. expectHealthcheckStop(0); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::NOT_SERVING); expectHostHealthy(false); @@ -2414,6 +2623,7 @@ TEST_F(GrpcHealthCheckerImplTest, GrpcHealthFail) { expectHealthcheckStop(0); // Host should has become healthy. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::SERVING); expectHostHealthy(true); } @@ -2441,6 +2651,7 @@ TEST_F(GrpcHealthCheckerImplTest, Disconnect) { expectHealthcheckStop(0); // Now, host should be unhealthy. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); test_sessions_[0]->client_connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); expectHostHealthy(false); } @@ -2466,6 +2677,7 @@ TEST_F(GrpcHealthCheckerImplTest, Timeout) { expectHealthcheckStop(0); EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); // Close connection. Timeouts and connection closes counts together. test_sessions_[0]->client_connection_->raiseEvent(Network::ConnectionEvent::RemoteClose); expectHostHealthy(false); @@ -2537,6 +2749,7 @@ TEST_F(GrpcHealthCheckerImplTest, HealthCheckIntervals) { // ignored and health state changes immediately. Since the threshold is ignored, next health // check respects "unhealthy_interval". EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(2000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::NOT_SERVING); @@ -2593,6 +2806,7 @@ TEST_F(GrpcHealthCheckerImplTest, HealthCheckIntervals) { // After the healthy threshold is reached, health state should change while checks should respect // the default interval. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(1000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::SERVING); @@ -2639,6 +2853,7 @@ TEST_F(GrpcHealthCheckerImplTest, HealthCheckIntervals) { // Subsequent failing checks should respect unhealthy_interval. As the unhealthy threshold is // reached, health state should also change. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(2000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); test_sessions_[0]->timeout_timer_->callback_(); @@ -2683,6 +2898,7 @@ TEST_F(GrpcHealthCheckerImplTest, HealthCheckIntervals) { // After the healthy threshold is reached, health state should change while checks should respect // the default interval. EXPECT_CALL(*this, onHostStatus(_, HealthTransition::Changed)); + EXPECT_CALL(*event_logger_, logAddHealthy(_, _, false)); EXPECT_CALL(*test_sessions_[0]->interval_timer_, enableTimer(std::chrono::milliseconds(1000))); EXPECT_CALL(*test_sessions_[0]->timeout_timer_, disableTimer()); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::SERVING); @@ -2761,6 +2977,7 @@ TEST_F(GrpcHealthCheckerImplTest, DontReuseConnectionBetweenChecks) { TEST_F(GrpcHealthCheckerImplTest, GrpcFailUnknown) { setupHC(); expectSingleHealthcheck(HealthTransition::Changed); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); respondServiceStatus(0, grpc::health::v1::HealthCheckResponse::UNKNOWN); EXPECT_TRUE(cluster_->prioritySet().getMockHostSet(0)->hosts_[0]->healthFlagGet( @@ -2774,6 +2991,7 @@ TEST_F(GrpcHealthCheckerImplTest, GoAwayProbeInProgress) { // is reached. setupHCWithUnhealthyThreshold(1); expectSingleHealthcheck(HealthTransition::Changed); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); test_sessions_[0]->codec_client_->raiseGoAway(); @@ -2910,6 +3128,7 @@ INSTANTIATE_TEST_CASE_P( TEST_P(BadResponseGrpcHealthCheckerImplTest, GrpcBadResponse) { setupHC(); expectSingleHealthcheck(HealthTransition::Changed); + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); ResponseSpec spec = GetParam(); respondResponseSpec(0, std::move(spec)); @@ -2936,6 +3155,34 @@ TEST(Printer, HealthTransitionPrinter) { EXPECT_EQ("Unchanged", unchanged.str()); } +TEST(HealthCheckEventLoggerImplTest, All) { + AccessLog::MockAccessLogManager log_manager; + std::shared_ptr file(new Filesystem::MockFile()); + EXPECT_CALL(log_manager, createAccessLog("foo")).WillOnce(Return(file)); + + std::shared_ptr host(new NiceMock()); + NiceMock cluster; + ON_CALL(*host, cluster()).WillByDefault(ReturnRef(cluster)); + + HealthCheckEventLoggerImpl event_logger(log_manager, "foo"); + + EXPECT_CALL(*file, write(absl::string_view{ + "{\"health_checker_type\":\"HTTP\",\"host\":{\"socket_address\":{" + "\"protocol\":\"TCP\",\"address\":\"10.0.0.1\",\"resolver_name\":\"\"," + "\"ipv4_compat\":false,\"port_value\":443}},\"cluster_name\":\"fake_" + "cluster\",\"eject_unhealthy_event\":{\"failure_type\":\"ACTIVE\"}}\n"})); + event_logger.logEjectUnhealthy(envoy::data::core::v2alpha::HealthCheckerType::HTTP, host, + envoy::data::core::v2alpha::HealthCheckFailureType::ACTIVE); + + EXPECT_CALL(*file, write(absl::string_view{ + "{\"health_checker_type\":\"HTTP\",\"host\":{\"socket_address\":{" + "\"protocol\":\"TCP\",\"address\":\"10.0.0.1\",\"resolver_name\":\"\"," + "\"ipv4_compat\":false,\"port_value\":443}},\"cluster_name\":\"fake_" + "cluster\",\"add_healthy_event\":{\"first_check\":false}}\n"})); + + event_logger.logAddHealthy(envoy::data::core::v2alpha::HealthCheckerType::HTTP, host, false); +} + } // namespace } // namespace Upstream } // namespace Envoy diff --git a/test/common/upstream/load_stats_reporter_test.cc b/test/common/upstream/load_stats_reporter_test.cc index e56e600993a29..a2bd231f9adad 100644 --- a/test/common/upstream/load_stats_reporter_test.cc +++ b/test/common/upstream/load_stats_reporter_test.cc @@ -28,7 +28,7 @@ class LoadStatsReporterTest : public testing::Test { LoadStatsReporterTest() : retry_timer_(new Event::MockTimer()), response_timer_(new Event::MockTimer()), async_client_(new Grpc::MockAsyncClient()) { - node_.set_id("foo"); + node_.set_id("baz"); } void createLoadStatsReporter() { @@ -42,7 +42,7 @@ class LoadStatsReporterTest : public testing::Test { return response_timer_; })); load_stats_reporter_.reset(new LoadStatsReporter( - node_, cm_, stats_store_, Grpc::AsyncClientPtr(async_client_), dispatcher_)); + node_, cm_, stats_store_, Grpc::AsyncClientPtr(async_client_), dispatcher_, time_source_)); } void expectSendMessage( @@ -76,6 +76,7 @@ class LoadStatsReporterTest : public testing::Test { Event::TimerCb response_timer_cb_; Grpc::MockAsyncStream async_stream_; Grpc::MockAsyncClient* async_client_; + MockMonotonicTimeSource time_source_; }; // Validate that stream creation results in a timer based retry. @@ -92,17 +93,132 @@ TEST_F(LoadStatsReporterTest, TestPubSub) { EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); EXPECT_CALL(async_stream_, sendMessage(_, _)); createLoadStatsReporter(); + EXPECT_CALL(time_source_, currentTime()); deliverLoadStatsResponse({"foo"}); EXPECT_CALL(async_stream_, sendMessage(_, _)); EXPECT_CALL(*response_timer_, enableTimer(std::chrono::milliseconds(42000))); response_timer_cb_(); + EXPECT_CALL(time_source_, currentTime()); + deliverLoadStatsResponse({"bar"}); + EXPECT_CALL(async_stream_, sendMessage(_, _)); EXPECT_CALL(*response_timer_, enableTimer(std::chrono::milliseconds(42000))); response_timer_cb_(); +} + +// Validate treatment of existing clusters across updates. +TEST_F(LoadStatsReporterTest, ExistingClusters) { + EXPECT_CALL(*async_client_, start(_, _)).WillOnce(Return(&async_stream_)); + // Initially, we have no clusters to report on. + expectSendMessage({}); + createLoadStatsReporter(); + EXPECT_CALL(time_source_, currentTime()) + .WillOnce(Return(MonotonicTime(std::chrono::microseconds(3)))); + // Start reporting on foo. + NiceMock foo_cluster; + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(2); + NiceMock bar_cluster; + MockClusterManager::ClusterInfoMap cluster_info{{"foo", foo_cluster}, {"bar", bar_cluster}}; + ON_CALL(cm_, clusters()).WillByDefault(Return(cluster_info)); + deliverLoadStatsResponse({"foo"}); + // Initial stats report for foo on timer tick. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(5); + EXPECT_CALL(time_source_, currentTime()) + .WillOnce(Return(MonotonicTime(std::chrono::microseconds(4)))); + { + envoy::api::v2::endpoint::ClusterStats foo_cluster_stats; + foo_cluster_stats.set_cluster_name("foo"); + foo_cluster_stats.set_total_dropped_requests(5); + foo_cluster_stats.mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration(1)); + expectSendMessage({foo_cluster_stats}); + } + EXPECT_CALL(*response_timer_, enableTimer(std::chrono::milliseconds(42000))); + response_timer_cb_(); + // Some traffic on foo/bar in between previous request and next response. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + bar_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + + // Start reporting on bar. + EXPECT_CALL(time_source_, currentTime()) + .WillOnce(Return(MonotonicTime(std::chrono::microseconds(6)))); + deliverLoadStatsResponse({"foo", "bar"}); + // Stats report foo/bar on timer tick. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + bar_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + EXPECT_CALL(time_source_, currentTime()) + .Times(2) + .WillRepeatedly(Return(MonotonicTime(std::chrono::microseconds(28)))); + { + envoy::api::v2::endpoint::ClusterStats foo_cluster_stats; + foo_cluster_stats.set_cluster_name("foo"); + foo_cluster_stats.set_total_dropped_requests(2); + foo_cluster_stats.mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration(24)); + envoy::api::v2::endpoint::ClusterStats bar_cluster_stats; + bar_cluster_stats.set_cluster_name("bar"); + bar_cluster_stats.set_total_dropped_requests(1); + bar_cluster_stats.mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration(22)); + expectSendMessage({bar_cluster_stats, foo_cluster_stats}); + } + EXPECT_CALL(*response_timer_, enableTimer(std::chrono::milliseconds(42000))); + response_timer_cb_(); + + // Some traffic on foo/bar in between previous request and next response. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + bar_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + + // Stop reporting on foo. deliverLoadStatsResponse({"bar"}); + // Stats report for bar on timer tick. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(5); + bar_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(5); + EXPECT_CALL(time_source_, currentTime()) + .WillOnce(Return(MonotonicTime(std::chrono::microseconds(33)))); + { + envoy::api::v2::endpoint::ClusterStats bar_cluster_stats; + bar_cluster_stats.set_cluster_name("bar"); + bar_cluster_stats.set_total_dropped_requests(6); + bar_cluster_stats.mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration(5)); + expectSendMessage({bar_cluster_stats}); + } + EXPECT_CALL(*response_timer_, enableTimer(std::chrono::milliseconds(42000))); + response_timer_cb_(); + + // Some traffic on foo/bar in between previous request and next response. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + bar_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + + // Start tracking foo again, we should forget earlier history for foo. + EXPECT_CALL(time_source_, currentTime()) + .WillOnce(Return(MonotonicTime(std::chrono::microseconds(43)))); + deliverLoadStatsResponse({"foo", "bar"}); + // Stats report foo/bar on timer tick. + foo_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + bar_cluster.info_->load_report_stats_.upstream_rq_dropped_.add(1); + EXPECT_CALL(time_source_, currentTime()) + .Times(2) + .WillRepeatedly(Return(MonotonicTime(std::chrono::microseconds(47)))); + { + envoy::api::v2::endpoint::ClusterStats foo_cluster_stats; + foo_cluster_stats.set_cluster_name("foo"); + foo_cluster_stats.set_total_dropped_requests(1); + foo_cluster_stats.mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration(4)); + envoy::api::v2::endpoint::ClusterStats bar_cluster_stats; + bar_cluster_stats.set_cluster_name("bar"); + bar_cluster_stats.set_total_dropped_requests(2); + bar_cluster_stats.mutable_load_report_interval()->MergeFrom( + Protobuf::util::TimeUtil::MicrosecondsToDuration(14)); + expectSendMessage({bar_cluster_stats, foo_cluster_stats}); + } + EXPECT_CALL(*response_timer_, enableTimer(std::chrono::milliseconds(42000))); + response_timer_cb_(); } // Validate that the client can recover from a remote stream closure via retry. diff --git a/test/common/upstream/logical_dns_cluster_test.cc b/test/common/upstream/logical_dns_cluster_test.cc index 6f09675f48533..80a5c7b705fc5 100644 --- a/test/common/upstream/logical_dns_cluster_test.cc +++ b/test/common/upstream/logical_dns_cluster_test.cc @@ -9,6 +9,7 @@ #include "test/common/upstream/utility.h" #include "test/mocks/common.h" +#include "test/mocks/local_info/mocks.h" #include "test/mocks/network/mocks.h" #include "test/mocks/runtime/mocks.h" #include "test/mocks/ssl/mocks.h" @@ -26,14 +27,16 @@ using testing::_; namespace Envoy { namespace Upstream { +enum class ConfigType { V2_YAML, V1_JSON }; + class LogicalDnsClusterTest : public testing::Test { public: - void setup(const std::string& json) { + void setupFromV1Json(const std::string& json) { resolve_timer_ = new Event::MockTimer(&dispatcher_); NiceMock cm; cluster_.reset(new LogicalDnsCluster(parseClusterFromJson(json), runtime_, stats_store_, - ssl_context_manager_, dns_resolver_, tls_, cm, dispatcher_, - false)); + ssl_context_manager_, local_info_, dns_resolver_, tls_, cm, + dispatcher_, false)); cluster_->prioritySet().addMemberUpdateCb( [&](uint32_t, const HostVector&, const HostVector&) -> void { membership_updated_.ready(); @@ -41,8 +44,22 @@ class LogicalDnsClusterTest : public testing::Test { cluster_->initialize([&]() -> void { initialized_.ready(); }); } - void expectResolve(Network::DnsLookupFamily dns_lookup_family) { - EXPECT_CALL(*dns_resolver_, resolve("foo.bar.com", dns_lookup_family, _)) + void setupFromV2Yaml(const std::string& yaml) { + resolve_timer_ = new Event::MockTimer(&dispatcher_); + NiceMock cm; + cluster_.reset(new LogicalDnsCluster(parseClusterFromV2Yaml(yaml), runtime_, stats_store_, + ssl_context_manager_, local_info_, dns_resolver_, tls_, cm, + dispatcher_, false)); + cluster_->prioritySet().addMemberUpdateCb( + [&](uint32_t, const HostVector&, const HostVector&) -> void { + membership_updated_.ready(); + }); + cluster_->initialize([&]() -> void { initialized_.ready(); }); + } + + void expectResolve(Network::DnsLookupFamily dns_lookup_family, + const std::string& expected_address) { + EXPECT_CALL(*dns_resolver_, resolve(expected_address, dns_lookup_family, _)) .WillOnce(Invoke([&](const std::string&, Network::DnsLookupFamily, Network::DnsResolver::ResolveCb cb) -> Network::ActiveDnsQuery* { dns_callback_ = cb; @@ -50,6 +67,102 @@ class LogicalDnsClusterTest : public testing::Test { })); } + void testBasicSetup(const std::string& config, const std::string& expected_address, + ConfigType config_type = ConfigType::V2_YAML) { + expectResolve(Network::DnsLookupFamily::V4Only, expected_address); + if (config_type == ConfigType::V1_JSON) { + setupFromV1Json(config); + } else { + setupFromV2Yaml(config); + } + + EXPECT_CALL(membership_updated_, ready()); + EXPECT_CALL(initialized_, ready()); + EXPECT_CALL(*resolve_timer_, enableTimer(std::chrono::milliseconds(4000))); + dns_callback_(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2"})); + + EXPECT_EQ(1UL, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts().size()); + EXPECT_EQ(1UL, cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); + EXPECT_EQ(1UL, + cluster_->prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); + EXPECT_EQ( + 1UL, + cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHostsPerLocality().get().size()); + EXPECT_EQ(cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0], + cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHosts()[0]); + HostSharedPtr logical_host = cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]; + + EXPECT_CALL(dispatcher_, + createClientConnection_( + PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.1:443")), _, _, _)) + .WillOnce(Return(new NiceMock())); + logical_host->createConnection(dispatcher_, nullptr); + logical_host->outlierDetector().putHttpResponseCode(200); + + expectResolve(Network::DnsLookupFamily::V4Only, expected_address); + resolve_timer_->callback_(); + + // Should not cause any changes. + EXPECT_CALL(*resolve_timer_, enableTimer(_)); + dns_callback_(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2", "127.0.0.3"})); + + EXPECT_EQ(logical_host, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]); + EXPECT_CALL(dispatcher_, + createClientConnection_( + PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.1:443")), _, _, _)) + .WillOnce(Return(new NiceMock())); + Host::CreateConnectionData data = logical_host->createConnection(dispatcher_, nullptr); + EXPECT_FALSE(data.host_description_->canary()); + EXPECT_EQ(&cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]->cluster(), + &data.host_description_->cluster()); + EXPECT_EQ(&cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]->stats(), + &data.host_description_->stats()); + EXPECT_EQ("127.0.0.1:443", data.host_description_->address()->asString()); + EXPECT_EQ("", data.host_description_->locality().region()); + EXPECT_EQ("", data.host_description_->locality().zone()); + EXPECT_EQ("", data.host_description_->locality().sub_zone()); + EXPECT_EQ("foo.bar.com", data.host_description_->hostname()); + EXPECT_TRUE(TestUtility::protoEqual(envoy::api::v2::core::Metadata::default_instance(), + *data.host_description_->metadata())); + data.host_description_->outlierDetector().putHttpResponseCode(200); + data.host_description_->healthChecker().setUnhealthy(); + + expectResolve(Network::DnsLookupFamily::V4Only, expected_address); + resolve_timer_->callback_(); + + // Should cause a change. + EXPECT_CALL(*resolve_timer_, enableTimer(_)); + dns_callback_(TestUtility::makeDnsResponse({"127.0.0.3", "127.0.0.1", "127.0.0.2"})); + + EXPECT_EQ(logical_host, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]); + EXPECT_CALL(dispatcher_, + createClientConnection_( + PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.3:443")), _, _, _)) + .WillOnce(Return(new NiceMock())); + logical_host->createConnection(dispatcher_, nullptr); + + expectResolve(Network::DnsLookupFamily::V4Only, expected_address); + resolve_timer_->callback_(); + + // Empty should not cause any change. + EXPECT_CALL(*resolve_timer_, enableTimer(_)); + dns_callback_({}); + + EXPECT_EQ(logical_host, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]); + EXPECT_CALL(dispatcher_, + createClientConnection_( + PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.3:443")), _, _, _)) + .WillOnce(Return(new NiceMock())); + logical_host->createConnection(dispatcher_, nullptr); + + // Make sure we cancel. + EXPECT_CALL(active_dns_query_, cancel()); + expectResolve(Network::DnsLookupFamily::V4Only, expected_address); + resolve_timer_->callback_(); + + tls_.shutdownThread(); + } + Stats::IsolatedStoreImpl stats_store_; Ssl::MockContextManager ssl_context_manager_; std::shared_ptr> dns_resolver_{ @@ -63,6 +176,7 @@ class LogicalDnsClusterTest : public testing::Test { ReadyWatcher initialized_; NiceMock runtime_; NiceMock dispatcher_; + NiceMock local_info_; }; typedef std::tuple> @@ -127,7 +241,7 @@ TEST_P(LogicalDnsParamTest, ImmediateResolve) { cb(TestUtility::makeDnsResponse(std::get<2>(GetParam()))); return nullptr; })); - setup(json); + setupFromV1Json(json); EXPECT_EQ(1UL, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts().size()); EXPECT_EQ(1UL, cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); EXPECT_EQ("foo.bar.com", @@ -137,7 +251,7 @@ TEST_P(LogicalDnsParamTest, ImmediateResolve) { } TEST_F(LogicalDnsClusterTest, BadConfig) { - const std::string json = R"EOF( + const std::string multiple_hosts_json = R"EOF( { "name": "name", "connect_timeout_ms": 250, @@ -147,7 +261,72 @@ TEST_F(LogicalDnsClusterTest, BadConfig) { } )EOF"; - EXPECT_THROW(setup(json), EnvoyException); + EXPECT_THROW_WITH_MESSAGE(setupFromV1Json(multiple_hosts_json), EnvoyException, + "LOGICAL_DNS clusters must have a single host"); + + const std::string multiple_lb_endpoints_yaml = R"EOF( + name: name + type: LOGICAL_DNS + dns_refresh_rate: 4s + connect_timeout: 0.25s + lb_policy: ROUND_ROBIN + dns_lookup_family: V4_ONLY + load_assignment: + cluster_name: name + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: foo.bar.com + port_value: 443 + health_check_config: + port_value: 8000 + - endpoint: + address: + socket_address: + address: hello.world.com + port_value: 443 + health_check_config: + port_value: 8000 + )EOF"; + + EXPECT_THROW_WITH_MESSAGE( + setupFromV2Yaml(multiple_lb_endpoints_yaml), EnvoyException, + "LOGICAL_DNS clusters must have a single locality_lb_endpoint and a single lb_endpoint"); + + const std::string multiple_endpoints_yaml = R"EOF( + name: name + type: LOGICAL_DNS + dns_refresh_rate: 4s + connect_timeout: 0.25s + lb_policy: ROUND_ROBIN + dns_lookup_family: V4_ONLY + load_assignment: + cluster_name: name + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: foo.bar.com + port_value: 443 + health_check_config: + port_value: 8000 + + - lb_endpoints: + - endpoint: + address: + socket_address: + address: hello.world.com + port_value: 443 + health_check_config: + port_value: 8000 + )EOF"; + + EXPECT_THROW_WITH_MESSAGE( + setupFromV2Yaml(multiple_endpoints_yaml), EnvoyException, + "LOGICAL_DNS clusters must have a single locality_lb_endpoint and a single lb_endpoint"); } TEST_F(LogicalDnsClusterTest, Basic) { @@ -162,93 +341,46 @@ TEST_F(LogicalDnsClusterTest, Basic) { } )EOF"; - expectResolve(Network::DnsLookupFamily::V4Only); - setup(json); - - EXPECT_CALL(membership_updated_, ready()); - EXPECT_CALL(initialized_, ready()); - EXPECT_CALL(*resolve_timer_, enableTimer(std::chrono::milliseconds(4000))); - dns_callback_(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2"})); + const std::string basic_yaml_hosts = R"EOF( + name: name + type: LOGICAL_DNS + dns_refresh_rate: 4s + connect_timeout: 0.25s + lb_policy: ROUND_ROBIN + # Since the following expectResolve() requires Network::DnsLookupFamily::V4Only we need to set + # dns_lookup_family to V4_ONLY explicitly for v2 .yaml config. + dns_lookup_family: V4_ONLY + hosts: + - socket_address: + address: foo.bar.com + port_value: 443 + )EOF"; - EXPECT_EQ(1UL, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts().size()); - EXPECT_EQ(1UL, cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); - EXPECT_EQ(0UL, cluster_->prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); - EXPECT_EQ( - 0UL, - cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHostsPerLocality().get().size()); - EXPECT_EQ(cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0], - cluster_->prioritySet().hostSetsPerPriority()[0]->healthyHosts()[0]); - HostSharedPtr logical_host = cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]; - - EXPECT_CALL(dispatcher_, - createClientConnection_( - PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.1:443")), _, _, _)) - .WillOnce(Return(new NiceMock())); - logical_host->createConnection(dispatcher_, nullptr); - logical_host->outlierDetector().putHttpResponseCode(200); - - expectResolve(Network::DnsLookupFamily::V4Only); - resolve_timer_->callback_(); - - // Should not cause any changes. - EXPECT_CALL(*resolve_timer_, enableTimer(_)); - dns_callback_(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2", "127.0.0.3"})); - - EXPECT_EQ(logical_host, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]); - EXPECT_CALL(dispatcher_, - createClientConnection_( - PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.1:443")), _, _, _)) - .WillOnce(Return(new NiceMock())); - Host::CreateConnectionData data = logical_host->createConnection(dispatcher_, nullptr); - EXPECT_FALSE(data.host_description_->canary()); - EXPECT_EQ(&cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]->cluster(), - &data.host_description_->cluster()); - EXPECT_EQ(&cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]->stats(), - &data.host_description_->stats()); - EXPECT_EQ("127.0.0.1:443", data.host_description_->address()->asString()); - EXPECT_EQ("", data.host_description_->locality().region()); - EXPECT_EQ("", data.host_description_->locality().zone()); - EXPECT_EQ("", data.host_description_->locality().sub_zone()); - EXPECT_EQ("foo.bar.com", data.host_description_->hostname()); - EXPECT_EQ(&envoy::api::v2::core::Metadata::default_instance(), - &data.host_description_->metadata()); - data.host_description_->outlierDetector().putHttpResponseCode(200); - data.host_description_->healthChecker().setUnhealthy(); - - expectResolve(Network::DnsLookupFamily::V4Only); - resolve_timer_->callback_(); - - // Should cause a change. - EXPECT_CALL(*resolve_timer_, enableTimer(_)); - dns_callback_(TestUtility::makeDnsResponse({"127.0.0.3", "127.0.0.1", "127.0.0.2"})); - - EXPECT_EQ(logical_host, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]); - EXPECT_CALL(dispatcher_, - createClientConnection_( - PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.3:443")), _, _, _)) - .WillOnce(Return(new NiceMock())); - logical_host->createConnection(dispatcher_, nullptr); - - expectResolve(Network::DnsLookupFamily::V4Only); - resolve_timer_->callback_(); - - // Empty should not cause any change. - EXPECT_CALL(*resolve_timer_, enableTimer(_)); - dns_callback_({}); - - EXPECT_EQ(logical_host, cluster_->prioritySet().hostSetsPerPriority()[0]->hosts()[0]); - EXPECT_CALL(dispatcher_, - createClientConnection_( - PointeesEq(Network::Utility::resolveUrl("tcp://127.0.0.3:443")), _, _, _)) - .WillOnce(Return(new NiceMock())); - logical_host->createConnection(dispatcher_, nullptr); - - // Make sure we cancel. - EXPECT_CALL(active_dns_query_, cancel()); - expectResolve(Network::DnsLookupFamily::V4Only); - resolve_timer_->callback_(); + const std::string basic_yaml_load_assignment = R"EOF( + name: name + type: LOGICAL_DNS + dns_refresh_rate: 4s + connect_timeout: 0.25s + lb_policy: ROUND_ROBIN + # Since the following expectResolve() requires Network::DnsLookupFamily::V4Only we need to set + # dns_lookup_family to V4_ONLY explicitly for v2 .yaml config. + dns_lookup_family: V4_ONLY + load_assignment: + cluster_name: name + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: foo.bar.com + port_value: 443 + health_check_config: + port_value: 8000 + )EOF"; - tls_.shutdownThread(); + testBasicSetup(json, "foo.bar.com", ConfigType::V1_JSON); + testBasicSetup(basic_yaml_hosts, "foo.bar.com"); + testBasicSetup(basic_yaml_load_assignment, "foo.bar.com"); } } // namespace Upstream diff --git a/test/common/upstream/subset_lb_test.cc b/test/common/upstream/subset_lb_test.cc index 2ca3b5fb26b57..4da2cfbb0811c 100644 --- a/test/common/upstream/subset_lb_test.cc +++ b/test/common/upstream/subset_lb_test.cc @@ -134,6 +134,31 @@ class SubsetLoadBalancerTest : public testing::TestWithParam { host_set.healthy_hosts_per_locality_ = host_set.hosts_per_locality_; } + void configureWeightedHostSet(const HostURLMetadataMap& first_locality_host_metadata, + const HostURLMetadataMap& second_locality_host_metadata, + MockHostSet& host_set, LocalityWeights locality_weights) { + HostVector first_locality; + HostVector all_hosts; + for (const auto& it : first_locality_host_metadata) { + auto host = makeHost(it.first, it.second); + first_locality.emplace_back(host); + all_hosts.emplace_back(host); + } + + HostVector second_locality; + for (const auto& it : second_locality_host_metadata) { + auto host = makeHost(it.first, it.second); + second_locality.emplace_back(host); + all_hosts.emplace_back(host); + } + + host_set.hosts_ = all_hosts; + host_set.hosts_per_locality_ = makeHostsPerLocality({first_locality, second_locality}); + host_set.healthy_hosts_ = host_set.hosts_; + host_set.healthy_hosts_per_locality_ = host_set.hosts_per_locality_; + host_set.locality_weights_ = std::make_shared(locality_weights); + } + void init(const HostURLMetadataMap& host_metadata) { HostURLMetadataMap failover; init(host_metadata, failover); @@ -323,6 +348,25 @@ class SubsetLoadBalancerTest : public testing::TestWithParam { EXPECT_EQ(added_host, lb_->chooseHost(nullptr)); } + envoy::api::v2::core::Metadata buildMetadata(const std::string& version, + bool is_default = false) const { + envoy::api::v2::core::Metadata metadata; + + if (version != "") { + Envoy::Config::Metadata::mutableMetadataValue( + metadata, Config::MetadataFilters::get().ENVOY_LB, "version") + .set_string_value(version); + } + + if (is_default) { + Envoy::Config::Metadata::mutableMetadataValue( + metadata, Config::MetadataFilters::get().ENVOY_LB, "default") + .set_string_value("true"); + } + + return metadata; + } + LoadBalancerType lb_type_{LoadBalancerType::RoundRobin}; NiceMock priority_set_; MockHostSet& host_set_ = *priority_set_.getMockHostSet(0); @@ -554,6 +598,174 @@ TEST_P(SubsetLoadBalancerTest, UpdateFailover) { EXPECT_FALSE(nullptr == lb_->chooseHost(&context_10).get()); } +TEST_P(SubsetLoadBalancerTest, OnlyMetadataChanged) { + TestLoadBalancerContext context_10({{"version", "1.0"}}); + TestLoadBalancerContext context_12({{"version", "1.2"}}); + TestLoadBalancerContext context_13({{"version", "1.3"}}); + TestLoadBalancerContext context_default({{"default", "true"}}); + const ProtobufWkt::Struct default_subset = makeDefaultSubset({{"default", "true"}}); + + EXPECT_CALL(subset_info_, defaultSubset()).WillRepeatedly(ReturnRef(default_subset)); + EXPECT_CALL(subset_info_, fallbackPolicy()) + .WillRepeatedly(Return(envoy::api::v2::Cluster::LbSubsetConfig::DEFAULT_SUBSET)); + + std::vector> subset_keys = {{"version"}, {"default"}}; + EXPECT_CALL(subset_info_, subsetKeys()).WillRepeatedly(ReturnRef(subset_keys)); + + // Add hosts initial hosts. + init({{"tcp://127.0.0.1:8000", {{"version", "1.2"}}}, + {"tcp://127.0.0.1:8001", {{"version", "1.0"}, {"default", "true"}}}}); + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(3U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(0U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_13)); + + // Swap the default version. + host_set_.hosts_[0]->metadata(buildMetadata("1.2", true)); + host_set_.hosts_[1]->metadata(buildMetadata("1.0")); + + host_set_.runCallbacks({}, {}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(3U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(0U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_13)); + + // Bump 1.0 to 1.3, one subset should be removed. + host_set_.hosts_[1]->metadata(buildMetadata("1.3")); + + // No hosts added nor removed, so we bypass modifyHosts(). + host_set_.runCallbacks({}, {}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(4U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(1U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_13)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_10)); + + // Rollback from 1.3 to 1.0. + host_set_.hosts_[1]->metadata(buildMetadata("1.0")); + + host_set_.runCallbacks({}, {}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(5U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(2U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_13)); + + // Make 1.0 default again. + host_set_.hosts_[1]->metadata(buildMetadata("1.0", true)); + host_set_.hosts_[0]->metadata(buildMetadata("1.2")); + + host_set_.runCallbacks({}, {}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(5U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(2U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_13)); +} + +TEST_P(SubsetLoadBalancerTest, MetadataChangedHostsAddedRemoved) { + TestLoadBalancerContext context_10({{"version", "1.0"}}); + TestLoadBalancerContext context_12({{"version", "1.2"}}); + TestLoadBalancerContext context_13({{"version", "1.3"}}); + TestLoadBalancerContext context_14({{"version", "1.4"}}); + TestLoadBalancerContext context_default({{"default", "true"}}); + const ProtobufWkt::Struct default_subset = makeDefaultSubset({{"default", "true"}}); + + EXPECT_CALL(subset_info_, defaultSubset()).WillRepeatedly(ReturnRef(default_subset)); + EXPECT_CALL(subset_info_, fallbackPolicy()) + .WillRepeatedly(Return(envoy::api::v2::Cluster::LbSubsetConfig::DEFAULT_SUBSET)); + + std::vector> subset_keys = {{"version"}, {"default"}}; + EXPECT_CALL(subset_info_, subsetKeys()).WillRepeatedly(ReturnRef(subset_keys)); + + // Add hosts initial hosts. + init({{"tcp://127.0.0.1:8000", {{"version", "1.2"}}}, + {"tcp://127.0.0.1:8001", {{"version", "1.0"}, {"default", "true"}}}}); + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(3U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(0U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_13)); + + // Swap the default version. + host_set_.hosts_[0]->metadata(buildMetadata("1.2", true)); + host_set_.hosts_[1]->metadata(buildMetadata("1.0")); + + // Add a new host. + modifyHosts({makeHost("tcp://127.0.0.1:8002", {{"version", "1.3"}})}, {}); + + EXPECT_EQ(4U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(4U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(0U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[2], lb_->chooseHost(&context_13)); + + // Swap default again and remove the previous one. + host_set_.hosts_[0]->metadata(buildMetadata("1.2")); + host_set_.hosts_[1]->metadata(buildMetadata("1.0", true)); + + modifyHosts({}, {host_set_.hosts_[2]}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(4U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(1U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_13)); + + // Swap the default version once more, this time adding a new host and removing + // the current default version. + host_set_.hosts_[0]->metadata(buildMetadata("1.2", true)); + host_set_.hosts_[1]->metadata(buildMetadata("1.0")); + + modifyHosts({makeHost("tcp://127.0.0.1:8003", {{"version", "1.4"}})}, {host_set_.hosts_[1]}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(5U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(2U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_13)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_14)); + + // Make 1.4 default, without hosts being added/removed. + host_set_.hosts_[0]->metadata(buildMetadata("1.2")); + host_set_.hosts_[1]->metadata(buildMetadata("1.4", true)); + + host_set_.runCallbacks({}, {}); + + EXPECT_EQ(3U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(5U, stats_.lb_subsets_created_.value()); + EXPECT_EQ(2U, stats_.lb_subsets_removed_.value()); + EXPECT_EQ(host_set_.hosts_[0], lb_->chooseHost(&context_12)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_10)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_default)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_13)); + EXPECT_EQ(host_set_.hosts_[1], lb_->chooseHost(&context_14)); +} + TEST_P(SubsetLoadBalancerTest, UpdateRemovingLastSubsetHost) { EXPECT_CALL(subset_info_, fallbackPolicy()) .WillRepeatedly(Return(envoy::api::v2::Cluster::LbSubsetConfig::ANY_ENDPOINT)); @@ -1166,6 +1378,81 @@ TEST_F(SubsetLoadBalancerTest, DescribeMetadata) { tester.test("", {}); } +TEST_F(SubsetLoadBalancerTest, DisabledLocalityWeightAwareness) { + EXPECT_CALL(subset_info_, isEnabled()).WillRepeatedly(Return(true)); + + // We configure a weighted host set that heavily favors the second locality. + configureWeightedHostSet( + { + {"tcp://127.0.0.1:80", {{"version", "1.0"}}}, + {"tcp://127.0.0.1:81", {{"version", "1.1"}}}, + }, + { + {"tcp://127.0.0.1:82", {{"version", "1.0"}}}, + {"tcp://127.0.0.1:83", {{"version", "1.1"}}}, + {"tcp://127.0.0.1:84", {{"version", "1.0"}}}, + {"tcp://127.0.0.1:85", {{"version", "1.1"}}}, + }, + host_set_, {1, 100}); + + lb_.reset(new SubsetLoadBalancer(lb_type_, priority_set_, nullptr, stats_, runtime_, random_, + subset_info_, ring_hash_lb_config_, common_config_)); + + TestLoadBalancerContext context({{"version", "1.1"}}); + + // Since we don't respect locality weights, the first locality is selected. + EXPECT_CALL(random_, random()).WillOnce(Return(0)); + EXPECT_EQ(host_set_.healthy_hosts_per_locality_->get()[0][0], lb_->chooseHost(&context)); +} + +TEST_F(SubsetLoadBalancerTest, EnabledLocalityWeightAwareness) { + EXPECT_CALL(subset_info_, isEnabled()).WillRepeatedly(Return(true)); + EXPECT_CALL(subset_info_, localityWeightAware()).WillRepeatedly(Return(true)); + + // We configure a weighted host set that heavily favors the second locality. + configureWeightedHostSet( + { + {"tcp://127.0.0.1:80", {{"version", "1.0"}}}, + {"tcp://127.0.0.1:81", {{"version", "1.1"}}}, + }, + { + {"tcp://127.0.0.1:82", {{"version", "1.0"}}}, + {"tcp://127.0.0.1:83", {{"version", "1.1"}}}, + {"tcp://127.0.0.1:84", {{"version", "1.0"}}}, + {"tcp://127.0.0.1:85", {{"version", "1.1"}}}, + }, + host_set_, {1, 100}); + + lb_.reset(new SubsetLoadBalancer(lb_type_, priority_set_, nullptr, stats_, runtime_, random_, + subset_info_, ring_hash_lb_config_, common_config_)); + + TestLoadBalancerContext context({{"version", "1.1"}}); + + // Since we respect locality weights, the second locality is selected. + EXPECT_CALL(random_, random()).WillOnce(Return(0)); + EXPECT_EQ(host_set_.healthy_hosts_per_locality_->get()[1][0], lb_->chooseHost(&context)); +} + +TEST_P(SubsetLoadBalancerTest, GaugesUpdatedOnDestroy) { + EXPECT_CALL(subset_info_, fallbackPolicy()) + .WillRepeatedly(Return(envoy::api::v2::Cluster::LbSubsetConfig::ANY_ENDPOINT)); + + std::vector> subset_keys = {{"version"}}; + EXPECT_CALL(subset_info_, subsetKeys()).WillRepeatedly(ReturnRef(subset_keys)); + + init({ + {"tcp://127.0.0.1:80", {{"version", "1.0"}}}, + }); + + EXPECT_EQ(1U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(0U, stats_.lb_subsets_removed_.value()); + + lb_ = nullptr; + + EXPECT_EQ(0U, stats_.lb_subsets_active_.value()); + EXPECT_EQ(1U, stats_.lb_subsets_removed_.value()); +} + INSTANTIATE_TEST_CASE_P(UpdateOrderings, SubsetLoadBalancerTest, testing::ValuesIn({REMOVES_FIRST, SIMULTANEOUS})); diff --git a/test/common/upstream/upstream_impl_test.cc b/test/common/upstream/upstream_impl_test.cc index 17888a9006304..f4b43c393d2d6 100644 --- a/test/common/upstream/upstream_impl_test.cc +++ b/test/common/upstream/upstream_impl_test.cc @@ -17,6 +17,7 @@ #include "test/common/upstream/utility.h" #include "test/mocks/common.h" +#include "test/mocks/local_info/mocks.h" #include "test/mocks/network/mocks.h" #include "test/mocks/runtime/mocks.h" #include "test/mocks/ssl/mocks.h" @@ -119,6 +120,7 @@ TEST_P(StrictDnsParamTest, ImmediateResolve) { auto dns_resolver = std::make_shared>(); NiceMock dispatcher; NiceMock runtime; + NiceMock local_info; ReadyWatcher initialized; const std::string json = R"EOF( @@ -141,7 +143,7 @@ TEST_P(StrictDnsParamTest, ImmediateResolve) { })); NiceMock cm; StrictDnsClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, - dns_resolver, cm, dispatcher, false); + local_info, dns_resolver, cm, dispatcher, false); cluster.initialize([&]() -> void { initialized.ready(); }); EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->hosts().size()); EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); @@ -155,6 +157,7 @@ TEST(StrictDnsClusterImplTest, ZeroHostsHealthChecker) { NiceMock dispatcher; NiceMock runtime; NiceMock cm; + NiceMock local_info; ReadyWatcher initialized; const std::string yaml = R"EOF( @@ -167,7 +170,7 @@ TEST(StrictDnsClusterImplTest, ZeroHostsHealthChecker) { ResolverData resolver(*dns_resolver, dispatcher); StrictDnsClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, - dns_resolver, cm, dispatcher, false); + local_info, dns_resolver, cm, dispatcher, false); std::shared_ptr health_checker(new MockHealthChecker()); EXPECT_CALL(*health_checker, start()); EXPECT_CALL(*health_checker, addHostCheckCompleteCb(_)); @@ -188,6 +191,7 @@ TEST(StrictDnsClusterImplTest, Basic) { auto dns_resolver = std::make_shared>(); NiceMock dispatcher; NiceMock runtime; + NiceMock local_info; // gmock matches in LIFO order which is why these are swapped. ResolverData resolver2(*dns_resolver, dispatcher); @@ -225,7 +229,7 @@ TEST(StrictDnsClusterImplTest, Basic) { NiceMock cm; StrictDnsClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, - dns_resolver, cm, dispatcher, false); + local_info, dns_resolver, cm, dispatcher, false); EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.default.max_connections", 43)); EXPECT_EQ(43U, cluster.info()->resourceManager(ResourcePriority::Default).connections().max()); EXPECT_CALL(runtime.snapshot_, @@ -302,8 +306,8 @@ TEST(StrictDnsClusterImplTest, Basic) { ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); - EXPECT_EQ(0UL, cluster.prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); - EXPECT_EQ(0UL, + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHostsPerLocality().get().size()); for (const HostSharedPtr& host : cluster.prioritySet().hostSetsPerPriority()[0]->hosts()) { @@ -329,6 +333,7 @@ TEST(StrictDnsClusterImplTest, HostRemovalActiveHealthSkipped) { NiceMock dispatcher; NiceMock runtime; NiceMock cm; + NiceMock local_info; const std::string yaml = R"EOF( name: name @@ -341,7 +346,7 @@ TEST(StrictDnsClusterImplTest, HostRemovalActiveHealthSkipped) { ResolverData resolver(*dns_resolver, dispatcher); StrictDnsClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, - dns_resolver, cm, dispatcher, false); + local_info, dns_resolver, cm, dispatcher, false); std::shared_ptr health_checker(new MockHealthChecker()); EXPECT_CALL(*health_checker, start()); EXPECT_CALL(*health_checker, addHostCheckCompleteCb(_)); @@ -372,6 +377,298 @@ TEST(StrictDnsClusterImplTest, HostRemovalActiveHealthSkipped) { EXPECT_EQ(1UL, hosts.size()); } +TEST(StrictDnsClusterImplTest, LoadAssignmentBasic) { + Stats::IsolatedStoreImpl stats; + Ssl::MockContextManager ssl_context_manager; + auto dns_resolver = std::make_shared>(); + NiceMock dispatcher; + NiceMock runtime; + NiceMock local_info; + + // gmock matches in LIFO order which is why these are swapped. + ResolverData resolver2(*dns_resolver, dispatcher); + ResolverData resolver1(*dns_resolver, dispatcher); + + const std::string yaml = R"EOF( + name: name + type: STRICT_DNS + + dns_lookup_family: V4_ONLY + connect_timeout: 0.25s + dns_refresh_rate: 4s + + lb_policy: ROUND_ROBIN + + circuit_breakers: + thresholds: + - priority: DEFAULT + max_connections: 43 + max_pending_requests: 57 + max_requests: 50 + max_retries: 10 + - priority: HIGH + max_connections: 1 + max_pending_requests: 2 + max_requests: 3 + max_retries: 4 + + max_requests_per_connection: 3 + + http2_protocol_options: + hpack_table_size: 0 + + load_assignment: + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: localhost1 + port_value: 11001 + health_check_config: + port_value: 8000 + - endpoint: + address: + socket_address: + address: localhost2 + port_value: 11002 + health_check_config: + port_value: 8000 + )EOF"; + + NiceMock cm; + StrictDnsClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, dns_resolver, cm, dispatcher, false); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.default.max_connections", 43)); + EXPECT_EQ(43U, cluster.info()->resourceManager(ResourcePriority::Default).connections().max()); + EXPECT_CALL(runtime.snapshot_, + getInteger("circuit_breakers.name.default.max_pending_requests", 57)); + EXPECT_EQ(57U, + cluster.info()->resourceManager(ResourcePriority::Default).pendingRequests().max()); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.default.max_requests", 50)); + EXPECT_EQ(50U, cluster.info()->resourceManager(ResourcePriority::Default).requests().max()); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.default.max_retries", 10)); + EXPECT_EQ(10U, cluster.info()->resourceManager(ResourcePriority::Default).retries().max()); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.high.max_connections", 1)); + EXPECT_EQ(1U, cluster.info()->resourceManager(ResourcePriority::High).connections().max()); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.high.max_pending_requests", 2)); + EXPECT_EQ(2U, cluster.info()->resourceManager(ResourcePriority::High).pendingRequests().max()); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.high.max_requests", 3)); + EXPECT_EQ(3U, cluster.info()->resourceManager(ResourcePriority::High).requests().max()); + EXPECT_CALL(runtime.snapshot_, getInteger("circuit_breakers.name.high.max_retries", 4)); + EXPECT_EQ(4U, cluster.info()->resourceManager(ResourcePriority::High).retries().max()); + EXPECT_EQ(3U, cluster.info()->maxRequestsPerConnection()); + EXPECT_EQ(0U, cluster.info()->http2Settings().hpack_table_size_); + + cluster.info()->stats().upstream_rq_total_.inc(); + EXPECT_EQ(1UL, stats.counter("cluster.name.upstream_rq_total").value()); + + EXPECT_CALL(runtime.snapshot_, featureEnabled("upstream.maintenance_mode.name", 0)); + EXPECT_FALSE(cluster.info()->maintenanceMode()); + + ReadyWatcher membership_updated; + cluster.prioritySet().addMemberUpdateCb( + [&](uint32_t, const HostVector&, const HostVector&) -> void { membership_updated.ready(); }); + + cluster.initialize([] {}); + + resolver1.expectResolve(*dns_resolver); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2"})); + EXPECT_THAT( + std::list({"127.0.0.1:11001", "127.0.0.2:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + EXPECT_EQ("localhost1", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[0]->hostname()); + EXPECT_EQ("localhost1", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[1]->hostname()); + + resolver1.expectResolve(*dns_resolver); + resolver1.timer_->callback_(); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.2", "127.0.0.1"})); + EXPECT_THAT( + std::list({"127.0.0.1:11001", "127.0.0.2:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + resolver1.expectResolve(*dns_resolver); + resolver1.timer_->callback_(); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.2", "127.0.0.1"})); + EXPECT_THAT( + std::list({"127.0.0.1:11001", "127.0.0.2:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + resolver1.timer_->callback_(); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.3"})); + EXPECT_THAT( + std::list({"127.0.0.3:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + // Make sure we de-dup the same address. + EXPECT_CALL(*resolver2.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver2.dns_callback_(TestUtility::makeDnsResponse({"10.0.0.1", "10.0.0.1"})); + EXPECT_THAT( + std::list({"127.0.0.3:11001", "10.0.0.1:11002"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); + EXPECT_EQ(1UL, + cluster.prioritySet().hostSetsPerPriority()[0]->healthyHostsPerLocality().get().size()); + + for (const HostSharedPtr& host : cluster.prioritySet().hostSetsPerPriority()[0]->hosts()) { + EXPECT_EQ(cluster.info().get(), &host->cluster()); + } + + // Make sure we cancel. + resolver1.expectResolve(*dns_resolver); + resolver1.timer_->callback_(); + resolver2.expectResolve(*dns_resolver); + resolver2.timer_->callback_(); + + EXPECT_CALL(resolver1.active_dns_query_, cancel()); + EXPECT_CALL(resolver2.active_dns_query_, cancel()); +} + +TEST(StrictDnsClusterImplTest, LoadAssignmentBasicMultiplePriorities) { + Stats::IsolatedStoreImpl stats; + Ssl::MockContextManager ssl_context_manager; + auto dns_resolver = std::make_shared>(); + NiceMock dispatcher; + NiceMock runtime; + NiceMock local_info; + + // gmock matches in LIFO order which is why these are swapped. + ResolverData resolver3(*dns_resolver, dispatcher); + ResolverData resolver2(*dns_resolver, dispatcher); + ResolverData resolver1(*dns_resolver, dispatcher); + + const std::string yaml = R"EOF( + name: name + type: STRICT_DNS + + dns_lookup_family: V4_ONLY + connect_timeout: 0.25s + dns_refresh_rate: 4s + + lb_policy: ROUND_ROBIN + + load_assignment: + endpoints: + - priority: 0 + lb_endpoints: + - endpoint: + address: + socket_address: + address: localhost1 + port_value: 11001 + health_check_config: + port_value: 8000 + - endpoint: + address: + socket_address: + address: localhost2 + port_value: 11002 + health_check_config: + port_value: 8000 + + - priority: 1 + lb_endpoints: + - endpoint: + address: + socket_address: + address: localhost3 + port_value: 11003 + health_check_config: + port_value: 8000 + )EOF"; + + NiceMock cm; + StrictDnsClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, dns_resolver, cm, dispatcher, false); + + ReadyWatcher membership_updated; + cluster.prioritySet().addMemberUpdateCb( + [&](uint32_t, const HostVector&, const HostVector&) -> void { membership_updated.ready(); }); + + cluster.initialize([] {}); + + resolver1.expectResolve(*dns_resolver); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.1", "127.0.0.2"})); + EXPECT_THAT( + std::list({"127.0.0.1:11001", "127.0.0.2:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + EXPECT_EQ("localhost1", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[0]->hostname()); + EXPECT_EQ("localhost1", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[1]->hostname()); + + resolver1.expectResolve(*dns_resolver); + resolver1.timer_->callback_(); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.2", "127.0.0.1"})); + EXPECT_THAT( + std::list({"127.0.0.1:11001", "127.0.0.2:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + resolver1.expectResolve(*dns_resolver); + resolver1.timer_->callback_(); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.2", "127.0.0.1"})); + EXPECT_THAT( + std::list({"127.0.0.1:11001", "127.0.0.2:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + resolver1.timer_->callback_(); + EXPECT_CALL(*resolver1.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver1.dns_callback_(TestUtility::makeDnsResponse({"127.0.0.3"})); + EXPECT_THAT( + std::list({"127.0.0.3:11001"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + // Make sure we de-dup the same address. + EXPECT_CALL(*resolver2.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver2.dns_callback_(TestUtility::makeDnsResponse({"10.0.0.1", "10.0.0.1"})); + EXPECT_THAT( + std::list({"127.0.0.3:11001", "10.0.0.1:11002"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); + + EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); + EXPECT_EQ(1UL, + cluster.prioritySet().hostSetsPerPriority()[0]->healthyHostsPerLocality().get().size()); + + for (const HostSharedPtr& host : cluster.prioritySet().hostSetsPerPriority()[0]->hosts()) { + EXPECT_EQ(cluster.info().get(), &host->cluster()); + } + + EXPECT_CALL(*resolver3.timer_, enableTimer(std::chrono::milliseconds(4000))); + EXPECT_CALL(membership_updated, ready()); + resolver3.dns_callback_(TestUtility::makeDnsResponse({"192.168.1.1", "192.168.1.2"})); + + // Make sure we have multiple priorities. + EXPECT_THAT( + std::list({"192.168.1.1:11003", "192.168.1.2:11003"}), + ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[1]->hosts()))); + + // Make sure we cancel. + resolver1.expectResolve(*dns_resolver); + resolver1.timer_->callback_(); + resolver2.expectResolve(*dns_resolver); + resolver2.timer_->callback_(); + resolver3.expectResolve(*dns_resolver); + resolver3.timer_->callback_(); + + EXPECT_CALL(resolver1.active_dns_query_, cancel()); + EXPECT_CALL(resolver2.active_dns_query_, cancel()); + EXPECT_CALL(resolver3.active_dns_query_, cancel()); +} + TEST(HostImplTest, HostCluster) { MockCluster cluster; HostSharedPtr host = makeTestHost(cluster.info_, "tcp://10.0.0.1:1234", 1); @@ -419,10 +716,38 @@ TEST(HostImplTest, HostnameCanaryAndLocality) { EXPECT_EQ("world", host.locality().sub_zone()); } +TEST(StaticClusterImplTest, InitialHosts) { + Stats::IsolatedStoreImpl stats; + Ssl::MockContextManager ssl_context_manager; + NiceMock runtime; + NiceMock local_info; + const std::string yaml = R"EOF( + name: staticcluster + connect_timeout: 0.25s + type: STATIC + lb_policy: ROUND_ROBIN + hosts: + - socket_address: + address: 10.0.0.1 + port_value: 443 + )EOF"; + + NiceMock cm; + StaticClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, cm, false); + cluster.initialize([] {}); + + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); + EXPECT_EQ("", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[0]->hostname()); + EXPECT_FALSE(cluster.info()->addedViaApi()); +} + TEST(StaticClusterImplTest, EmptyHostname) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + const std::string json = R"EOF( { "name": "staticcluster", @@ -434,8 +759,41 @@ TEST(StaticClusterImplTest, EmptyHostname) { )EOF"; NiceMock cm; - StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, cm, - false); + StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, + local_info, cm, false); + cluster.initialize([] {}); + + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); + EXPECT_EQ("", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[0]->hostname()); + EXPECT_FALSE(cluster.info()->addedViaApi()); +} + +TEST(StaticClusterImplTest, LoadAssignmentEmptyHostname) { + Stats::IsolatedStoreImpl stats; + Ssl::MockContextManager ssl_context_manager; + NiceMock runtime; + NiceMock local_info; + + const std::string yaml = R"EOF( + name: staticcluster + connect_timeout: 0.25s + type: STATIC + lb_policy: ROUND_ROBIN + load_assignment: + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: 10.0.0.1 + port_value: 443 + health_check_config: + port_value: 8000 + )EOF"; + + NiceMock cm; + StaticClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, cm, false); cluster.initialize([] {}); EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); @@ -443,10 +801,114 @@ TEST(StaticClusterImplTest, EmptyHostname) { EXPECT_FALSE(cluster.info()->addedViaApi()); } +TEST(StaticClusterImplTest, LoadAssignmentMultiplePriorities) { + Stats::IsolatedStoreImpl stats; + Ssl::MockContextManager ssl_context_manager; + NiceMock runtime; + NiceMock local_info; + + const std::string yaml = R"EOF( + name: staticcluster + connect_timeout: 0.25s + type: STATIC + lb_policy: ROUND_ROBIN + load_assignment: + endpoints: + - priority: 0 + lb_endpoints: + - endpoint: + address: + socket_address: + address: 10.0.0.1 + port_value: 443 + health_check_config: + port_value: 8000 + - endpoint: + address: + socket_address: + address: 10.0.0.2 + port_value: 443 + health_check_config: + port_value: 8000 + + - priority: 1 + lb_endpoints: + - endpoint: + address: + socket_address: + address: 10.0.0.3 + port_value: 443 + health_check_config: + port_value: 8000 + )EOF"; + + NiceMock cm; + StaticClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, cm, false); + cluster.initialize([] {}); + + EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[1]->healthyHosts().size()); + EXPECT_EQ("", cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[0]->hostname()); + EXPECT_FALSE(cluster.info()->addedViaApi()); +} + +TEST(StaticClusterImplTest, LoadAssignmentLocality) { + Stats::IsolatedStoreImpl stats; + Ssl::MockContextManager ssl_context_manager; + NiceMock runtime; + NiceMock local_info; + + const std::string yaml = R"EOF( + name: staticcluster + connect_timeout: 0.25s + type: STATIC + lb_policy: ROUND_ROBIN + load_assignment: + endpoints: + - locality: + region: oceania + zone: hello + sub_zone: world + lb_endpoints: + - endpoint: + address: + socket_address: + address: 10.0.0.1 + port_value: 443 + health_check_config: + port_value: 8000 + - endpoint: + address: + socket_address: + address: 10.0.0.2 + port_value: 443 + health_check_config: + port_value: 8000 + )EOF"; + + NiceMock cm; + StaticClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, cm, false); + cluster.initialize([] {}); + + auto& hosts = cluster.prioritySet().hostSetsPerPriority()[0]->hosts(); + EXPECT_EQ(hosts.size(), 2); + for (int i = 0; i < 2; ++i) { + const auto& locality = hosts[i]->locality(); + EXPECT_EQ("oceania", locality.region()); + EXPECT_EQ("hello", locality.zone()); + EXPECT_EQ("world", locality.sub_zone()); + } + EXPECT_EQ(nullptr, cluster.prioritySet().hostSetsPerPriority()[0]->localityWeights()); + EXPECT_FALSE(cluster.info()->addedViaApi()); +} + TEST(StaticClusterImplTest, AltStatName) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; const std::string yaml = R"EOF( name: staticcluster @@ -458,8 +920,8 @@ TEST(StaticClusterImplTest, AltStatName) { )EOF"; NiceMock cm; - StaticClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, cm, - false); + StaticClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, + local_info, cm, false); cluster.initialize([] {}); // Increment a stat and verify it is emitted with alt_stat_name cluster.info()->stats().upstream_rq_total_.inc(); @@ -470,6 +932,8 @@ TEST(StaticClusterImplTest, RingHash) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + const std::string json = R"EOF( { "name": "staticcluster", @@ -481,8 +945,8 @@ TEST(StaticClusterImplTest, RingHash) { )EOF"; NiceMock cm; - StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, cm, - true); + StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, + local_info, cm, true); cluster.initialize([] {}); EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); @@ -494,6 +958,8 @@ TEST(StaticClusterImplTest, OutlierDetector) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + const std::string json = R"EOF( { "name": "addressportconfig", @@ -506,8 +972,8 @@ TEST(StaticClusterImplTest, OutlierDetector) { )EOF"; NiceMock cm; - StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, cm, - false); + StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, + local_info, cm, false); Outlier::MockDetector* detector = new Outlier::MockDetector(); EXPECT_CALL(*detector, addChangedStateCb(_)); @@ -541,6 +1007,8 @@ TEST(StaticClusterImplTest, HealthyStat) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + const std::string json = R"EOF( { "name": "addressportconfig", @@ -553,8 +1021,8 @@ TEST(StaticClusterImplTest, HealthyStat) { )EOF"; NiceMock cm; - StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, cm, - false); + StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, + local_info, cm, false); Outlier::MockDetector* outlier_detector = new NiceMock(); cluster.setOutlierDetector(Outlier::DetectorSharedPtr{outlier_detector}); @@ -623,6 +1091,8 @@ TEST(StaticClusterImplTest, UrlConfig) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + const std::string json = R"EOF( { "name": "addressportconfig", @@ -635,8 +1105,8 @@ TEST(StaticClusterImplTest, UrlConfig) { )EOF"; NiceMock cm; - StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, cm, - false); + StaticClusterImpl cluster(parseClusterFromJson(json), runtime, stats, ssl_context_manager, + local_info, cm, false); cluster.initialize([] {}); EXPECT_EQ(1024U, cluster.info()->resourceManager(ResourcePriority::Default).connections().max()); @@ -656,8 +1126,8 @@ TEST(StaticClusterImplTest, UrlConfig) { std::list({"10.0.0.1:11001", "10.0.0.2:11002"}), ContainerEq(hostListToAddresses(cluster.prioritySet().hostSetsPerPriority()[0]->hosts()))); EXPECT_EQ(2UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHosts().size()); - EXPECT_EQ(0UL, cluster.prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); - EXPECT_EQ(0UL, + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->hostsPerLocality().get().size()); + EXPECT_EQ(1UL, cluster.prioritySet().hostSetsPerPriority()[0]->healthyHostsPerLocality().get().size()); cluster.prioritySet().hostSetsPerPriority()[0]->hosts()[0]->healthChecker().setUnhealthy(); } @@ -667,6 +1137,8 @@ TEST(StaticClusterImplTest, UnsupportedLBType) { Ssl::MockContextManager ssl_context_manager; NiceMock runtime; NiceMock cm; + NiceMock local_info; + const std::string json = R"EOF( { "name": "addressportconfig", @@ -678,15 +1150,17 @@ TEST(StaticClusterImplTest, UnsupportedLBType) { } )EOF"; - EXPECT_THROW( - StaticClusterImpl(parseClusterFromJson(json), runtime, stats, ssl_context_manager, cm, false), - EnvoyException); + EXPECT_THROW(StaticClusterImpl(parseClusterFromJson(json), runtime, stats, ssl_context_manager, + local_info, cm, false), + EnvoyException); } TEST(StaticClusterImplTest, MalformedHostIP) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + const std::string yaml = R"EOF( name: name connect_timeout: 0.25s @@ -697,7 +1171,7 @@ TEST(StaticClusterImplTest, MalformedHostIP) { NiceMock cm; EXPECT_THROW_WITH_MESSAGE(StaticClusterImpl(parseClusterFromV2Yaml(yaml), runtime, stats, - ssl_context_manager, cm, false), + ssl_context_manager, local_info, cm, false), EnvoyException, "malformed IP address: foo.bar.com. Consider setting resolver_name or " "setting cluster type to 'STRICT_DNS' or 'LOGICAL_DNS'"); @@ -739,6 +1213,8 @@ TEST(StaticClusterImplTest, SourceAddressPriority) { Stats::IsolatedStoreImpl stats; Ssl::MockContextManager ssl_context_manager; NiceMock runtime; + NiceMock local_info; + envoy::api::v2::Cluster config; config.set_name("staticcluster"); config.mutable_connect_timeout(); @@ -747,7 +1223,7 @@ TEST(StaticClusterImplTest, SourceAddressPriority) { // If the cluster manager gets a source address from the bootstrap proto, use it. NiceMock cm; cm.bind_config_.mutable_source_address()->set_address("1.2.3.5"); - StaticClusterImpl cluster(config, runtime, stats, ssl_context_manager, cm, false); + StaticClusterImpl cluster(config, runtime, stats, ssl_context_manager, local_info, cm, false); EXPECT_EQ("1.2.3.5:0", cluster.info()->sourceAddress()->asString()); } @@ -756,7 +1232,7 @@ TEST(StaticClusterImplTest, SourceAddressPriority) { { // Verify source address from cluster config is used when present. NiceMock cm; - StaticClusterImpl cluster(config, runtime, stats, ssl_context_manager, cm, false); + StaticClusterImpl cluster(config, runtime, stats, ssl_context_manager, local_info, cm, false); EXPECT_EQ(cluster_address, cluster.info()->sourceAddress()->ip()->addressAsString()); } @@ -764,7 +1240,7 @@ TEST(StaticClusterImplTest, SourceAddressPriority) { // The source address from cluster config takes precedence over one from the bootstrap proto. NiceMock cm; cm.bind_config_.mutable_source_address()->set_address("1.2.3.5"); - StaticClusterImpl cluster(config, runtime, stats, ssl_context_manager, cm, false); + StaticClusterImpl cluster(config, runtime, stats, ssl_context_manager, local_info, cm, false); EXPECT_EQ(cluster_address, cluster.info()->sourceAddress()->ip()->addressAsString()); } } @@ -778,6 +1254,7 @@ TEST(ClusterImplTest, CloseConnectionsOnHostHealthFailure) { NiceMock dispatcher; NiceMock runtime; NiceMock cm; + NiceMock local_info; ReadyWatcher initialized; const std::string yaml = R"EOF( @@ -789,7 +1266,7 @@ TEST(ClusterImplTest, CloseConnectionsOnHostHealthFailure) { hosts: [{ socket_address: { address: foo.bar.com, port_value: 443 }}] )EOF"; StrictDnsClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, - dns_resolver, cm, dispatcher, false); + local_info, dns_resolver, cm, dispatcher, false); EXPECT_TRUE(cluster.info()->features() & ClusterInfo::Features::CLOSE_CONNECTIONS_ON_HOST_HEALTH_FAILURE); } @@ -851,6 +1328,7 @@ TEST(ClusterMetadataTest, Metadata) { NiceMock dispatcher; NiceMock runtime; NiceMock cm; + NiceMock local_info; ReadyWatcher initialized; const std::string yaml = R"EOF( @@ -866,7 +1344,7 @@ TEST(ClusterMetadataTest, Metadata) { )EOF"; StrictDnsClusterImpl cluster(parseClusterFromV2Yaml(yaml), runtime, stats, ssl_context_manager, - dns_resolver, cm, dispatcher, false); + local_info, dns_resolver, cm, dispatcher, false); EXPECT_EQ("test_value", Config::Metadata::metadataValue(cluster.info()->metadata(), "com.bar.foo", "baz") .string_value()); diff --git a/test/common/upstream/utility.h b/test/common/upstream/utility.h index 9d1c43d382402..716e2f2edee7c 100644 --- a/test/common/upstream/utility.h +++ b/test/common/upstream/utility.h @@ -1,11 +1,13 @@ #pragma once +#include "envoy/stats/stats.h" #include "envoy/upstream/upstream.h" #include "common/common/utility.h" #include "common/config/cds_json.h" #include "common/json/json_loader.h" #include "common/network/utility.h" +#include "common/stats/stats_impl.h" #include "common/upstream/upstream_impl.h" #include "fmt/printf.h" @@ -46,8 +48,10 @@ inline std::string clustersJson(const std::vector& clusters) { inline envoy::api::v2::Cluster parseClusterFromJson(const std::string& json_string) { envoy::api::v2::Cluster cluster; auto json_object_ptr = Json::Factory::loadFromString(json_string); + Stats::StatsOptionsImpl stats_options; Config::CdsJson::translateCluster(*json_object_ptr, - absl::optional(), cluster); + absl::optional(), cluster, + stats_options); return cluster; } @@ -66,7 +70,8 @@ parseSdsClusterFromJson(const std::string& json_string, const envoy::api::v2::core::ConfigSource eds_config) { envoy::api::v2::Cluster cluster; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster); + Stats::StatsOptionsImpl stats_options; + Config::CdsJson::translateCluster(*json_object_ptr, eds_config, cluster, stats_options); return cluster; } diff --git a/test/config/integration/BUILD b/test/config/integration/BUILD index 2f4916e77be19..21faf9456494d 100644 --- a/test/config/integration/BUILD +++ b/test/config/integration/BUILD @@ -10,6 +10,7 @@ envoy_package() exports_files([ "echo_server.json", "server.json", + "server.yaml", "server_ads.yaml", "server_cors_filter.json", "server_grpc_json_transcoder.json", @@ -19,6 +20,7 @@ exports_files([ "server_ssl.json", "server_uds.json", "server_unix_listener.json", + "server_unix_listener.yaml", "server_xfcc.json", "tcp_proxy.json", ]) @@ -37,8 +39,8 @@ filegroup( filegroup( name = "server_config_files", srcs = [ - "server.json", - "server_unix_listener.json", + "server.yaml", + "server_unix_listener.yaml", ], ) diff --git a/test/config/integration/server.yaml b/test/config/integration/server.yaml new file mode 100644 index 0000000000000..4e68a14193bb2 --- /dev/null +++ b/test/config/integration/server.yaml @@ -0,0 +1,460 @@ +static_resources: + listeners: + - address: + socket_address: + address: "{{ ip_loopback_address }}" + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + drain_timeout_ms: 5000 + route_config: + virtual_hosts: + - require_ssl: all + routes: + - cluster: cluster_1 + prefix: "/" + domains: + - www.redirect.com + name: redirect + - routes: + - prefix: "/" + cluster: cluster_1 + runtime: + key: some_key + default: 0 + - prefix: "/test/long/url" + rate_limits: + - actions: + - type: destination_cluster + cluster: cluster_1 + - prefix: "/test/" + cluster: cluster_2 + - prefix: "/websocket/test" + prefix_rewrite: "/websocket" + cluster: cluster_1 + use_websocket: true + domains: + - "*" + name: integration + codec_type: http1 + stat_prefix: router + filters: + - type: both + name: health_check + config: + endpoint: "/healthcheck" + pass_through_mode: false + - type: decoder + name: rate_limit + config: + domain: foo + - type: decoder + name: router + config: {} + access_log: + - format: '[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% + %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% + %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" + "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" + "%REQUEST_DURATION%" "%RESPONSE_DURATION%"' + path: "/dev/null" + filter: + filters: + - type: status_code + op: ">=" + value: 500 + - type: duration + op: ">=" + value: 1000000 + type: logical_or + - path: "/dev/null" + deprecated_v1: true + deprecated_v1: + type: read + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + filters: + - type: both + name: health_check + config: + endpoint: "/healthcheck" + pass_through_mode: false + - name: rate_limit + config: + domain: foo + type: decoder + - type: decoder + name: router + config: {} + access_log: + - filter: + type: logical_or + filters: + - value: 500 + type: status_code + op: ">=" + - type: duration + op: ">=" + value: 1555500 + format: '[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% + %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% + %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" + "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" + "%REQUEST_DURATION%" "%RESPONSE_DURATION%"' + path: "/dev/null" + - path: "/dev/null" + drain_timeout_ms: 5000 + route_config: + virtual_hosts: + - routes: + - prefix: "/" + cluster: cluster_1 + domains: + - www.redirect.com + name: redirect + require_ssl: all + - routes: + - prefix: "/" + cluster: cluster_1 + domains: + - www.namewithport.com:1234 + name: redirect + require_ssl: all + - routes: + - cluster: cluster_1 + runtime: + key: some_key + default: 0 + prefix: "/" + - rate_limits: + - actions: + - type: destination_cluster + cluster: cluster_1 + prefix: "/test/long/url" + - prefix: "/test/" + cluster: cluster_2 + - cluster: cluster_1 + use_websocket: true + prefix: "/websocket/test" + prefix_rewrite: "/websocket" + domains: + - "*" + name: integration + codec_type: http1 + stat_prefix: router + http1_settings: + allow_absolute_url: true + deprecated_v1: true + deprecated_v1: + type: read + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + route_config: + virtual_hosts: + - routes: + - cluster: cluster_3 + prefix: "/test/long/url" + domains: + - "*" + name: integration + filters: + - name: router + config: {} + type: decoder + codec_type: http1 + stat_prefix: router + deprecated_v1: true + deprecated_v1: + type: read + per_connection_buffer_limit_bytes: 1024 + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + filters: + - type: both + name: http_dynamo_filter + config: {} + - name: router + config: {} + type: decoder + codec_type: http1 + stat_prefix: router + route_config: + virtual_hosts: + - routes: + - cluster: cluster_3 + prefix: "/dynamo/url" + domains: + - "*" + name: integration + deprecated_v1: true + deprecated_v1: + type: read + per_connection_buffer_limit_bytes: 1024 + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + route_config: + virtual_hosts: + - domains: + - "*" + name: integration + routes: + - prefix: "/test/long/url" + cluster: cluster_3 + filters: + - type: both + name: grpc_http1_bridge + config: {} + - type: decoder + name: router + config: {} + codec_type: http1 + stat_prefix: router + deprecated_v1: true + deprecated_v1: + type: read + per_connection_buffer_limit_bytes: 1024 + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + drain_timeout_ms: 5000 + route_config: + virtual_hosts: + - routes: + - cluster: cluster_1 + prefix: "/" + domains: + - www.redirect.com + name: redirect + require_ssl: all + - routes: + - cluster: cluster_1 + runtime: + key: some_key + default: 0 + prefix: "/" + - prefix: "/test/long/url" + rate_limits: + - actions: + - type: destination_cluster + cluster: cluster_1 + - prefix: "/test/" + cluster: cluster_2 + - prefix: "/websocket/test" + prefix_rewrite: "/websocket" + cluster: cluster_1 + use_websocket: true + domains: + - "*" + name: integration + codec_type: http1 + stat_prefix: router + filters: + - type: both + name: health_check + config: + endpoint: "/healthcheck" + pass_through_mode: false + - name: rate_limit + config: + domain: foo + type: decoder + - name: buffer + config: + max_request_time_s: 120 + max_request_bytes: 5242880 + type: decoder + - config: {} + type: decoder + name: router + access_log: + - filter: + filters: + - op: ">=" + value: 500 + type: status_code + - type: duration + op: ">=" + value: 1555500 + type: logical_or + format: '[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% + %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% + %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" + "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" + "%REQUEST_DURATION%" "%RESPONSE_DURATION%"' + path: "/dev/null" + - path: "/dev/null" + deprecated_v1: true + deprecated_v1: + type: read + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + filters: + - type: decoder + name: router + config: {} + codec_type: http1 + stat_prefix: rds_dummy + rds: + route_config_name: foo + cluster: rds + deprecated_v1: true + deprecated_v1: + type: read + - address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 + filter_chains: + - filters: + - name: envoy.redis_proxy + config: + value: + conn_pool: + op_timeout_ms: 400 + stat_prefix: redis + cluster_name: redis + deprecated_v1: true + deprecated_v1: + type: read + clusters: + - name: cds + connect_timeout: 5s + hosts: + - socket_address: + address: {{ ip_loopback_address }} + port_value: 4 + dns_lookup_family: "{{ dns_lookup_family }}" + - name: rds + connect_timeout: 5s + hosts: + - socket_address: + address: {{ ip_loopback_address }} + port_value: 4 + dns_lookup_family: "{{ dns_lookup_family }}" + - name: lds + connect_timeout: 5s + hosts: + - socket_address: + address: {{ ip_loopback_address }} + port_value: 4 + dns_lookup_family: "{{ dns_lookup_family }}" + - name: cluster_1 + connect_timeout: 5s + hosts: + - socket_address: + address: {{ ip_loopback_address }} + port_value: {{ upstream_0 }} + dns_lookup_family: "{{ dns_lookup_family }}" + - name: cluster_2 + type: STRICT_DNS + connect_timeout: 5s + hosts: + - socket_address: + address: localhost + port_value: {{ upstream_1 }} + dns_lookup_family: "{{ dns_lookup_family }}" + - name: cluster_3 + connect_timeout: 5s + per_connection_buffer_limit_bytes: 1024 + hosts: + - socket_address: + address: {{ ip_loopback_address }} + port_value: {{ upstream_0 }} + dns_lookup_family: "{{ dns_lookup_family }}" + - name: statsd + type: STRICT_DNS + connect_timeout: 5s + hosts: + - socket_address: + address: localhost + port_value: 4 + dns_lookup_family: "{{ dns_lookup_family }}" + - name: redis + type: STRICT_DNS + connect_timeout: 5s + lb_policy: RING_HASH + hosts: + - socket_address: + address: localhost + port_value: 4 + dns_lookup_family: "{{ dns_lookup_family }}" + outlier_detection: {} +dynamic_resources: + lds_config: + api_config_source: + cluster_names: + - lds + refresh_delay: 30s + cds_config: + api_config_source: + cluster_names: + - cds + refresh_delay: 30s +cluster_manager: {} +flags_path: "/invalid_flags" +stats_sinks: +- name: envoy.statsd + config: + address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 8125 +- name: envoy.statsd + config: + tcp_cluster_name: statsd +watchdog: {} +runtime: + symlink_root: "{{ test_rundir }}/test/common/runtime/test_data/current" + subdirectory: envoy + override_subdirectory: envoy_override +admin: + access_log_path: "/dev/null" + profile_path: "{{ test_tmpdir }}/envoy.prof" + address: + socket_address: + address: {{ ip_loopback_address }} + port_value: 0 diff --git a/test/config/integration/server_unix_listener.yaml b/test/config/integration/server_unix_listener.yaml new file mode 100644 index 0000000000000..eb1b02f2de3d2 --- /dev/null +++ b/test/config/integration/server_unix_listener.yaml @@ -0,0 +1,44 @@ +static_resources: + listeners: + - address: + pipe: + path: "{{ socket_dir }}/unix-sockets.listener_0" + filter_chains: + - filters: + - name: envoy.http_connection_manager + config: + value: + filters: + - type: decoder + name: router + config: {} + codec_type: auto + stat_prefix: router + drain_timeout_ms: 5000 + route_config: + virtual_hosts: + - domains: + - "*" + name: vhost_0 + routes: + - prefix: "/" + cluster: cluster_0 + deprecated_v1: true + deprecated_v1: + type: read + clusters: + - name: cluster_0 + connect_timeout: 5s + hosts: + - socket_address: + address: "{{ ip_loopback_address }}" + port_value: 0 + dns_lookup_family: V4_ONLY +cluster_manager: {} +watchdog: {} +admin: + access_log_path: "/dev/null" + address: + socket_address: + address: "{{ ip_loopback_address }}" + port_value: 0 diff --git a/test/config/utility.cc b/test/config/utility.cc index 13bf5de7235a6..f0bcc370fcbf0 100644 --- a/test/config/utility.cc +++ b/test/config/utility.cc @@ -114,7 +114,7 @@ name: envoy.squash )EOF"; ConfigHelper::ConfigHelper(const Network::Address::IpVersion version, const std::string& config) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); std::string filename = TestEnvironment::writeStringToFileForTest("basic_config.yaml", config); MessageUtil::loadFromFile(filename, bootstrap_); @@ -140,7 +140,7 @@ ConfigHelper::ConfigHelper(const Network::Address::IpVersion version, const std: } void ConfigHelper::finalize(const std::vector& ports) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); for (auto config_modifier : config_modifiers_) { config_modifier(bootstrap_); } @@ -179,7 +179,7 @@ void ConfigHelper::finalize(const std::vector& ports) { for (int j = 0; j < cluster->hosts_size(); ++j) { if (cluster->mutable_hosts(j)->has_socket_address()) { auto* host_socket_addr = cluster->mutable_hosts(j)->mutable_socket_address(); - RELEASE_ASSERT(ports.size() > port_idx); + RELEASE_ASSERT(ports.size() > port_idx, ""); host_socket_addr->set_port_value(ports[port_idx++]); } } @@ -219,7 +219,7 @@ void ConfigHelper::setCaptureTransportSocket( // Determine inner transport socket. envoy::api::v2::core::TransportSocket inner_transport_socket; if (!transport_socket.name().empty()) { - RELEASE_ASSERT(!tls_config); + RELEASE_ASSERT(!tls_config, ""); inner_transport_socket.MergeFrom(transport_socket); } else if (tls_config.has_value()) { inner_transport_socket.set_name("ssl"); @@ -242,7 +242,7 @@ void ConfigHelper::setCaptureTransportSocket( } void ConfigHelper::setSourceAddress(const std::string& address_string) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); bootstrap_.mutable_cluster_manager() ->mutable_upstream_bind_config() ->mutable_source_address() @@ -255,7 +255,7 @@ void ConfigHelper::setSourceAddress(const std::string& address_string) { } void ConfigHelper::setDefaultHostAndRoute(const std::string& domains, const std::string& prefix) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager hcm_config; loadHttpConnectionManager(hcm_config); @@ -268,8 +268,8 @@ void ConfigHelper::setDefaultHostAndRoute(const std::string& domains, const std: void ConfigHelper::setBufferLimits(uint32_t upstream_buffer_limit, uint32_t downstream_buffer_limit) { - RELEASE_ASSERT(!finalized_); - RELEASE_ASSERT(bootstrap_.mutable_static_resources()->listeners_size() == 1); + RELEASE_ASSERT(!finalized_, ""); + RELEASE_ASSERT(bootstrap_.mutable_static_resources()->listeners_size() == 1, ""); auto* listener = bootstrap_.mutable_static_resources()->mutable_listeners(0); listener->mutable_per_connection_buffer_limit_bytes()->set_value(downstream_buffer_limit); @@ -295,7 +295,7 @@ void ConfigHelper::setBufferLimits(uint32_t upstream_buffer_limit, } void ConfigHelper::setConnectTimeout(std::chrono::milliseconds timeout) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); auto* static_resources = bootstrap_.mutable_static_resources(); for (int i = 0; i < bootstrap_.mutable_static_resources()->clusters_size(); ++i) { @@ -313,7 +313,7 @@ void ConfigHelper::addRoute(const std::string& domains, const std::string& prefi const std::string& cluster, bool validate_clusters, envoy::api::v2::route::RouteAction::ClusterNotFoundResponseCode code, envoy::api::v2::route::VirtualHost::TlsRequirementType type) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager hcm_config; loadHttpConnectionManager(hcm_config); @@ -331,7 +331,7 @@ void ConfigHelper::addRoute(const std::string& domains, const std::string& prefi } void ConfigHelper::addFilter(const std::string& config) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager hcm_config; loadHttpConnectionManager(hcm_config); @@ -349,7 +349,7 @@ void ConfigHelper::addFilter(const std::string& config) { void ConfigHelper::setClientCodec( envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager::CodecType type) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager hcm_config; if (loadHttpConnectionManager(hcm_config)) { hcm_config.set_codec_type(type); @@ -358,7 +358,7 @@ void ConfigHelper::setClientCodec( } void ConfigHelper::addSslConfig() { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); auto* filter_chain = bootstrap_.mutable_static_resources()->mutable_listeners(0)->mutable_filter_chains(0); @@ -390,7 +390,7 @@ void ConfigHelper::renameListener(const std::string& name) { } envoy::api::v2::listener::Filter* ConfigHelper::getFilterFromListener(const std::string& name) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); if (bootstrap_.mutable_static_resources()->listeners_size() == 0) { return nullptr; } @@ -409,7 +409,7 @@ envoy::api::v2::listener::Filter* ConfigHelper::getFilterFromListener(const std: bool ConfigHelper::loadHttpConnectionManager( envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); auto* hcm_filter = getFilterFromListener("envoy.http_connection_manager"); if (hcm_filter) { MessageUtil::jsonConvert(*hcm_filter->mutable_config(), hcm); @@ -420,7 +420,7 @@ bool ConfigHelper::loadHttpConnectionManager( void ConfigHelper::storeHttpConnectionManager( const envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); auto* hcm_config_struct = getFilterFromListener("envoy.http_connection_manager")->mutable_config(); @@ -428,7 +428,7 @@ void ConfigHelper::storeHttpConnectionManager( } void ConfigHelper::addConfigModifier(ConfigModifierFunction function) { - RELEASE_ASSERT(!finalized_); + RELEASE_ASSERT(!finalized_, ""); config_modifiers_.push_back(std::move(function)); } @@ -461,11 +461,11 @@ void EdsHelper::setEds( // FilesystemSubscriptionImpl is subscribed to. std::string path = TestEnvironment::writeStringToFileForTest("eds.update.pb_text", eds_response.DebugString()); - RELEASE_ASSERT(::rename(path.c_str(), eds_path_.c_str()) == 0); + RELEASE_ASSERT(::rename(path.c_str(), eds_path_.c_str()) == 0, ""); // Make sure Envoy has consumed the update now that it is running. server_stats.waitForCounterGe("cluster.cluster_0.update_success", ++update_successes_); - RELEASE_ASSERT(update_successes_ == - server_stats.counter("cluster.cluster_0.update_success")->value()); + RELEASE_ASSERT( + update_successes_ == server_stats.counter("cluster.cluster_0.update_success")->value(), ""); } } // namespace Envoy diff --git a/test/config_test/example_configs_test.cc b/test/config_test/example_configs_test.cc index 3b9c101b96d19..d49262987a930 100644 --- a/test/config_test/example_configs_test.cc +++ b/test/config_test/example_configs_test.cc @@ -12,8 +12,8 @@ TEST(ExampleConfigsTest, All) { // Change working directory, otherwise we won't be able to read files using relative paths. char cwd[PATH_MAX]; const std::string& directory = TestEnvironment::temporaryDirectory() + "/test/config_test"; - RELEASE_ASSERT(::getcwd(cwd, PATH_MAX) != nullptr); - RELEASE_ASSERT(::chdir(directory.c_str()) == 0); + RELEASE_ASSERT(::getcwd(cwd, PATH_MAX) != nullptr, ""); + RELEASE_ASSERT(::chdir(directory.c_str()) == 0, ""); #ifdef __APPLE__ // freebind/freebind.yaml is not supported on OS X and disabled via Bazel. @@ -25,6 +25,6 @@ TEST(ExampleConfigsTest, All) { ConfigTest::testIncompatibleMerge(); // Return to the original working directory, otherwise "bazel.coverage" breaks (...but why?). - RELEASE_ASSERT(::chdir(cwd) == 0); + RELEASE_ASSERT(::chdir(cwd) == 0, ""); } } // namespace Envoy diff --git a/test/extensions/access_loggers/file/config_test.cc b/test/extensions/access_loggers/file/config_test.cc index 6e7facb54c4bf..7bd56fea21107 100644 --- a/test/extensions/access_loggers/file/config_test.cc +++ b/test/extensions/access_loggers/file/config_test.cc @@ -37,7 +37,7 @@ TEST(FileAccessLogConfigTest, ConfigureFromProto) { EXPECT_THROW_WITH_MESSAGE(AccessLog::AccessLogFactory::fromProto(config, context), EnvoyException, "Provided name for static registration lookup was empty."); - config.set_name(AccessLogNames::get().FILE); + config.set_name(AccessLogNames::get().File); AccessLog::InstanceSharedPtr log = AccessLog::AccessLogFactory::fromProto(config, context); @@ -53,7 +53,7 @@ TEST(FileAccessLogConfigTest, ConfigureFromProto) { TEST(FileAccessLogConfigTest, FileAccessLogTest) { auto factory = Registry::FactoryRegistry::getFactory( - AccessLogNames::get().FILE); + AccessLogNames::get().File); ASSERT_NE(nullptr, factory); ProtobufTypes::MessagePtr message = factory->createEmptyConfigProto(); diff --git a/test/extensions/access_loggers/http_grpc/BUILD b/test/extensions/access_loggers/http_grpc/BUILD index 1d5314a999fe2..2f2e07aab3ca9 100644 --- a/test/extensions/access_loggers/http_grpc/BUILD +++ b/test/extensions/access_loggers/http_grpc/BUILD @@ -46,5 +46,6 @@ envoy_extension_cc_test( "//source/extensions/access_loggers/http_grpc:config", "//test/common/grpc:grpc_client_integration_lib", "//test/integration:http_integration_lib", + "//test/test_common:utility_lib", ], ) diff --git a/test/extensions/access_loggers/http_grpc/config_test.cc b/test/extensions/access_loggers/http_grpc/config_test.cc index c04331e52c1c9..04bc4049d34ca 100644 --- a/test/extensions/access_loggers/http_grpc/config_test.cc +++ b/test/extensions/access_loggers/http_grpc/config_test.cc @@ -23,7 +23,7 @@ class HttpGrpcAccessLogConfigTest : public testing::Test { void SetUp() override { factory_ = Registry::FactoryRegistry::getFactory( - AccessLogNames::get().HTTP_GRPC); + AccessLogNames::get().HttpGrpc); ASSERT_NE(nullptr, factory_); message_ = factory_->createEmptyConfigProto(); diff --git a/test/extensions/access_loggers/http_grpc/grpc_access_log_impl_test.cc b/test/extensions/access_loggers/http_grpc/grpc_access_log_impl_test.cc index e365cff61b448..bfbb62bc7eeb6 100644 --- a/test/extensions/access_loggers/http_grpc/grpc_access_log_impl_test.cc +++ b/test/extensions/access_loggers/http_grpc/grpc_access_log_impl_test.cc @@ -219,7 +219,7 @@ TEST_F(HttpGrpcAccessLogTest, Marshalling) { request_info.addBytesReceived(10); request_info.addBytesSent(20); request_info.response_code_ = 200; - ON_CALL(request_info, getResponseFlag(RequestInfo::ResponseFlag::FaultInjected)) + ON_CALL(request_info, hasResponseFlag(RequestInfo::ResponseFlag::FaultInjected)) .WillByDefault(Return(true)); Http::TestHeaderMapImpl request_headers{ @@ -414,7 +414,7 @@ TEST_F(HttpGrpcAccessLogTest, MarshallingAdditionalHeaders) { TEST(responseFlagsToAccessLogResponseFlagsTest, All) { NiceMock request_info; - ON_CALL(request_info, getResponseFlag(_)).WillByDefault(Return(true)); + ON_CALL(request_info, hasResponseFlag(_)).WillByDefault(Return(true)); envoy::data::accesslog::v2::AccessLogCommon common_access_log; HttpGrpcAccessLog::responseFlagsToAccessLogResponseFlags(common_access_log, request_info); diff --git a/test/extensions/access_loggers/http_grpc/grpc_access_log_integration_test.cc b/test/extensions/access_loggers/http_grpc/grpc_access_log_integration_test.cc index 4bdc7c9f24526..7888dfc9a1554 100644 --- a/test/extensions/access_loggers/http_grpc/grpc_access_log_integration_test.cc +++ b/test/extensions/access_loggers/http_grpc/grpc_access_log_integration_test.cc @@ -8,9 +8,12 @@ #include "test/common/grpc/grpc_client_integration.h" #include "test/integration/http_integration.h" +#include "test/test_common/utility.h" #include "gtest/gtest.h" +using testing::AssertionResult; + namespace Envoy { namespace { @@ -49,17 +52,20 @@ class AccessLogIntegrationTest : public HttpIntegrationTest, HttpIntegrationTest::initialize(); } - void waitForAccessLogConnection() { - fake_access_log_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); + ABSL_MUST_USE_RESULT + AssertionResult waitForAccessLogConnection() { + return fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_access_log_connection_); } - void waitForAccessLogStream() { - access_log_request_ = fake_access_log_connection_->waitForNewStream(*dispatcher_); + ABSL_MUST_USE_RESULT + AssertionResult waitForAccessLogStream() { + return fake_access_log_connection_->waitForNewStream(*dispatcher_, access_log_request_); } - void waitForAccessLogRequest(const std::string& expected_request_msg_yaml) { + ABSL_MUST_USE_RESULT + AssertionResult waitForAccessLogRequest(const std::string& expected_request_msg_yaml) { envoy::service::accesslog::v2::StreamAccessLogsMessage request_msg; - access_log_request_->waitForGrpcMessage(*dispatcher_, request_msg); + VERIFY_ASSERTION(access_log_request_->waitForGrpcMessage(*dispatcher_, request_msg)); EXPECT_STREQ("POST", access_log_request_->headers().Method()->value().c_str()); EXPECT_STREQ("/envoy.service.accesslog.v2.AccessLogService/StreamAccessLogs", access_log_request_->headers().Path()->value().c_str()); @@ -78,12 +84,16 @@ class AccessLogIntegrationTest : public HttpIntegrationTest, log_entry->mutable_common_properties()->clear_time_to_last_downstream_tx_byte(); log_entry->mutable_request()->clear_request_id(); EXPECT_EQ(request_msg.DebugString(), expected_request_msg.DebugString()); + + return AssertionSuccess(); } void cleanup() { if (fake_access_log_connection_ != nullptr) { - fake_access_log_connection_->close(); - fake_access_log_connection_->waitForDisconnect(); + AssertionResult result = fake_access_log_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_access_log_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } @@ -97,9 +107,9 @@ INSTANTIATE_TEST_CASE_P(IpVersionsCientType, AccessLogIntegrationTest, // Test a basic full access logging flow. TEST_P(AccessLogIntegrationTest, BasicAccessLogFlow) { testRouterNotFound(); - waitForAccessLogConnection(); - waitForAccessLogStream(); - waitForAccessLogRequest(fmt::format(R"EOF( + ASSERT_TRUE(waitForAccessLogConnection()); + ASSERT_TRUE(waitForAccessLogStream()); + ASSERT_TRUE(waitForAccessLogRequest(fmt::format(R"EOF( identifier: node: id: node_name @@ -124,13 +134,13 @@ TEST_P(AccessLogIntegrationTest, BasicAccessLogFlow) { value: 404 response_headers_bytes: 54 )EOF", - VersionInfo::version())); + VersionInfo::version()))); BufferingStreamDecoderPtr response = IntegrationUtil::makeSingleRequest( lookupPort("http"), "GET", "/notfound", "", downstream_protocol_, version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("404", response->headers().Status()->value().c_str()); - waitForAccessLogRequest(R"EOF( + ASSERT_TRUE(waitForAccessLogRequest(R"EOF( http_logs: log_entry: common_properties: @@ -146,7 +156,7 @@ TEST_P(AccessLogIntegrationTest, BasicAccessLogFlow) { response_code: value: 404 response_headers_bytes: 54 -)EOF"); +)EOF")); // Send an empty response and end the stream. This should never happen but make sure nothing // breaks and we make a new stream on a follow up request. @@ -162,14 +172,14 @@ TEST_P(AccessLogIntegrationTest, BasicAccessLogFlow) { test_server_->waitForCounterGe("grpc.accesslog.streams_closed_0", 1); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } response = IntegrationUtil::makeSingleRequest(lookupPort("http"), "GET", "/notfound", "", downstream_protocol_, version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("404", response->headers().Status()->value().c_str()); - waitForAccessLogStream(); - waitForAccessLogRequest(fmt::format(R"EOF( + ASSERT_TRUE(waitForAccessLogStream()); + ASSERT_TRUE(waitForAccessLogRequest(fmt::format(R"EOF( identifier: node: id: node_name @@ -194,7 +204,7 @@ TEST_P(AccessLogIntegrationTest, BasicAccessLogFlow) { value: 404 response_headers_bytes: 54 )EOF", - VersionInfo::version())); + VersionInfo::version()))); cleanup(); } diff --git a/test/extensions/extensions_build_system.bzl b/test/extensions/extensions_build_system.bzl index 1ea8f3630c1e0..68d496642e1da 100644 --- a/test/extensions/extensions_build_system.bzl +++ b/test/extensions/extensions_build_system.bzl @@ -1,21 +1,32 @@ -load("//bazel:envoy_build_system.bzl", "envoy_cc_test", "envoy_cc_mock") +load("//bazel:envoy_build_system.bzl", "envoy_cc_mock", "envoy_cc_test", "envoy_cc_test_library") load("@envoy_build_config//:extensions_build_config.bzl", "EXTENSIONS") # All extension tests should use this version of envoy_cc_test(). It allows compiling out # tests for extensions that the user does not wish to include in their build. # @param extension_name should match an extension listed in EXTENSIONS. -def envoy_extension_cc_test(name, - extension_name, - **kwargs): +def envoy_extension_cc_test( + name, + extension_name, + **kwargs): if not extension_name in EXTENSIONS: return envoy_cc_test(name, **kwargs) -def envoy_extension_cc_mock(name, - extension_name, - **kwargs): +def envoy_extension_cc_test_library( + name, + extension_name, + **kwargs): if not extension_name in EXTENSIONS: return - envoy_cc_mock(name, **kwargs) \ No newline at end of file + envoy_cc_test_library(name, **kwargs) + +def envoy_extension_cc_mock( + name, + extension_name, + **kwargs): + if not extension_name in EXTENSIONS: + return + + envoy_cc_mock(name, **kwargs) diff --git a/test/extensions/filters/common/ext_authz/BUILD b/test/extensions/filters/common/ext_authz/BUILD index 11da871031dd7..27b273c5ee12d 100644 --- a/test/extensions/filters/common/ext_authz/BUILD +++ b/test/extensions/filters/common/ext_authz/BUILD @@ -10,15 +10,12 @@ load( envoy_package() envoy_cc_test( - name = "ext_authz_impl_test", - srcs = ["ext_authz_impl_test.cc"], + name = "check_request_utils_test", + srcs = ["check_request_utils_test.cc"], deps = [ - "//source/common/http:header_map_lib", - "//source/common/http:headers_lib", "//source/common/network:address_lib", - "//source/extensions/filters/common/ext_authz:ext_authz_lib", - "//test/mocks/grpc:grpc_mocks", - "//test/mocks/http:http_mocks", + "//source/common/protobuf", + "//source/extensions/filters/common/ext_authz:check_request_utils_lib", "//test/mocks/network:network_mocks", "//test/mocks/request_info:request_info_mocks", "//test/mocks/ssl:ssl_mocks", @@ -27,6 +24,24 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "ext_authz_grpc_impl_test", + srcs = ["ext_authz_grpc_impl_test.cc"], + deps = [ + "//source/extensions/filters/common/ext_authz:ext_authz_grpc_lib", + "//test/extensions/filters/common/ext_authz:ext_authz_test_common", + ], +) + +envoy_cc_test( + name = "ext_authz_http_impl_test", + srcs = ["ext_authz_http_impl_test.cc"], + deps = [ + "//source/extensions/filters/common/ext_authz:ext_authz_http_lib", + "//test/extensions/filters/common/ext_authz:ext_authz_test_common", + ], +) + envoy_cc_mock( name = "ext_authz_mocks", srcs = ["mocks.cc"], @@ -35,3 +50,18 @@ envoy_cc_mock( "//source/extensions/filters/common/ext_authz:ext_authz_interface", ], ) + +envoy_cc_mock( + name = "ext_authz_test_common", + srcs = ["test_common.cc"], + hdrs = ["test_common.h"], + deps = [ + "//source/common/http:headers_lib", + "//source/common/protobuf", + "//source/extensions/filters/common/ext_authz:ext_authz_grpc_lib", + "//test/extensions/filters/common/ext_authz:ext_authz_mocks", + "//test/mocks/grpc:grpc_mocks", + "//test/mocks/upstream:upstream_mocks", + "@envoy_api//envoy/api/v2/core:base_cc", + ], +) diff --git a/test/extensions/filters/common/ext_authz/check_request_utils_test.cc b/test/extensions/filters/common/ext_authz/check_request_utils_test.cc new file mode 100644 index 0000000000000..d760e61188cef --- /dev/null +++ b/test/extensions/filters/common/ext_authz/check_request_utils_test.cc @@ -0,0 +1,94 @@ +#include "common/network/address_impl.h" +#include "common/protobuf/protobuf.h" + +#include "extensions/filters/common/ext_authz/check_request_utils.h" + +#include "test/mocks/network/mocks.h" +#include "test/mocks/request_info/mocks.h" +#include "test/mocks/ssl/mocks.h" +#include "test/mocks/upstream/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Return; +using testing::ReturnPointee; +using testing::ReturnRef; + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +class CheckRequestUtilsTest : public testing::Test { +public: + CheckRequestUtilsTest() { + addr_ = std::make_shared("1.2.3.4", 1111); + protocol_ = Envoy::Http::Protocol::Http10; + }; + + Network::Address::InstanceConstSharedPtr addr_; + absl::optional protocol_; + CheckRequestUtils check_request_generator_; + NiceMock callbacks_; + NiceMock net_callbacks_; + NiceMock connection_; + NiceMock ssl_; + NiceMock req_info_; +}; + +// Verify that createTcpCheck's dependencies are invoked when it's called. +TEST_F(CheckRequestUtilsTest, BasicTcp) { + envoy::service::auth::v2alpha::CheckRequest request; + EXPECT_CALL(net_callbacks_, connection()).Times(2).WillRepeatedly(ReturnRef(connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(Const(connection_), ssl()).Times(2).WillRepeatedly(Return(&ssl_)); + + CheckRequestUtils::createTcpCheck(&net_callbacks_, request); +} + +// Verify that createHttpCheck's dependencies are invoked when it's called. +TEST_F(CheckRequestUtilsTest, BasicHttp) { + Http::HeaderMapImpl headers; + envoy::service::auth::v2alpha::CheckRequest request; + EXPECT_CALL(callbacks_, connection()).Times(2).WillRepeatedly(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(Const(connection_), ssl()).Times(2).WillRepeatedly(Return(&ssl_)); + EXPECT_CALL(callbacks_, streamId()).WillOnce(Return(0)); + EXPECT_CALL(callbacks_, requestInfo()).Times(3).WillRepeatedly(ReturnRef(req_info_)); + EXPECT_CALL(req_info_, protocol()).Times(2).WillRepeatedly(ReturnPointee(&protocol_)); + + CheckRequestUtils::createHttpCheck(&callbacks_, headers, request); +} + +// Verify that createHttpCheck extract the proper attributes from the http request into CheckRequest +// proto object. +TEST_F(CheckRequestUtilsTest, CheckAttrContextPeer) { + Http::TestHeaderMapImpl request_headers{{"x-envoy-downstream-service-cluster", "foo"}, + {":path", "/bar"}}; + envoy::service::auth::v2alpha::CheckRequest request; + EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillRepeatedly(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillRepeatedly(ReturnRef(addr_)); + EXPECT_CALL(Const(connection_), ssl()).WillRepeatedly(Return(&ssl_)); + EXPECT_CALL(callbacks_, streamId()).WillRepeatedly(Return(0)); + EXPECT_CALL(callbacks_, requestInfo()).WillRepeatedly(ReturnRef(req_info_)); + EXPECT_CALL(req_info_, protocol()).WillRepeatedly(ReturnPointee(&protocol_)); + EXPECT_CALL(ssl_, uriSanPeerCertificate()).WillOnce(Return("source")); + EXPECT_CALL(ssl_, uriSanLocalCertificate()).WillOnce(Return("destination")); + + CheckRequestUtils::createHttpCheck(&callbacks_, request_headers, request); + + EXPECT_EQ("source", request.attributes().source().principal()); + EXPECT_EQ("destination", request.attributes().destination().principal()); + EXPECT_EQ("foo", request.attributes().source().service()); +} + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/common/ext_authz/ext_authz_grpc_impl_test.cc b/test/extensions/filters/common/ext_authz/ext_authz_grpc_impl_test.cc new file mode 100644 index 0000000000000..6b6c5f8257351 --- /dev/null +++ b/test/extensions/filters/common/ext_authz/ext_authz_grpc_impl_test.cc @@ -0,0 +1,187 @@ +#include "envoy/api/v2/core/base.pb.h" + +#include "common/http/headers.h" +#include "common/protobuf/protobuf.h" + +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" + +#include "test/extensions/filters/common/ext_authz/mocks.h" +#include "test/extensions/filters/common/ext_authz/test_common.h" +#include "test/mocks/grpc/mocks.h" +#include "test/mocks/upstream/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Invoke; +using testing::Ref; +using testing::Return; +using testing::ReturnPointee; +using testing::ReturnRef; +using testing::WhenDynamicCastTo; +using testing::WithArg; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +class ExtAuthzGrpcClientTest : public testing::Test { +public: + ExtAuthzGrpcClientTest() + : async_client_(new Grpc::MockAsyncClient()), timeout_(10), + client_(Grpc::AsyncClientPtr{async_client_}, timeout_) {} + + Grpc::MockAsyncClient* async_client_; + absl::optional timeout_; + Grpc::MockAsyncRequest async_request_; + GrpcClientImpl client_; + MockRequestCallbacks request_callbacks_; + Tracing::MockSpan span_; + + void expectCallSend(envoy::service::auth::v2alpha::CheckRequest& request) { + EXPECT_CALL(*async_client_, send(_, ProtoEq(request), Ref(client_), _, _)) + .WillOnce(Invoke( + [this]( + const Protobuf::MethodDescriptor& service_method, const Protobuf::Message&, + Grpc::AsyncRequestCallbacks&, Tracing::Span&, + const absl::optional& timeout) -> Grpc::AsyncRequest* { + // TODO(dio): Use a defined constant value. + EXPECT_EQ("envoy.service.auth.v2alpha.Authorization", + service_method.service()->full_name()); + EXPECT_EQ("Check", service_method.name()); + EXPECT_EQ(timeout_->count(), timeout->count()); + return &async_request_; + })); + } +}; + +// Test the client when an ok response is received. +TEST_F(ExtAuthzGrpcClientTest, AuthorizationOk) { + auto check_response = std::make_unique(); + auto status = check_response->mutable_status(); + status->set_code(Grpc::Status::GrpcStatus::Ok); + auto authz_response = Response{}; + authz_response.status = CheckStatus::OK; + + envoy::service::auth::v2alpha::CheckRequest request; + expectCallSend(request); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + Http::HeaderMapImpl headers; + client_.onCreateInitialMetadata(headers); + + EXPECT_CALL(span_, setTag("ext_authz_status", "ext_authz_ok")); + EXPECT_CALL(request_callbacks_, onComplete_(WhenDynamicCastTo( + AuthzResponseNoAttributes(authz_response)))); + client_.onSuccess(std::move(check_response), span_); +} + +// Test the client when an ok response is received. +TEST_F(ExtAuthzGrpcClientTest, AuthorizationOkWithAllAtributes) { + const std::string empty_body{}; + const auto expected_headers = TestCommon::makeHeaderValueOption({{"foo", "bar", false}}); + auto check_response = TestCommon::makeCheckResponse( + Grpc::Status::GrpcStatus::Ok, envoy::type::StatusCode::OK, empty_body, expected_headers); + auto authz_response = + TestCommon::makeAuthzResponse(CheckStatus::OK, Http::Code::OK, empty_body, expected_headers); + + envoy::service::auth::v2alpha::CheckRequest request; + expectCallSend(request); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + Http::HeaderMapImpl headers; + client_.onCreateInitialMetadata(headers); + + EXPECT_CALL(span_, setTag("ext_authz_status", "ext_authz_ok")); + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzOkResponse(authz_response)))); + client_.onSuccess(std::move(check_response), span_); +} + +// Test the client when a denied response is received. +TEST_F(ExtAuthzGrpcClientTest, AuthorizationDenied) { + auto check_response = std::make_unique(); + auto status = check_response->mutable_status(); + status->set_code(Grpc::Status::GrpcStatus::PermissionDenied); + auto authz_response = Response{}; + authz_response.status = CheckStatus::Denied; + + envoy::service::auth::v2alpha::CheckRequest request; + expectCallSend(request); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + Http::HeaderMapImpl headers; + client_.onCreateInitialMetadata(headers); + EXPECT_EQ(nullptr, headers.RequestId()); + EXPECT_CALL(span_, setTag("ext_authz_status", "ext_authz_unauthorized")); + EXPECT_CALL(request_callbacks_, onComplete_(WhenDynamicCastTo( + AuthzResponseNoAttributes(authz_response)))); + + client_.onSuccess(std::move(check_response), span_); +} + +// Test the client when a denied response with additional HTTP attributes is received. +TEST_F(ExtAuthzGrpcClientTest, AuthorizationDeniedWithAllAttributes) { + const std::string expected_body{"test"}; + const auto expected_headers = + TestCommon::makeHeaderValueOption({{"foo", "bar", false}, {"foobar", "bar", true}}); + auto check_response = TestCommon::makeCheckResponse(Grpc::Status::GrpcStatus::PermissionDenied, + envoy::type::StatusCode::Unauthorized, + expected_body, expected_headers); + auto authz_response = TestCommon::makeAuthzResponse(CheckStatus::Denied, Http::Code::Unauthorized, + expected_body, expected_headers); + + envoy::service::auth::v2alpha::CheckRequest request; + expectCallSend(request); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + Http::HeaderMapImpl headers; + client_.onCreateInitialMetadata(headers); + EXPECT_EQ(nullptr, headers.RequestId()); + EXPECT_CALL(span_, setTag("ext_authz_status", "ext_authz_unauthorized")); + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzDeniedResponse(authz_response)))); + + client_.onSuccess(std::move(check_response), span_); +} + +// Test the client when an unknown error occurs. +TEST_F(ExtAuthzGrpcClientTest, UnknownError) { + envoy::service::auth::v2alpha::CheckRequest request; + expectCallSend(request); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzErrorResponse(CheckStatus::Error)))); + client_.onFailure(Grpc::Status::Unknown, "", span_); +} + +// Test the client when the request is canceled. +TEST_F(ExtAuthzGrpcClientTest, CancelledAuthorizationRequest) { + envoy::service::auth::v2alpha::CheckRequest request; + EXPECT_CALL(*async_client_, send(_, _, _, _, _)).WillOnce(Return(&async_request_)); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(async_request_, cancel()); + client_.cancel(); +} + +// Test the client when the request times out. +TEST_F(ExtAuthzGrpcClientTest, AuthorizationRequestTimeout) { + envoy::service::auth::v2alpha::CheckRequest request; + expectCallSend(request); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzErrorResponse(CheckStatus::Error)))); + client_.onFailure(Grpc::Status::DeadlineExceeded, "", span_); +} + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc b/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc new file mode 100644 index 0000000000000..5588c168519b8 --- /dev/null +++ b/test/extensions/filters/common/ext_authz/ext_authz_http_impl_test.cc @@ -0,0 +1,184 @@ +#include "envoy/api/v2/core/base.pb.h" + +#include "common/http/headers.h" +#include "common/http/message_impl.h" +#include "common/protobuf/protobuf.h" +#include "common/tracing/http_tracer_impl.h" + +#include "extensions/filters/common/ext_authz/ext_authz_http_impl.h" + +#include "test/extensions/filters/common/ext_authz/mocks.h" +#include "test/extensions/filters/common/ext_authz/test_common.h" +#include "test/mocks/upstream/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Invoke; +using testing::Ref; +using testing::Return; +using testing::ReturnPointee; +using testing::ReturnRef; +using testing::WhenDynamicCastTo; +using testing::WithArg; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +typedef std::vector HeaderValueOptionVector; + +class ExtAuthzHttpClientTest : public testing::Test { +public: + ExtAuthzHttpClientTest() + : cluster_name_{"foo"}, cluster_manager_{}, timeout_{}, path_prefix_{"/bar"}, + response_headers_to_remove_{Http::LowerCaseString{"bar"}}, async_client_{}, + async_request_{&async_client_}, client_(cluster_name_, cluster_manager_, timeout_, + path_prefix_, response_headers_to_remove_) { + ON_CALL(cluster_manager_, httpAsyncClientForCluster(cluster_name_)) + .WillByDefault(ReturnRef(async_client_)); + } + + std::string cluster_name_; + NiceMock cluster_manager_; + MockRequestCallbacks request_callbacks_; + absl::optional timeout_; + std::string path_prefix_; + std::vector response_headers_to_remove_; + NiceMock async_client_; + NiceMock async_request_; + RawHttpClientImpl client_; +}; + +// Test the client when an ok response is received. +TEST_F(ExtAuthzHttpClientTest, AuthorizationOk) { + const auto expected_headers = TestCommon::makeHeaderValueOption({{":status", "200", false}}); + const auto authz_response = TestCommon::makeAuthzResponse(CheckStatus::OK); + auto check_response = TestCommon::makeMessageResponse(expected_headers); + envoy::service::auth::v2alpha::CheckRequest request; + + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzOkResponse(authz_response)))); + + client_.onSuccess(std::move(check_response)); +} + +// Test the client when an request contains path to be re-written and ok response is received. +TEST_F(ExtAuthzHttpClientTest, AuthorizationOkWithPathRewrite) { + const auto expected_headers = TestCommon::makeHeaderValueOption({{":status", "200", false}}); + const auto authz_response = TestCommon::makeAuthzResponse(CheckStatus::OK); + auto check_response = TestCommon::makeMessageResponse(expected_headers); + + envoy::service::auth::v2alpha::CheckRequest request{}; + auto mutable_headers = + request.mutable_attributes()->mutable_request()->mutable_http()->mutable_headers(); + (*mutable_headers)[std::string{":path"}] = std::string{"foo"}; + (*mutable_headers)[std::string{"foo"}] = std::string{"bar"}; + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzOkResponse(authz_response)))); + + client_.onSuccess(std::move(check_response)); +} + +// Test that the client removes certain response headers. +TEST_F(ExtAuthzHttpClientTest, AuthorizationOkWithRemovedHeader) { + const auto expected_headers = TestCommon::makeHeaderValueOption({{"foobar", "foo", false}}); + const std::string empty_body{}; + const auto authz_response = + TestCommon::makeAuthzResponse(CheckStatus::OK, Http::Code::OK, empty_body, expected_headers); + const auto check_response_headers = + TestCommon::makeHeaderValueOption({{":status", "200", false}, + {":path", "/bar", false}, + {":method", "post", false}, + {"content-length", "post", false}, + {"bar", "foo", false}, + {"foobar", "foo", false}}); + auto message_response = TestCommon::makeMessageResponse(check_response_headers); + + envoy::service::auth::v2alpha::CheckRequest request; + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzOkResponse(authz_response)))); + + client_.onSuccess(std::move(message_response)); +} + +// Test the client when a denied response is received due to an unknown status code. +TEST_F(ExtAuthzHttpClientTest, AuthorizationDeniedWithInvalidStatusCode) { + const auto expected_headers = TestCommon::makeHeaderValueOption({{":status", "error", false}}); + const auto authz_response = TestCommon::makeAuthzResponse( + CheckStatus::Denied, Http::Code::Forbidden, "", expected_headers); + Http::MessagePtr check_response(new Http::ResponseMessageImpl( + Http::HeaderMapPtr{new Http::TestHeaderMapImpl{{":status", "error"}}})); + envoy::service::auth::v2alpha::CheckRequest request; + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzDeniedResponse(authz_response)))); + + client_.onSuccess(std::move(check_response)); +} + +// Test the client when a denied response is received. +TEST_F(ExtAuthzHttpClientTest, AuthorizationDenied) { + const auto expected_headers = TestCommon::makeHeaderValueOption({{":status", "403", false}}); + const auto authz_response = TestCommon::makeAuthzResponse( + CheckStatus::Denied, Http::Code::Forbidden, "", expected_headers); + auto check_response = TestCommon::makeMessageResponse(expected_headers); + + envoy::service::auth::v2alpha::CheckRequest request; + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzDeniedResponse(authz_response)))); + + client_.onSuccess(std::move(check_response)); +} + +// Test the client when a denied response is received and it contains additional HTTP attributes. +TEST_F(ExtAuthzHttpClientTest, AuthorizationDeniedWithAllAttributes) { + const auto expected_body = std::string{"test"}; + const auto expected_headers = TestCommon::makeHeaderValueOption({{":status", "401", false}}); + const auto authz_response = TestCommon::makeAuthzResponse( + CheckStatus::Denied, Http::Code::Unauthorized, expected_body, expected_headers); + auto check_response = TestCommon::makeMessageResponse(expected_headers, expected_body); + + envoy::service::auth::v2alpha::CheckRequest request; + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzDeniedResponse(authz_response)))); + + client_.onSuccess(std::move(check_response)); +} + +// Test the client when an unknown error occurs. +TEST_F(ExtAuthzHttpClientTest, AuthorizationRequestError) { + envoy::service::auth::v2alpha::CheckRequest request; + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(request_callbacks_, + onComplete_(WhenDynamicCastTo(AuthzErrorResponse(CheckStatus::Error)))); + client_.onFailure(Http::AsyncClient::FailureReason::Reset); +} + +// Test the client when the request is canceled. +TEST_F(ExtAuthzHttpClientTest, CancelledAuthorizationRequest) { + envoy::service::auth::v2alpha::CheckRequest request; + EXPECT_CALL(async_client_, send_(_, _, _)).WillOnce(Return(&async_request_)); + client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); + + EXPECT_CALL(async_request_, cancel()); + client_.cancel(); +} + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/common/ext_authz/ext_authz_impl_test.cc b/test/extensions/filters/common/ext_authz/ext_authz_impl_test.cc deleted file mode 100644 index f2f2ded6e23c7..0000000000000 --- a/test/extensions/filters/common/ext_authz/ext_authz_impl_test.cc +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include - -#include "common/http/header_map_impl.h" -#include "common/http/headers.h" -#include "common/network/address_impl.h" -#include "common/tracing/http_tracer_impl.h" - -#include "extensions/filters/common/ext_authz/ext_authz_impl.h" - -#include "test/mocks/grpc/mocks.h" -#include "test/mocks/http/mocks.h" -#include "test/mocks/network/mocks.h" -#include "test/mocks/request_info/mocks.h" -#include "test/mocks/ssl/mocks.h" -#include "test/mocks/upstream/mocks.h" -#include "test/test_common/printers.h" -#include "test/test_common/utility.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::AtLeast; -using testing::Invoke; -using testing::Ref; -using testing::Return; -using testing::ReturnPointee; -using testing::ReturnRef; -using testing::WithArg; -using testing::_; - -namespace Envoy { -namespace Extensions { -namespace Filters { -namespace Common { -namespace ExtAuthz { - -class MockRequestCallbacks : public RequestCallbacks { -public: - MOCK_METHOD1(onComplete, void(CheckStatus status)); -}; - -class ExtAuthzGrpcClientTest : public testing::Test { -public: - ExtAuthzGrpcClientTest() - : async_client_(new Grpc::MockAsyncClient()), - client_(Grpc::AsyncClientPtr{async_client_}, absl::optional()) {} - - Grpc::MockAsyncClient* async_client_; - Grpc::MockAsyncRequest async_request_; - GrpcClientImpl client_; - MockRequestCallbacks request_callbacks_; - Tracing::MockSpan span_; -}; - -TEST_F(ExtAuthzGrpcClientTest, BasicOK) { - envoy::service::auth::v2alpha::CheckRequest request; - std::unique_ptr response; - Http::HeaderMapImpl headers; - EXPECT_CALL(*async_client_, send(_, ProtoEq(request), _, _, _)).WillOnce(Return(&async_request_)); - - client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); - - client_.onCreateInitialMetadata(headers); - - response = std::make_unique(); - auto status = response->mutable_status(); - status->set_code(Grpc::Status::GrpcStatus::Ok); - EXPECT_CALL(span_, setTag("ext_authz_status", "ext_authz_ok")); - EXPECT_CALL(request_callbacks_, onComplete(CheckStatus::OK)); - client_.onSuccess(std::move(response), span_); -} - -TEST_F(ExtAuthzGrpcClientTest, BasicDenied) { - envoy::service::auth::v2alpha::CheckRequest request; - std::unique_ptr response; - Http::HeaderMapImpl headers; - - EXPECT_CALL(*async_client_, send(_, ProtoEq(request), Ref(client_), _, _)) - .WillOnce( - Invoke([this](const Protobuf::MethodDescriptor& service_method, const Protobuf::Message&, - Grpc::AsyncRequestCallbacks&, Tracing::Span&, - const absl::optional&) -> Grpc::AsyncRequest* { - // TODO(dio): Use a defined constant value. - EXPECT_EQ("envoy.service.auth.v2alpha.Authorization", - service_method.service()->full_name()); - EXPECT_EQ("Check", service_method.name()); - return &async_request_; - })); - - client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); - - client_.onCreateInitialMetadata(headers); - EXPECT_EQ(nullptr, headers.RequestId()); - - response = std::make_unique(); - auto status = response->mutable_status(); - status->set_code(Grpc::Status::GrpcStatus::PermissionDenied); - EXPECT_CALL(span_, setTag("ext_authz_status", "ext_authz_unauthorized")); - EXPECT_CALL(request_callbacks_, onComplete(CheckStatus::Denied)); - client_.onSuccess(std::move(response), span_); -} - -TEST_F(ExtAuthzGrpcClientTest, BasicError) { - envoy::service::auth::v2alpha::CheckRequest request; - EXPECT_CALL(*async_client_, send(_, ProtoEq(request), _, _, _)).WillOnce(Return(&async_request_)); - - client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); - - EXPECT_CALL(request_callbacks_, onComplete(CheckStatus::Error)); - client_.onFailure(Grpc::Status::Unknown, "", span_); -} - -TEST_F(ExtAuthzGrpcClientTest, Cancel) { - envoy::service::auth::v2alpha::CheckRequest request; - - EXPECT_CALL(*async_client_, send(_, _, _, _, _)).WillOnce(Return(&async_request_)); - - client_.check(request_callbacks_, request, Tracing::NullSpan::instance()); - - EXPECT_CALL(async_request_, cancel()); - client_.cancel(); -} - -class CheckRequestUtilsTest : public testing::Test { -public: - CheckRequestUtilsTest() { - addr_ = std::make_shared("1.2.3.4", 1111); - protocol_ = Envoy::Http::Protocol::Http10; - }; - - Network::Address::InstanceConstSharedPtr addr_; - absl::optional protocol_; - CheckRequestUtils check_request_generator_; - NiceMock callbacks_; - NiceMock net_callbacks_; - NiceMock connection_; - NiceMock ssl_; - NiceMock req_info_; -}; - -TEST_F(CheckRequestUtilsTest, BasicTcp) { - - envoy::service::auth::v2alpha::CheckRequest request; - - EXPECT_CALL(net_callbacks_, connection()).Times(2).WillRepeatedly(ReturnRef(connection_)); - EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); - EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); - EXPECT_CALL(Const(connection_), ssl()).Times(2).WillRepeatedly(Return(&ssl_)); - - CheckRequestUtils::createTcpCheck(&net_callbacks_, request); -} - -TEST_F(CheckRequestUtilsTest, BasicHttp) { - - Http::HeaderMapImpl headers; - envoy::service::auth::v2alpha::CheckRequest request; - - EXPECT_CALL(callbacks_, connection()).Times(2).WillRepeatedly(Return(&connection_)); - EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); - EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); - EXPECT_CALL(Const(connection_), ssl()).Times(2).WillRepeatedly(Return(&ssl_)); - EXPECT_CALL(callbacks_, streamId()).WillOnce(Return(0)); - EXPECT_CALL(callbacks_, requestInfo()).Times(3).WillRepeatedly(ReturnRef(req_info_)); - EXPECT_CALL(req_info_, protocol()).Times(2).WillRepeatedly(ReturnPointee(&protocol_)); - CheckRequestUtils::createHttpCheck(&callbacks_, headers, request); -} - -TEST_F(CheckRequestUtilsTest, CheckAttrContextPeer) { - - Http::TestHeaderMapImpl request_headers{{"x-envoy-downstream-service-cluster", "foo"}, - {":path", "/bar"}}; - envoy::service::auth::v2alpha::CheckRequest request; - - EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection_)); - EXPECT_CALL(connection_, remoteAddress()).WillRepeatedly(ReturnRef(addr_)); - EXPECT_CALL(connection_, localAddress()).WillRepeatedly(ReturnRef(addr_)); - EXPECT_CALL(Const(connection_), ssl()).WillRepeatedly(Return(&ssl_)); - EXPECT_CALL(callbacks_, streamId()).WillRepeatedly(Return(0)); - EXPECT_CALL(callbacks_, requestInfo()).WillRepeatedly(ReturnRef(req_info_)); - EXPECT_CALL(req_info_, protocol()).WillRepeatedly(ReturnPointee(&protocol_)); - - EXPECT_CALL(ssl_, uriSanPeerCertificate()).WillOnce(Return("source")); - EXPECT_CALL(ssl_, uriSanLocalCertificate()).WillOnce(Return("destination")); - CheckRequestUtils::createHttpCheck(&callbacks_, request_headers, request); - - EXPECT_EQ("source", request.attributes().source().principal()); - EXPECT_EQ("destination", request.attributes().destination().principal()); - EXPECT_EQ("foo", request.attributes().source().service()); -} - -} // namespace ExtAuthz -} // namespace Common -} // namespace Filters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/common/ext_authz/mocks.cc b/test/extensions/filters/common/ext_authz/mocks.cc index 7416e537dcfe7..99bbd23ab0ea0 100644 --- a/test/extensions/filters/common/ext_authz/mocks.cc +++ b/test/extensions/filters/common/ext_authz/mocks.cc @@ -9,6 +9,9 @@ namespace ExtAuthz { MockClient::MockClient() {} MockClient::~MockClient() {} +MockRequestCallbacks::MockRequestCallbacks() {} +MockRequestCallbacks::~MockRequestCallbacks() {} + } // namespace ExtAuthz } // namespace Common } // namespace Filters diff --git a/test/extensions/filters/common/ext_authz/mocks.h b/test/extensions/filters/common/ext_authz/mocks.h index e70e4db8e324b..0ec15d0c86023 100644 --- a/test/extensions/filters/common/ext_authz/mocks.h +++ b/test/extensions/filters/common/ext_authz/mocks.h @@ -25,6 +25,16 @@ class MockClient : public Client { Tracing::Span& parent_span)); }; +class MockRequestCallbacks : public RequestCallbacks { +public: + MockRequestCallbacks(); + ~MockRequestCallbacks(); + + void onComplete(ResponsePtr&& response) override { onComplete_(response); } + + MOCK_METHOD1(onComplete_, void(ResponsePtr& response)); +}; + } // namespace ExtAuthz } // namespace Common } // namespace Filters diff --git a/test/extensions/filters/common/ext_authz/test_common.cc b/test/extensions/filters/common/ext_authz/test_common.cc new file mode 100644 index 0000000000000..cc7cfb8e60999 --- /dev/null +++ b/test/extensions/filters/common/ext_authz/test_common.cc @@ -0,0 +1,99 @@ +#include "test/extensions/filters/common/ext_authz/test_common.h" + +#include "test/mocks/upstream/mocks.h" + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +CheckResponsePtr TestCommon::makeCheckResponse(Grpc::Status::GrpcStatus response_status, + envoy::type::StatusCode http_status_code, + const std::string& body, + const HeaderValueOptionVector& headers) { + auto response = std::make_unique(); + auto status = response->mutable_status(); + status->set_code(response_status); + + if (response_status != Grpc::Status::GrpcStatus::Ok) { + const auto denied_response = response->mutable_denied_response(); + if (!body.empty()) { + denied_response->set_body(body); + } + + auto status_code = denied_response->mutable_status(); + status_code->set_code(http_status_code); + + auto denied_response_headers = denied_response->mutable_headers(); + if (!headers.empty()) { + for (const auto& header : headers) { + auto* item = denied_response_headers->Add(); + item->CopyFrom(header); + } + } + } else { + if (!headers.empty()) { + const auto ok_response_headers = response->mutable_ok_response()->mutable_headers(); + for (const auto& header : headers) { + auto* item = ok_response_headers->Add(); + item->CopyFrom(header); + } + } + } + return response; +} + +Response TestCommon::makeAuthzResponse(CheckStatus status, Http::Code status_code, + const std::string& body, + const HeaderValueOptionVector& headers) { + auto authz_response = Response{}; + authz_response.status = status; + authz_response.status_code = status_code; + if (!body.empty()) { + authz_response.body = body; + } + if (!headers.empty()) { + for (auto& header : headers) { + if (header.append().value()) { + authz_response.headers_to_append.emplace_back(Http::LowerCaseString(header.header().key()), + header.header().value()); + } else { + authz_response.headers_to_add.emplace_back(Http::LowerCaseString(header.header().key()), + header.header().value()); + } + } + } + return authz_response; +} + +HeaderValueOptionVector TestCommon::makeHeaderValueOption(KeyValueOptionVector&& headers) { + HeaderValueOptionVector header_option_vector{}; + for (auto header : headers) { + envoy::api::v2::core::HeaderValueOption header_value_option; + auto* mutable_header = header_value_option.mutable_header(); + mutable_header->set_key(header.key); + mutable_header->set_value(header.value); + header_value_option.mutable_append()->set_value(header.append); + header_option_vector.push_back(header_value_option); + } + return header_option_vector; +} + +Http::MessagePtr TestCommon::makeMessageResponse(const HeaderValueOptionVector& headers, + const std::string& body) { + Http::MessagePtr response( + new Http::ResponseMessageImpl(Http::HeaderMapPtr{new Http::TestHeaderMapImpl{}})); + for (auto& header : headers) { + response->headers().addCopy(Http::LowerCaseString(header.header().key()), + header.header().value()); + } + response->body().reset(new Buffer::OwnedImpl(body)); + return response; +}; + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/common/ext_authz/test_common.h b/test/extensions/filters/common/ext_authz/test_common.h new file mode 100644 index 0000000000000..268f4cb7e202b --- /dev/null +++ b/test/extensions/filters/common/ext_authz/test_common.h @@ -0,0 +1,105 @@ +#pragma once + +#include "envoy/api/v2/core/base.pb.h" + +#include "common/http/headers.h" + +#include "extensions/filters/common/ext_authz/ext_authz_grpc_impl.h" + +#include "test/extensions/filters/common/ext_authz/mocks.h" + +namespace Envoy { +namespace Extensions { +namespace Filters { +namespace Common { +namespace ExtAuthz { + +MATCHER_P(AuthzErrorResponse, status, "") { return arg->status == status; } + +MATCHER_P(AuthzResponseNoAttributes, response, "") { + if (arg->status != response.status) { + return false; + } + return true; +} + +MATCHER_P(AuthzDeniedResponse, response, "") { + if (arg->status != response.status) { + return false; + } + if (arg->status_code != response.status_code) { + return false; + } + if (arg->body.compare(response.body)) { + return false; + } + // Compare headers_to_add. + if (!arg->headers_to_add.empty() && response.headers_to_add.empty()) { + return false; + } + if (!std::equal(arg->headers_to_add.begin(), arg->headers_to_add.end(), + response.headers_to_add.begin())) { + return false; + } + + return true; +} + +MATCHER_P(AuthzOkResponse, response, "") { + if (arg->status != response.status) { + return false; + } + // Compare headers_to_apppend. + if (!arg->headers_to_append.empty() && response.headers_to_append.empty()) { + return false; + } + if (!std::equal(arg->headers_to_append.begin(), arg->headers_to_append.end(), + response.headers_to_append.begin())) { + return false; + } + // Compare headers_to_add. + if (!arg->headers_to_add.empty() && response.headers_to_add.empty()) { + return false; + } + if (!std::equal(arg->headers_to_add.begin(), arg->headers_to_add.end(), + response.headers_to_add.begin())) { + return false; + } + + return true; +} + +struct KeyValueOption { + std::string key; + std::string value; + bool append; +}; + +typedef std::vector KeyValueOptionVector; +typedef std::vector HeaderValueOptionVector; +typedef std::unique_ptr CheckResponsePtr; + +class TestCommon { +public: + static Http::MessagePtr makeMessageResponse(const HeaderValueOptionVector& headers, + const std::string& body = std::string{}); + + static CheckResponsePtr + makeCheckResponse(Grpc::Status::GrpcStatus response_status = Grpc::Status::GrpcStatus::Ok, + envoy::type::StatusCode http_status_code = envoy::type::StatusCode::OK, + const std::string& body = std::string{}, + const HeaderValueOptionVector& headers = HeaderValueOptionVector{}); + + static Response + makeAuthzResponse(CheckStatus status, Http::Code status_code = Http::Code::OK, + const std::string& body = std::string{}, + const HeaderValueOptionVector& headers = HeaderValueOptionVector{}); + + static HeaderValueOptionVector makeHeaderValueOption(KeyValueOptionVector&& headers); +}; + +} // namespace ExtAuthz +} // namespace Common +} // namespace Filters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/common/lua/BUILD b/test/extensions/filters/common/lua/BUILD index 254c6f4a90c91..8ba2dbcd32ade 100644 --- a/test/extensions/filters/common/lua/BUILD +++ b/test/extensions/filters/common/lua/BUILD @@ -27,6 +27,8 @@ envoy_cc_test( ":lua_wrappers_lib", "//source/common/buffer:buffer_lib", "//source/extensions/filters/common/lua:wrappers_lib", + "//test/mocks/network:network_mocks", + "//test/mocks/ssl:ssl_mocks", "//test/test_common:utility_lib", ], ) diff --git a/test/extensions/filters/common/lua/wrappers_test.cc b/test/extensions/filters/common/lua/wrappers_test.cc index 0bbbc282d0148..a540c49b2965e 100644 --- a/test/extensions/filters/common/lua/wrappers_test.cc +++ b/test/extensions/filters/common/lua/wrappers_test.cc @@ -3,6 +3,8 @@ #include "extensions/filters/common/lua/wrappers.h" #include "test/extensions/filters/common/lua/lua_wrappers.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/ssl/mocks.h" #include "test/test_common/utility.h" namespace Envoy { @@ -27,6 +29,42 @@ class LuaMetadataMapWrapperTest : public LuaWrappersTestBase } }; +class LuaConnectionWrapperTest : public LuaWrappersTestBase { +public: + virtual void setup(const std::string& script) { + LuaWrappersTestBase::setup(script); + state_->registerType(); + } + +protected: + void expectSecureConnection(const bool secure) { + const std::string SCRIPT{R"EOF( + function callMe(object) + if object:ssl() == nil then + testPrint("plain") + else + testPrint("secure") + end + testPrint(type(object:ssl())) + end + )EOF"}; + testing::InSequence s; + setup(SCRIPT); + + // Setup secure connection if required. + EXPECT_CALL(Const(connection_), ssl()).WillOnce(Return(secure ? &ssl_ : nullptr)); + + ConnectionWrapper::create(coroutine_->luaState(), &connection_); + EXPECT_CALL(*this, testPrint(secure ? "secure" : "plain")); + EXPECT_CALL(Const(connection_), ssl()).WillOnce(Return(secure ? &ssl_ : nullptr)); + EXPECT_CALL(*this, testPrint(secure ? "userdata" : "nil")); + start("callMe"); + } + + NiceMock connection_; + NiceMock ssl_; +}; + // Basic buffer wrapper methods test. TEST_F(LuaBufferWrapperTest, Methods) { const std::string SCRIPT{R"EOF( @@ -224,6 +262,11 @@ TEST_F(LuaMetadataMapWrapperTest, DontFinishIteration) { "[string \"...\"]:5: cannot create a second iterator before completing the first"); } +TEST_F(LuaConnectionWrapperTest, Secure) { + expectSecureConnection(true); + expectSecureConnection(false); +} + } // namespace Lua } // namespace Common } // namespace Filters diff --git a/test/extensions/filters/common/rbac/engine_impl_test.cc b/test/extensions/filters/common/rbac/engine_impl_test.cc index 85e8f891ea538..b4f441f44cab1 100644 --- a/test/extensions/filters/common/rbac/engine_impl_test.cc +++ b/test/extensions/filters/common/rbac/engine_impl_test.cc @@ -20,10 +20,12 @@ namespace Common { namespace RBAC { namespace { -void checkEngine(const RBAC::RoleBasedAccessControlEngineImpl& engine, bool expected, - const Envoy::Network::Connection& connection = Envoy::Network::MockConnection(), - const Envoy::Http::HeaderMap& headers = Envoy::Http::HeaderMapImpl()) { - EXPECT_EQ(expected, engine.allowed(connection, headers)); +void checkEngine( + const RBAC::RoleBasedAccessControlEngineImpl& engine, bool expected, + const Envoy::Network::Connection& connection = Envoy::Network::MockConnection(), + const Envoy::Http::HeaderMap& headers = Envoy::Http::HeaderMapImpl(), + const envoy::api::v2::core::Metadata& metadata = envoy::api::v2::core::Metadata()) { + EXPECT_EQ(expected, engine.allowed(connection, headers, metadata)); } TEST(RoleBasedAccessControlEngineImpl, Disabled) { diff --git a/test/extensions/filters/common/rbac/matchers_test.cc b/test/extensions/filters/common/rbac/matchers_test.cc index 680293acb8f0c..dae4eb73e07fc 100644 --- a/test/extensions/filters/common/rbac/matchers_test.cc +++ b/test/extensions/filters/common/rbac/matchers_test.cc @@ -20,10 +20,12 @@ namespace Common { namespace RBAC { namespace { -void checkMatcher(const RBAC::Matcher& matcher, bool expected, - const Envoy::Network::Connection& connection = Envoy::Network::MockConnection(), - const Envoy::Http::HeaderMap& headers = Envoy::Http::HeaderMapImpl()) { - EXPECT_EQ(expected, matcher.matches(connection, headers)); +void checkMatcher( + const RBAC::Matcher& matcher, bool expected, + const Envoy::Network::Connection& connection = Envoy::Network::MockConnection(), + const Envoy::Http::HeaderMap& headers = Envoy::Http::HeaderMapImpl(), + const envoy::api::v2::core::Metadata& metadata = envoy::api::v2::core::Metadata()) { + EXPECT_EQ(expected, matcher.matches(connection, headers, metadata)); } TEST(AlwaysMatcher, AlwaysMatches) { checkMatcher(RBAC::AlwaysMatcher(), true); } @@ -114,6 +116,20 @@ TEST(OrMatcher, Principal_Set) { checkMatcher(RBAC::OrMatcher(set), true, conn); } +TEST(NotMatcher, Permission) { + envoy::config::rbac::v2alpha::Permission perm; + perm.set_any(true); + + checkMatcher(RBAC::NotMatcher(perm), false, Envoy::Network::MockConnection()); +} + +TEST(NotMatcher, Principal) { + envoy::config::rbac::v2alpha::Principal principal; + principal.set_any(true); + + checkMatcher(RBAC::NotMatcher(principal), false, Envoy::Network::MockConnection()); +} + TEST(HeaderMatcher, HeaderMatcher) { envoy::api::v2::route::HeaderMatcher config; config.set_name("foo"); @@ -210,6 +226,27 @@ TEST(AuthenticatedMatcher, NoSSL) { checkMatcher(AuthenticatedMatcher({}), false, conn); } +TEST(MetadataMatcher, MetadataMatcher) { + Envoy::Network::MockConnection conn; + Envoy::Http::HeaderMapImpl header; + + auto label = MessageUtil::keyValueStruct("label", "prod"); + envoy::api::v2::core::Metadata metadata; + metadata.mutable_filter_metadata()->insert( + Protobuf::MapPair("other", label)); + metadata.mutable_filter_metadata()->insert( + Protobuf::MapPair("rbac", label)); + + envoy::type::matcher::MetadataMatcher matcher; + matcher.set_filter("rbac"); + matcher.add_path()->set_key("label"); + + matcher.mutable_value()->mutable_string_match()->set_exact("test"); + checkMatcher(MetadataMatcher(matcher), false, conn, header, metadata); + matcher.mutable_value()->mutable_string_match()->set_exact("prod"); + checkMatcher(MetadataMatcher(matcher), true, conn, header, metadata); +} + TEST(PolicyMatcher, PolicyMatcher) { envoy::config::rbac::v2alpha::Policy policy; policy.add_permissions()->set_destination_port(123); diff --git a/test/extensions/filters/common/rbac/mocks.h b/test/extensions/filters/common/rbac/mocks.h index e95f63f21d256..a289b6cea8c37 100644 --- a/test/extensions/filters/common/rbac/mocks.h +++ b/test/extensions/filters/common/rbac/mocks.h @@ -15,8 +15,8 @@ class MockEngine : public RoleBasedAccessControlEngineImpl { MockEngine(const envoy::config::rbac::v2alpha::RBAC& rules) : RoleBasedAccessControlEngineImpl(rules){}; - MOCK_CONST_METHOD2(allowed, - bool(const Envoy::Network::Connection&, const Envoy::Http::HeaderMap&)); + MOCK_CONST_METHOD3(allowed, bool(const Envoy::Network::Connection&, const Envoy::Http::HeaderMap&, + const envoy::api::v2::core::Metadata&)); }; } // namespace RBAC diff --git a/test/extensions/filters/http/buffer/buffer_filter_integration_test.cc b/test/extensions/filters/http/buffer/buffer_filter_integration_test.cc index a40ad1732a9e7..c44fb91fadca8 100644 --- a/test/extensions/filters/http/buffer/buffer_filter_integration_test.cc +++ b/test/extensions/filters/http/buffer/buffer_filter_integration_test.cc @@ -56,7 +56,7 @@ TEST_P(BufferIntegrationTest, RouterRequestBufferLimitExceeded) { ConfigHelper::HttpModifierFunction overrideConfig(const std::string& json_config) { ProtobufWkt::Struct pfc; - RELEASE_ASSERT(Protobuf::util::JsonStringToMessage(json_config, &pfc).ok()); + RELEASE_ASSERT(Protobuf::util::JsonStringToMessage(json_config, &pfc).ok(), ""); return [pfc]( diff --git a/test/extensions/filters/http/buffer/buffer_filter_test.cc b/test/extensions/filters/http/buffer/buffer_filter_test.cc index b1749f70cbf44..6a2cf88b6dde2 100644 --- a/test/extensions/filters/http/buffer/buffer_filter_test.cc +++ b/test/extensions/filters/http/buffer/buffer_filter_test.cc @@ -46,10 +46,10 @@ class BufferFilterTest : public testing::Test { void routeLocalConfig(const Router::RouteSpecificFilterConfig* route_settings, const Router::RouteSpecificFilterConfig* vhost_settings) { - ON_CALL(callbacks_.route_->route_entry_, perFilterConfig(HttpFilterNames::get().BUFFER)) + ON_CALL(callbacks_.route_->route_entry_, perFilterConfig(HttpFilterNames::get().Buffer)) .WillByDefault(Return(route_settings)); ON_CALL(callbacks_.route_->route_entry_.virtual_host_, - perFilterConfig(HttpFilterNames::get().BUFFER)) + perFilterConfig(HttpFilterNames::get().Buffer)) .WillByDefault(Return(vhost_settings)); } diff --git a/test/extensions/filters/http/cors/cors_filter_integration_test.cc b/test/extensions/filters/http/cors/cors_filter_integration_test.cc index d35d84d581bca..dac0c8600c019 100644 --- a/test/extensions/filters/http/cors/cors_filter_integration_test.cc +++ b/test/extensions/filters/http/cors/cors_filter_integration_test.cc @@ -46,9 +46,9 @@ class CorsFilterIntegrationTest : public HttpIntegrationTest, cors->add_allow_origin("test-host-2"); cors->set_allow_headers("content-type"); cors->set_allow_methods("POST"); - cors->set_expose_headers("content-type"); cors->set_max_age("100"); } + { auto* route = virtual_host->add_routes(); route->mutable_match()->set_prefix("/cors-credentials-allowed"); @@ -57,6 +57,23 @@ class CorsFilterIntegrationTest : public HttpIntegrationTest, cors->add_allow_origin("test-origin-1"); cors->mutable_allow_credentials()->set_value(true); } + + { + auto* route = virtual_host->add_routes(); + route->mutable_match()->set_prefix("/cors-allow-origin-regex"); + route->mutable_route()->set_cluster("cluster_0"); + auto* cors = route->mutable_route()->mutable_cors(); + cors->add_allow_origin_regex(".*\\.envoyproxy\\.io"); + } + + { + auto* route = virtual_host->add_routes(); + route->mutable_match()->set_prefix("/cors-expose-headers"); + route->mutable_route()->set_cluster("cluster_0"); + auto* cors = route->mutable_route()->mutable_cors(); + cors->add_allow_origin("test-origin-1"); + cors->set_expose_headers("custom-header-1,custom-header-2"); + } }); HttpIntegrationTest::initialize(); } @@ -128,7 +145,6 @@ TEST_P(CorsFilterIntegrationTest, TestRouteConfigSuccess) { {"access-control-allow-origin", "test-origin-1"}, {"access-control-allow-methods", "POST"}, {"access-control-allow-headers", "content-type"}, - {"access-control-expose-headers", "content-type"}, {"access-control-max-age", "100"}, {"server", "envoy"}, {"content-length", "0"}, @@ -204,4 +220,40 @@ TEST_P(CorsFilterIntegrationTest, TestEncodeHeadersCredentialsAllowed) { {":status", "200"}, }); } + +TEST_P(CorsFilterIntegrationTest, TestAllowedOriginRegex) { + testNormalRequest( + Http::TestHeaderMapImpl{ + {":method", "GET"}, + {":path", "/cors-allow-origin-regex/test"}, + {":scheme", "http"}, + {":authority", "test-host"}, + {"origin", "www.envoyproxy.io"}, + }, + Http::TestHeaderMapImpl{ + {"access-control-allow-origin", "www.envoyproxy.io"}, + {"access-control-allow-credentials", "true"}, + {"server", "envoy"}, + {"content-length", "0"}, + {":status", "200"}, + }); +} + +TEST_P(CorsFilterIntegrationTest, TestExposeHeaders) { + testNormalRequest( + Http::TestHeaderMapImpl{ + {":method", "GET"}, + {":path", "/cors-expose-headers/test"}, + {":scheme", "http"}, + {":authority", "test-host"}, + {"origin", "test-origin-1"}, + }, + Http::TestHeaderMapImpl{ + {"access-control-allow-origin", "test-origin-1"}, + {"access-control-expose-headers", "custom-header-1,custom-header-2"}, + {"server", "envoy"}, + {"content-length", "0"}, + {":status", "200"}, + }); +} } // namespace Envoy diff --git a/test/extensions/filters/http/cors/cors_filter_test.cc b/test/extensions/filters/http/cors/cors_filter_test.cc index faf2dabd89f06..885f8830051be 100644 --- a/test/extensions/filters/http/cors/cors_filter_test.cc +++ b/test/extensions/filters/http/cors/cors_filter_test.cc @@ -164,7 +164,6 @@ TEST_F(CorsFilterTest, OptionsRequestMatchingOriginByWildcard) { {"access-control-allow-origin", "test-host"}, {"access-control-allow-methods", "GET"}, {"access-control-allow-headers", "content-type"}, - {"access-control-expose-headers", "content-type"}, {"access-control-max-age", "0"}, }; EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); @@ -229,7 +228,6 @@ TEST_F(CorsFilterTest, ValidOptionsRequestWithAllowCredentialsTrue) { {"access-control-allow-credentials", "true"}, {"access-control-allow-methods", "GET"}, {"access-control-allow-headers", "content-type"}, - {"access-control-expose-headers", "content-type"}, {"access-control-max-age", "0"}, }; EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); @@ -254,7 +252,6 @@ TEST_F(CorsFilterTest, ValidOptionsRequestWithAllowCredentialsFalse) { {"access-control-allow-origin", "localhost"}, {"access-control-allow-methods", "GET"}, {"access-control-allow-headers", "content-type"}, - {"access-control-expose-headers", "content-type"}, {"access-control-max-age", "0"}, }; EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); @@ -321,6 +318,27 @@ TEST_F(CorsFilterTest, EncodeWithAllowCredentialsTrue) { EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.encodeTrailers(request_headers_)); } +TEST_F(CorsFilterTest, EncodeWithExposeHeaders) { + Http::TestHeaderMapImpl request_headers{{"origin", "localhost"}}; + cors_policy_->expose_headers_ = "custom-header-1"; + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.decodeTrailers(request_headers_)); + + Http::TestHeaderMapImpl continue_headers{{":status", "100"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, + filter_.encode100ContinueHeaders(continue_headers)); + + Http::TestHeaderMapImpl response_headers{}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.encodeHeaders(response_headers, false)); + EXPECT_EQ("localhost", response_headers.get_("access-control-allow-origin")); + EXPECT_EQ("custom-header-1", response_headers.get_("access-control-expose-headers")); + + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.encodeTrailers(request_headers_)); +} + TEST_F(CorsFilterTest, EncodeWithAllowCredentialsFalse) { Http::TestHeaderMapImpl request_headers{{"origin", "localhost"}}; EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); @@ -437,7 +455,6 @@ TEST_F(CorsFilterTest, NoRouteCorsEntry) { {"access-control-allow-origin", "localhost"}, {"access-control-allow-methods", "GET"}, {"access-control-allow-headers", "content-type"}, - {"access-control-expose-headers", "content-type"}, {"access-control-max-age", "0"}, }; EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); @@ -466,7 +483,6 @@ TEST_F(CorsFilterTest, NoVHostCorsEntry) { {":status", "200"}, {"access-control-allow-origin", "localhost"}, {"access-control-allow-headers", "content-type"}, - {"access-control-expose-headers", "content-type"}, {"access-control-max-age", "0"}, }; EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); @@ -482,6 +498,54 @@ TEST_F(CorsFilterTest, NoVHostCorsEntry) { EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.encodeTrailers(request_headers_)); } +TEST_F(CorsFilterTest, OptionsRequestMatchingOriginByRegex) { + Http::TestHeaderMapImpl request_headers{{":method", "OPTIONS"}, + {"origin", "www.envoyproxy.io"}, + {"access-control-request-method", "GET"}}; + + Http::TestHeaderMapImpl response_headers{ + {":status", "200"}, + {"access-control-allow-origin", "www.envoyproxy.io"}, + {"access-control-allow-methods", "GET"}, + {"access-control-allow-headers", "content-type"}, + {"access-control-max-age", "0"}, + }; + + cors_policy_->allow_origin_.clear(); + cors_policy_->allow_origin_regex_.push_back(std::regex(".*")); + + EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); + + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_.decodeHeaders(request_headers, false)); + EXPECT_EQ(true, IsCorsRequest()); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.decodeTrailers(request_headers_)); + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.encodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.encodeTrailers(request_headers_)); +} + +TEST_F(CorsFilterTest, OptionsRequestNotMatchingOriginByRegex) { + Http::TestHeaderMapImpl request_headers{{":method", "OPTIONS"}, + {"origin", "www.envoyproxy.com"}, + {"access-control-request-method", "GET"}}; + + cors_policy_->allow_origin_.clear(); + cors_policy_->allow_origin_regex_.push_back(std::regex(".*.envoyproxy.io")); + + EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, false)).Times(0); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + EXPECT_EQ(false, IsCorsRequest()); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.decodeTrailers(request_headers_)); + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.encodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.encodeTrailers(request_headers_)); +} + } // namespace Cors } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/dynamo/BUILD b/test/extensions/filters/http/dynamo/BUILD index f4a644d0ff404..01797af74456c 100644 --- a/test/extensions/filters/http/dynamo/BUILD +++ b/test/extensions/filters/http/dynamo/BUILD @@ -46,6 +46,7 @@ envoy_extension_cc_test( deps = [ "//source/common/stats:stats_lib", "//source/extensions/filters/http/dynamo:dynamo_utility_lib", + "//test/mocks/stats:stats_mocks", ], ) diff --git a/test/extensions/filters/http/dynamo/dynamo_filter_test.cc b/test/extensions/filters/http/dynamo/dynamo_filter_test.cc index 005822a42ec53..5215f814eee45 100644 --- a/test/extensions/filters/http/dynamo/dynamo_filter_test.cc +++ b/test/extensions/filters/http/dynamo/dynamo_filter_test.cc @@ -44,7 +44,7 @@ class DynamoFilterTest : public testing::Test { std::unique_ptr filter_; NiceMock loader_; std::string stat_prefix_{"prefix."}; - Stats::MockStore stats_; + NiceMock stats_; NiceMock decoder_callbacks_; NiceMock encoder_callbacks_; }; diff --git a/test/extensions/filters/http/dynamo/dynamo_utility_test.cc b/test/extensions/filters/http/dynamo/dynamo_utility_test.cc index 2661e1819a321..0e4b31872cce5 100644 --- a/test/extensions/filters/http/dynamo/dynamo_utility_test.cc +++ b/test/extensions/filters/http/dynamo/dynamo_utility_test.cc @@ -4,9 +4,13 @@ #include "extensions/filters/http/dynamo/dynamo_utility.h" +#include "test/mocks/stats/mocks.h" + #include "gmock/gmock.h" #include "gtest/gtest.h" +using testing::NiceMock; +using testing::Return; using testing::_; namespace Envoy { @@ -15,17 +19,20 @@ namespace HttpFilters { namespace Dynamo { TEST(DynamoUtility, PartitionIdStatString) { + Stats::StatsOptionsImpl stats_options; + stats_options.max_obj_name_length_ = 60; + { std::string stat_prefix = "stat.prefix."; std::string table_name = "locations"; std::string operation = "GetItem"; std::string partition_id = "6235c781-1d0d-47a3-a4ea-eec04c5883ca"; - std::string partition_stat_string = - Utility::buildPartitionStatString(stat_prefix, table_name, operation, partition_id); + std::string partition_stat_string = Utility::buildPartitionStatString( + stat_prefix, table_name, operation, partition_id, stats_options); std::string expected_stat_string = "stat.prefix.table.locations.capacity.GetItem.__partition_id=c5883ca"; EXPECT_EQ(expected_stat_string, partition_stat_string); - EXPECT_TRUE(partition_stat_string.size() <= Stats::RawStatData::maxNameLength()); + EXPECT_TRUE(partition_stat_string.size() <= stats_options.maxNameLength()); } { @@ -34,13 +41,13 @@ TEST(DynamoUtility, PartitionIdStatString) { std::string operation = "GetItem"; std::string partition_id = "6235c781-1d0d-47a3-a4ea-eec04c5883ca"; - std::string partition_stat_string = - Utility::buildPartitionStatString(stat_prefix, table_name, operation, partition_id); + std::string partition_stat_string = Utility::buildPartitionStatString( + stat_prefix, table_name, operation, partition_id, stats_options); std::string expected_stat_string = "http.egress_dynamodb_iad.dynamodb.table.locations-sandbox-" "partition-test-iad-mytest-rea.capacity.GetItem.__partition_" "id=c5883ca"; EXPECT_EQ(expected_stat_string, partition_stat_string); - EXPECT_TRUE(partition_stat_string.size() == Stats::RawStatData::maxNameLength()); + EXPECT_TRUE(partition_stat_string.size() <= stats_options.maxNameLength()); } { std::string stat_prefix = "http.egress_dynamodb_iad.dynamodb."; @@ -48,14 +55,14 @@ TEST(DynamoUtility, PartitionIdStatString) { std::string operation = "GetItem"; std::string partition_id = "6235c781-1d0d-47a3-a4ea-eec04c5883ca"; - std::string partition_stat_string = - Utility::buildPartitionStatString(stat_prefix, table_name, operation, partition_id); + std::string partition_stat_string = Utility::buildPartitionStatString( + stat_prefix, table_name, operation, partition_id, stats_options); std::string expected_stat_string = "http.egress_dynamodb_iad.dynamodb.table.locations-sandbox-" "partition-test-iad-mytest-rea.capacity.GetItem.__partition_" "id=c5883ca"; EXPECT_EQ(expected_stat_string, partition_stat_string); - EXPECT_TRUE(partition_stat_string.size() == Stats::RawStatData::maxNameLength()); + EXPECT_TRUE(partition_stat_string.size() <= stats_options.maxNameLength()); } } diff --git a/test/extensions/filters/http/ext_authz/BUILD b/test/extensions/filters/http/ext_authz/BUILD index 15e1e5e6d8c2b..0692b53be3c3a 100644 --- a/test/extensions/filters/http/ext_authz/BUILD +++ b/test/extensions/filters/http/ext_authz/BUILD @@ -16,6 +16,7 @@ envoy_extension_cc_test( srcs = ["ext_authz_test.cc"], extension_name = "envoy.filters.http.ext_authz", deps = [ + "//include/envoy/http:codes_interface", "//source/common/buffer:buffer_lib", "//source/common/common:empty_string", "//source/common/config:filter_json_lib", @@ -23,7 +24,7 @@ envoy_extension_cc_test( "//source/common/json:json_loader_lib", "//source/common/network:address_lib", "//source/common/protobuf:utility_lib", - "//source/extensions/filters/common/ext_authz:ext_authz_lib", + "//source/extensions/filters/common/ext_authz:ext_authz_grpc_lib", "//source/extensions/filters/http/ext_authz", "//test/extensions/filters/common/ext_authz:ext_authz_mocks", "//test/mocks/http:http_mocks", diff --git a/test/extensions/filters/http/ext_authz/config_test.cc b/test/extensions/filters/http/ext_authz/config_test.cc index 26463bbd945f5..d1b344ad45fe1 100644 --- a/test/extensions/filters/http/ext_authz/config_test.cc +++ b/test/extensions/filters/http/ext_authz/config_test.cc @@ -15,21 +15,24 @@ namespace Extensions { namespace HttpFilters { namespace ExtAuthz { -TEST(HttpExtAuthzConfigTest, ExtAuthzCorrectProto) { +TEST(HttpExtAuthzConfigTest, CorrectProtoGrpc) { std::string yaml = R"EOF( grpc_service: google_grpc: target_uri: ext_authz_server stat_prefix: google failure_mode_allow: false -)EOF"; + )EOF"; ExtAuthzFilterConfig factory; ProtobufTypes::MessagePtr proto_config = factory.createEmptyConfigProto(); MessageUtil::loadFromYaml(yaml, *proto_config); - NiceMock context; - + testing::StrictMock context; + EXPECT_CALL(context, localInfo()).Times(1); + EXPECT_CALL(context, clusterManager()).Times(2); + EXPECT_CALL(context, runtime()).Times(1); + EXPECT_CALL(context, scope()).Times(2); EXPECT_CALL(context.cluster_manager_.async_client_manager_, factoryForGrpcService(_, _, _)) .WillOnce(Invoke([](const envoy::api::v2::core::GrpcService&, Stats::Scope&, bool) { return std::make_unique>(); @@ -40,6 +43,34 @@ TEST(HttpExtAuthzConfigTest, ExtAuthzCorrectProto) { cb(filter_callback); } +TEST(HttpExtAuthzConfigTest, CorrectProtoHttp) { + std::string yaml = R"EOF( + http_service: + server_uri: + uri: "ext_authz:9000" + cluster: "ext_authz" + timeout: 0.25s + path_prefix: "/test" + response_headers_to_remove: + - foo_header_key + - baz_header_key + failure_mode_allow: true + )EOF"; + + ExtAuthzFilterConfig factory; + ProtobufTypes::MessagePtr proto_config = factory.createEmptyConfigProto(); + MessageUtil::loadFromYaml(yaml, *proto_config); + testing::StrictMock context; + EXPECT_CALL(context, localInfo()).Times(1); + EXPECT_CALL(context, clusterManager()).Times(1); + EXPECT_CALL(context, runtime()).Times(1); + EXPECT_CALL(context, scope()).Times(1); + Http::FilterFactoryCb cb = factory.createFilterFactoryFromProto(*proto_config, "stats", context); + testing::StrictMock filter_callback; + EXPECT_CALL(filter_callback, addStreamDecoderFilter(_)); + cb(filter_callback); +} + } // namespace ExtAuthz } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/ext_authz/ext_authz_test.cc b/test/extensions/filters/http/ext_authz/ext_authz_test.cc index 972ae716b25fa..e6aac4d730a19 100644 --- a/test/extensions/filters/http/ext_authz/ext_authz_test.cc +++ b/test/extensions/filters/http/ext_authz/ext_authz_test.cc @@ -2,7 +2,9 @@ #include #include +#include "envoy/config/filter/http/ext_authz/v2alpha/ext_authz.pb.h" #include "envoy/config/filter/http/ext_authz/v2alpha/ext_authz.pb.validate.h" +#include "envoy/http/codes.h" #include "common/buffer/buffer_impl.h" #include "common/common/empty_string.h" @@ -142,6 +144,7 @@ TEST_P(HttpExtAuthzFilterParamTest, OkResponse) { ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(*client_, check(_, _, testing::A())) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { @@ -157,7 +160,10 @@ TEST_P(HttpExtAuthzFilterParamTest, OkResponse) { EXPECT_CALL(filter_callbacks_.request_info_, setResponseFlag(Envoy::RequestInfo::ResponseFlag::UnauthorizedExternalService)) .Times(0); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::OK); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::OK; + request_callbacks_->onComplete(std::make_unique(response)); EXPECT_EQ(1U, cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.ok").value()); @@ -171,10 +177,14 @@ TEST_P(HttpExtAuthzFilterParamTest, ImmediateOkResponse) { ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::OK; + EXPECT_CALL(*client_, check(_, _, _)) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { - callbacks.onComplete(Filters::Common::ExtAuthz::CheckStatus::OK); + callbacks.onComplete(std::make_unique(response)); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -186,6 +196,82 @@ TEST_P(HttpExtAuthzFilterParamTest, ImmediateOkResponse) { cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.ok").value()); } +// Test that an synchronous denied response from the authorization service passing additional HTTP +// attributes to the downstream. +TEST_P(HttpExtAuthzFilterParamTest, ImmediateDeniedResponseWithHttpAttributes) { + InSequence s; + + ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Denied; + response.status_code = Http::Code::Unauthorized; + response.headers_to_add = Http::HeaderVector{{Http::LowerCaseString{"foo"}, "bar"}}; + response.body = std::string{"baz"}; + + auto response_ptr = std::make_unique(response); + + EXPECT_CALL(*client_, check(_, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { + callbacks.onComplete(std::move(response_ptr)); + }))); + + EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndWatermark, filter_->decodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::StopIteration, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.denied").value()); +} + +// Test that an synchronous ok response from the authorization service passing additional HTTP +// attributes to the upstream. +TEST_P(HttpExtAuthzFilterParamTest, ImmediateOkResponseWithHttpAttributes) { + InSequence s; + + // `bar` will be appended to this header. + const Http::LowerCaseString request_header_key{"baz"}; + request_headers_.addCopy(request_header_key, "foo"); + + // `foo` will be added to this key. + const Http::LowerCaseString key_to_add{"bar"}; + + // `foo` will be override with `bar`. + const Http::LowerCaseString key_to_override{"foobar"}; + request_headers_.addCopy("foobar", "foo"); + + ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::OK; + response.headers_to_append = Http::HeaderVector{{request_header_key, "bar"}}; + response.headers_to_add = Http::HeaderVector{{key_to_add, "foo"}, {key_to_override, "bar"}}; + + auto response_ptr = std::make_unique(response); + + EXPECT_CALL(*client_, check(_, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { + callbacks.onComplete(std::move(response_ptr)); + }))); + + EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers_, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data_, false)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_->decodeTrailers(request_headers_)); + EXPECT_EQ(request_headers_.get_(request_header_key), "foo,bar"); + EXPECT_EQ(request_headers_.get_(key_to_add), "foo"); + EXPECT_EQ(request_headers_.get_(key_to_override), "bar"); +} + // Test that an synchronous denied response from the authorization service, on the call stack, // results in request not continuing. TEST_P(HttpExtAuthzFilterParamTest, ImmediateDeniedResponse) { @@ -194,10 +280,13 @@ TEST_P(HttpExtAuthzFilterParamTest, ImmediateDeniedResponse) { ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Denied; EXPECT_CALL(*client_, check(_, _, _)) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { - callbacks.onComplete(Filters::Common::ExtAuthz::CheckStatus::Denied); + callbacks.onComplete(std::make_unique(response)); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); @@ -211,8 +300,8 @@ TEST_P(HttpExtAuthzFilterParamTest, ImmediateDeniedResponse) { cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.denied").value()); } -// Test that a denied response results in the connection closing with a 403 response to the client. -TEST_P(HttpExtAuthzFilterParamTest, DeniedResponse) { +// Test that a denied response results in the connection closing with a 401 response to the client. +TEST_P(HttpExtAuthzFilterParamTest, DeniedResponseWith401) { InSequence s; ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); @@ -226,12 +315,171 @@ TEST_P(HttpExtAuthzFilterParamTest, DeniedResponse) { EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_->decodeHeaders(request_headers_, false)); + + Http::TestHeaderMapImpl response_headers{{":status", "401"}}; + + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); + + EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); + EXPECT_CALL(filter_callbacks_.request_info_, + setResponseFlag(Envoy::RequestInfo::ResponseFlag::UnauthorizedExternalService)); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Denied; + response.status_code = Http::Code::Unauthorized; + request_callbacks_->onComplete(std::make_unique(response)); + + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.denied").value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_4xx").value()); +} + +// Test that a denied response results in the connection closing with a 403 response to the client. +TEST_P(HttpExtAuthzFilterParamTest, DeniedResponseWith403) { + InSequence s; + + ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(*client_, check(_, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + Http::TestHeaderMapImpl response_headers{{":status", "403"}}; + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), true)); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); EXPECT_CALL(filter_callbacks_.request_info_, setResponseFlag(Envoy::RequestInfo::ResponseFlag::UnauthorizedExternalService)); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Denied); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Denied; + response.status_code = Http::Code::Forbidden; + request_callbacks_->onComplete(std::make_unique(response)); + + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.denied").value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_4xx").value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_403").value()); +} + +// Verify that authz response memory is not used after free. +TEST_P(HttpExtAuthzFilterParamTest, DestroyResponseBeforeSendLocalReply) { + InSequence s; + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Denied; + response.status_code = Http::Code::Forbidden; + response.body = std::string{"foo"}; + response.headers_to_add = Http::HeaderVector{{Http::LowerCaseString{"foo"}, "bar"}, + {Http::LowerCaseString{"bar"}, "foo"}}; + Filters::Common::ExtAuthz::ResponsePtr response_ptr = + std::make_unique(response); + + ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(*client_, check(_, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + + Http::TestHeaderMapImpl response_headers{{":status", "403"}, + {"content-length", "3"}, + {"content-type", "text/plain"}, + {"foo", "bar"}, + {"bar", "foo"}}; + + Http::HeaderMap* saved_headers; + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), false)) + .WillOnce(Invoke([&](Http::HeaderMap& headers, bool) { saved_headers = &headers; })); + + EXPECT_CALL(filter_callbacks_, encodeData(_, true)) + .WillOnce(Invoke([&](Buffer::Instance& data, bool) { + response_ptr.reset(); + Http::TestHeaderMapImpl test_headers{*saved_headers}; + EXPECT_EQ(test_headers.get_("foo"), "bar"); + EXPECT_EQ(test_headers.get_("bar"), "foo"); + EXPECT_EQ(data.toString(), "foo"); + })); + + request_callbacks_->onComplete(std::move(response_ptr)); + + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("ext_authz.denied").value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_4xx").value()); + EXPECT_EQ( + 1U, + cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_403").value()); +} + +// Verify that authz denied response headers overrides the existing encoding headers. +TEST_P(HttpExtAuthzFilterParamTest, OverrideEncodingHeaders) { + InSequence s; + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Denied; + response.status_code = Http::Code::Forbidden; + response.body = std::string{"foo"}; + response.headers_to_add = Http::HeaderVector{{Http::LowerCaseString{"foo"}, "bar"}, + {Http::LowerCaseString{"bar"}, "foo"}}; + Filters::Common::ExtAuthz::ResponsePtr response_ptr = + std::make_unique(response); + + ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); + EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + EXPECT_CALL(*client_, check(_, _, _)) + .WillOnce( + WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { + request_callbacks_ = &callbacks; + }))); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers_, false)); + + Http::TestHeaderMapImpl response_headers{{":status", "403"}, + {"content-length", "3"}, + {"content-type", "text/plain"}, + {"foo", "bar"}, + {"bar", "foo"}}; + + Http::HeaderMap* saved_headers; + EXPECT_CALL(filter_callbacks_, encodeHeaders_(HeaderMapEqualRef(&response_headers), false)) + .WillOnce(Invoke([&](Http::HeaderMap& headers, bool) { + headers.addCopy(Http::LowerCaseString{"foo"}, std::string{"OVERRIDE_WITH_bar"}); + headers.addCopy(Http::LowerCaseString{"foobar"}, std::string{"DO_NOT_OVERRIDE"}); + saved_headers = &headers; + })); + + EXPECT_CALL(filter_callbacks_, encodeData(_, true)) + .WillOnce(Invoke([&](Buffer::Instance& data, bool) { + response_ptr.reset(); + Http::TestHeaderMapImpl test_headers{*saved_headers}; + EXPECT_EQ(test_headers.get_("foo"), "bar"); + EXPECT_EQ(test_headers.get_("bar"), "foo"); + EXPECT_EQ(test_headers.get_("foobar"), "DO_NOT_OVERRIDE"); + EXPECT_EQ(data.toString(), "foo"); + })); + + request_callbacks_->onComplete(std::move(response_ptr)); EXPECT_EQ( 1U, @@ -244,8 +492,8 @@ TEST_P(HttpExtAuthzFilterParamTest, DeniedResponse) { cm_.thread_local_cluster_.cluster_.info_->stats_store_.counter("upstream_rq_403").value()); } -// Test that when a connection awaiting a authorization response is canceled then the authorization -// call is closed. +// Test that when a connection awaiting a authorization response is canceled then the +// authorization call is closed. TEST_P(HttpExtAuthzFilterParamTest, ResetDuringCall) { InSequence s; @@ -300,11 +548,13 @@ TEST_F(HttpExtAuthzFilterTest, ErrorFailClose) { WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { request_callbacks_ = &callbacks; }))); - EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_->decodeHeaders(request_headers_, false)); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Error; + request_callbacks_->onComplete(std::make_unique(response)); EXPECT_EQ( 1U, @@ -325,11 +575,13 @@ TEST_F(HttpExtAuthzFilterTest, ErrorOpen) { WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { request_callbacks_ = &callbacks; }))); - EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, filter_->decodeHeaders(request_headers_, false)); EXPECT_CALL(filter_callbacks_, continueDecoding()); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Error; + request_callbacks_->onComplete(std::make_unique(response)); EXPECT_EQ( 1U, @@ -345,10 +597,13 @@ TEST_F(HttpExtAuthzFilterTest, ImmediateErrorOpen) { ON_CALL(filter_callbacks_, connection()).WillByDefault(Return(&connection_)); EXPECT_CALL(connection_, remoteAddress()).WillOnce(ReturnRef(addr_)); EXPECT_CALL(connection_, localAddress()).WillOnce(ReturnRef(addr_)); + + Filters::Common::ExtAuthz::Response response{}; + response.status = Filters::Common::ExtAuthz::CheckStatus::Error; EXPECT_CALL(*client_, check(_, _, _)) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { - callbacks.onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + callbacks.onComplete(std::make_unique(response)); }))); EXPECT_CALL(filter_callbacks_, continueDecoding()).Times(0); diff --git a/test/extensions/filters/http/fault/fault_filter_test.cc b/test/extensions/filters/http/fault/fault_filter_test.cc index c231a6796bd11..420c065208b56 100644 --- a/test/extensions/filters/http/fault/fault_filter_test.cc +++ b/test/extensions/filters/http/fault/fault_filter_test.cc @@ -756,10 +756,10 @@ void FaultFilterTest::TestPerFilterConfigFault( const Router::RouteSpecificFilterConfig* vhost_fault) { ON_CALL(filter_callbacks_.route_->route_entry_, - perFilterConfig(Extensions::HttpFilters::HttpFilterNames::get().FAULT)) + perFilterConfig(Extensions::HttpFilters::HttpFilterNames::get().Fault)) .WillByDefault(Return(route_fault)); ON_CALL(filter_callbacks_.route_->route_entry_.virtual_host_, - perFilterConfig(Extensions::HttpFilters::HttpFilterNames::get().FAULT)) + perFilterConfig(Extensions::HttpFilters::HttpFilterNames::get().Fault)) .WillByDefault(Return(vhost_fault)); const std::string upstream_cluster("www1"); diff --git a/test/extensions/filters/http/grpc_json_transcoder/grpc_json_transcoder_integration_test.cc b/test/extensions/filters/http/grpc_json_transcoder/grpc_json_transcoder_integration_test.cc index 7a9d898777287..a4bc68362c246 100644 --- a/test/extensions/filters/http/grpc_json_transcoder/grpc_json_transcoder_integration_test.cc +++ b/test/extensions/filters/http/grpc_json_transcoder/grpc_json_transcoder_integration_test.cc @@ -71,10 +71,10 @@ class GrpcJsonTranscoderIntegrationTest response = codec_client_->makeHeaderOnlyRequest(request_headers); } - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); if (!grpc_request_messages.empty()) { - upstream_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); Grpc::Decoder grpc_decoder; std::vector frames; @@ -84,8 +84,7 @@ class GrpcJsonTranscoderIntegrationTest for (size_t i = 0; i < grpc_request_messages.size(); ++i) { RequestType actual_message; if (frames[i].length_ > 0) { - EXPECT_TRUE( - actual_message.ParseFromString(TestUtility::bufferToString(*frames[i].data_))); + EXPECT_TRUE(actual_message.ParseFromString(frames[i].data_->toString())); } RequestType expected_message; EXPECT_TRUE(TextFormat::ParseFromString(grpc_request_messages[i], &expected_message)); @@ -115,7 +114,7 @@ class GrpcJsonTranscoderIntegrationTest } EXPECT_TRUE(upstream_request_->complete()); } else { - upstream_request_->waitForReset(); + ASSERT_TRUE(upstream_request_->waitForReset()); } response->waitForEndStream(); @@ -137,8 +136,8 @@ class GrpcJsonTranscoderIntegrationTest } codec_client_->close(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } }; @@ -173,6 +172,17 @@ TEST_P(GrpcJsonTranscoderIntegrationTest, UnaryGet) { R"({"shelves":[{"id":"20","theme":"Children"},{"id":"1","theme":"Foo"}]})"); } +TEST_P(GrpcJsonTranscoderIntegrationTest, UnaryGetHttpBody) { + testTranscoding( + Http::TestHeaderMapImpl{{":method", "GET"}, {":path", "/index"}, {":authority", "host"}}, "", + {""}, {R"(content_type: "text/html" data: "

Hello!

" )"}, Status(), + Http::TestHeaderMapImpl{{":status", "200"}, + {"content-type", "text/html"}, + {"content-length", "15"}, + {"grpc-status", "0"}}, + R"(

Hello!

)"); +} + TEST_P(GrpcJsonTranscoderIntegrationTest, UnaryGetError) { testTranscoding( Http::TestHeaderMapImpl{ diff --git a/test/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter_test.cc b/test/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter_test.cc index 10a4deeb9fa8a..f388588a085e4 100644 --- a/test/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter_test.cc +++ b/test/extensions/filters/http/grpc_json_transcoder/json_transcoder_filter_test.cc @@ -97,7 +97,7 @@ class GrpcJsonTranscoderConfigTest : public testing::Test { 0 == file.name().compare(file.name().length() - file_name.length(), ProtobufTypes::String::npos, file_name); }); - RELEASE_ASSERT(file_itr != descriptor_set.file().end()); + RELEASE_ASSERT(file_itr != descriptor_set.file().end(), ""); file_descriptor = *file_itr; descriptor_set.clear_file(); @@ -302,7 +302,7 @@ TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryPost) { expected_request.mutable_shelf()->set_theme("Children"); bookstore::CreateShelfRequest request; - request.ParseFromString(TestUtility::bufferToString(*frames[0].data_)); + request.ParseFromString(frames[0].data_->toString()); EXPECT_EQ(expected_request.ByteSize(), frames[0].length_); EXPECT_TRUE(MessageDifferencer::Equals(expected_request, request)); @@ -327,7 +327,7 @@ TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryPost) { EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_.encodeData(*response_data, false)); - std::string response_json = TestUtility::bufferToString(*response_data); + std::string response_json = response_data->toString(); EXPECT_EQ("{\"id\":\"20\",\"theme\":\"Children\"}", response_json); @@ -366,7 +366,7 @@ TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryPostWithPackageServiceMetho expected_request.mutable_shelf()->set_theme("Children"); bookstore::CreateShelfRequest request; - request.ParseFromString(TestUtility::bufferToString(*frames[0].data_)); + request.ParseFromString(frames[0].data_->toString()); EXPECT_EQ(expected_request.ByteSize(), frames[0].length_); EXPECT_TRUE(MessageDifferencer::Equals(expected_request, request)); @@ -391,7 +391,7 @@ TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryPostWithPackageServiceMetho EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_.encodeData(*response_data, false)); - std::string response_json = TestUtility::bufferToString(*response_data); + std::string response_json = response_data->toString(); EXPECT_EQ("{\"id\":\"20\",\"theme\":\"Children\"}", response_json); @@ -427,7 +427,7 @@ TEST_F(GrpcJsonTranscoderFilterTest, ForwardUnaryPostGrpc) { expected_request.mutable_shelf()->set_theme("Children"); bookstore::CreateShelfRequest forwarded_request; - forwarded_request.ParseFromString(TestUtility::bufferToString(*frames[0].data_)); + forwarded_request.ParseFromString(frames[0].data_->toString()); EXPECT_EQ(expected_request.ByteSize(), frames[0].length_); EXPECT_TRUE(MessageDifferencer::Equals(expected_request, forwarded_request)); @@ -458,7 +458,7 @@ TEST_F(GrpcJsonTranscoderFilterTest, ForwardUnaryPostGrpc) { EXPECT_EQ(1, frames.size()); bookstore::Shelf forwarded_response; - forwarded_response.ParseFromString(TestUtility::bufferToString(*frames[0].data_)); + forwarded_response.ParseFromString(frames[0].data_->toString()); EXPECT_EQ(expected_response.ByteSize(), frames[0].length_); EXPECT_TRUE(MessageDifferencer::Equals(expected_response, forwarded_response)); @@ -550,6 +550,84 @@ TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryNotGrpcResponse) { EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(request_data, true)); } +TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryWithHttpBodyAsOutput) { + Http::TestHeaderMapImpl request_headers{{":method", "GET"}, {":path", "/index"}}; + + EXPECT_CALL(decoder_callbacks_, clearRouteCache()); + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + EXPECT_EQ("application/grpc", request_headers.get_("content-type")); + EXPECT_EQ("/index", request_headers.get_("x-envoy-original-path")); + EXPECT_EQ("/bookstore.Bookstore/GetIndex", request_headers.get_(":path")); + EXPECT_EQ("trailers", request_headers.get_("te")); + + Http::TestHeaderMapImpl response_headers{{"content-type", "application/grpc"}, + {":status", "200"}}; + + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_.encodeHeaders(response_headers, false)); + EXPECT_EQ("application/json", response_headers.get_("content-type")); + + google::api::HttpBody response; + response.set_content_type("text/html"); + response.set_data("

Hello, world!

"); + + auto response_data = Grpc::Common::serializeBody(response); + + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, + filter_.encodeData(*response_data, false)); + + EXPECT_EQ(response.content_type(), response_headers.get_("content-type")); + EXPECT_EQ(response.data(), response_data->toString()); + + Http::TestHeaderMapImpl response_trailers{{"grpc-status", "0"}, {"grpc-message", ""}}; + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.decodeTrailers(response_trailers)); +} + +TEST_F(GrpcJsonTranscoderFilterTest, TranscodingUnaryWithHttpBodyAsOutputAndSplitTwoEncodeData) { + Http::TestHeaderMapImpl request_headers{{":method", "GET"}, {":path", "/index"}}; + + EXPECT_CALL(decoder_callbacks_, clearRouteCache()); + + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(request_headers, false)); + EXPECT_EQ("application/grpc", request_headers.get_("content-type")); + EXPECT_EQ("/index", request_headers.get_("x-envoy-original-path")); + EXPECT_EQ("/bookstore.Bookstore/GetIndex", request_headers.get_(":path")); + EXPECT_EQ("trailers", request_headers.get_("te")); + + Http::TestHeaderMapImpl response_headers{{"content-type", "application/grpc"}, + {":status", "200"}}; + + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_.encodeHeaders(response_headers, false)); + EXPECT_EQ("application/json", response_headers.get_("content-type")); + + google::api::HttpBody response; + response.set_content_type("text/html"); + response.set_data("

Hello, world!

"); + + auto response_data = Grpc::Common::serializeBody(response); + + // Firstly, the response data buffer is splitted into two parts. + Buffer::OwnedImpl response_data_first_part; + response_data_first_part.move(*response_data, response_data->length() / 2); + + // Secondly, we send the first part of response data to the data encoding step. + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, + filter_.encodeData(response_data_first_part, false)); + + // Finaly, since half of the response data buffer is moved already, here we can send the rest + // of it to the next data encoding step. + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, + filter_.encodeData(*response_data, false)); + + EXPECT_EQ(response.content_type(), response_headers.get_("content-type")); + EXPECT_EQ(response.data(), response_data->toString()); + + Http::TestHeaderMapImpl response_trailers{{"grpc-status", "0"}, {"grpc-message", ""}}; + EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.decodeTrailers(response_trailers)); +} + struct GrpcJsonTranscoderFilterPrintTestParam { std::string config_json_; std::string expected_response_; @@ -594,7 +672,7 @@ TEST_P(GrpcJsonTranscoderFilterPrintTest, PrintOptions) { EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_->encodeData(*response_data, false)); - std::string response_json = TestUtility::bufferToString(*response_data); + std::string response_json = response_data->toString(); EXPECT_EQ(GetParam().expected_response_, response_json); } diff --git a/test/extensions/filters/http/grpc_web/grpc_web_filter_test.cc b/test/extensions/filters/http/grpc_web/grpc_web_filter_test.cc index 6d4b7ba7ffe34..c36eb271448f2 100644 --- a/test/extensions/filters/http/grpc_web/grpc_web_filter_test.cc +++ b/test/extensions/filters/http/grpc_web/grpc_web_filter_test.cc @@ -88,9 +88,8 @@ class GrpcWebFilterTest : public testing::TestWithParam(expected_code), code); })); EXPECT_CALL(decoder_callbacks_, encodeData(_, _)) - .WillOnce(Invoke([=](Buffer::Instance& data, bool) { - EXPECT_EQ(expected_message, TestUtility::bufferToString(data)); - })); + .WillOnce(Invoke( + [=](Buffer::Instance& data, bool) { EXPECT_EQ(expected_message, data.toString()); })); } void expectRequiredGrpcUpstreamHeaders(const Http::HeaderMap& request_headers) { @@ -247,7 +246,7 @@ TEST_P(GrpcWebFilterTest, Unary) { EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.decodeData(request_buffer, true)); decoded_buffer.move(request_buffer); } - EXPECT_EQ(std::string(MESSAGE, MESSAGE_SIZE), TestUtility::bufferToString(decoded_buffer)); + EXPECT_EQ(std::string(MESSAGE, MESSAGE_SIZE), decoded_buffer.toString()); } else if (isTextRequest()) { Buffer::OwnedImpl request_buffer; Buffer::OwnedImpl decoded_buffer; @@ -266,8 +265,7 @@ TEST_P(GrpcWebFilterTest, Unary) { } decoded_buffer.move(request_buffer); } - EXPECT_EQ(std::string(TEXT_MESSAGE, TEXT_MESSAGE_SIZE), - TestUtility::bufferToString(decoded_buffer)); + EXPECT_EQ(std::string(TEXT_MESSAGE, TEXT_MESSAGE_SIZE), decoded_buffer.toString()); } else { FAIL() << "Unsupported gRPC-Web request content-type: " << request_content_type(); } @@ -304,7 +302,7 @@ TEST_P(GrpcWebFilterTest, Unary) { EXPECT_EQ(Http::FilterDataStatus::Continue, filter_.encodeData(response_buffer, false)); encoded_buffer.move(response_buffer); } - EXPECT_EQ(std::string(MESSAGE, MESSAGE_SIZE), TestUtility::bufferToString(encoded_buffer)); + EXPECT_EQ(std::string(MESSAGE, MESSAGE_SIZE), encoded_buffer.toString()); } else if (accept_text_response()) { Buffer::OwnedImpl response_buffer; Buffer::OwnedImpl encoded_buffer; @@ -318,8 +316,7 @@ TEST_P(GrpcWebFilterTest, Unary) { } encoded_buffer.move(response_buffer); } - EXPECT_EQ(std::string(B64_MESSAGE, B64_MESSAGE_SIZE), - TestUtility::bufferToString(encoded_buffer)); + EXPECT_EQ(std::string(B64_MESSAGE, B64_MESSAGE_SIZE), encoded_buffer.toString()); } else { FAIL() << "Unsupported gRPC-Web response content-type: " << response_headers.ContentType()->value().c_str(); @@ -334,10 +331,9 @@ TEST_P(GrpcWebFilterTest, Unary) { response_trailers.addCopy(Http::Headers::get().GrpcMessage, "ok"); EXPECT_EQ(Http::FilterTrailersStatus::Continue, filter_.encodeTrailers(response_trailers)); if (accept_binary_response()) { - EXPECT_EQ(std::string(TRAILERS, TRAILERS_SIZE), TestUtility::bufferToString(trailers_buffer)); + EXPECT_EQ(std::string(TRAILERS, TRAILERS_SIZE), trailers_buffer.toString()); } else if (accept_text_response()) { - EXPECT_EQ(std::string(TRAILERS, TRAILERS_SIZE), - Base64::decode(TestUtility::bufferToString(trailers_buffer))); + EXPECT_EQ(std::string(TRAILERS, TRAILERS_SIZE), Base64::decode(trailers_buffer.toString())); } else { FAIL() << "Unsupported gRPC-Web response content-type: " << response_headers.ContentType()->value().c_str(); diff --git a/test/extensions/filters/http/gzip/gzip_filter_test.cc b/test/extensions/filters/http/gzip/gzip_filter_test.cc index e7987c13ed70b..170aa143f790d 100644 --- a/test/extensions/filters/http/gzip/gzip_filter_test.cc +++ b/test/extensions/filters/http/gzip/gzip_filter_test.cc @@ -68,7 +68,7 @@ class GzipFilterTest : public testing::Test { void verifyCompressedData() { decompressor_.decompress(data_, decompressed_data_); - const std::string uncompressed_str{TestUtility::bufferToString(decompressed_data_)}; + const std::string uncompressed_str{decompressed_data_.toString()}; ASSERT_EQ(expected_str_.length(), uncompressed_str.length()); EXPECT_EQ(expected_str_, uncompressed_str); EXPECT_EQ(expected_str_.length(), stats_.counter("test.gzip.total_uncompressed_bytes").value()); @@ -77,7 +77,7 @@ class GzipFilterTest : public testing::Test { void feedBuffer(uint64_t size) { TestUtility::feedBufferWithRandomCharacters(data_, size); - expected_str_ += TestUtility::bufferToString(data_); + expected_str_ += data_.toString(); } void drainBuffer() { diff --git a/test/extensions/filters/http/health_check/config_test.cc b/test/extensions/filters/http/health_check/config_test.cc index 3ec2e405ef204..a966cae896c85 100644 --- a/test/extensions/filters/http/health_check/config_test.cc +++ b/test/extensions/filters/http/health_check/config_test.cc @@ -71,8 +71,10 @@ TEST(HealthCheckFilterConfig, FailsWhenNotPassThroughButTimeoutSetProto) { NiceMock context; config.mutable_pass_through_mode()->set_value(false); - config.set_endpoint("foo"); config.mutable_cache_time()->set_seconds(10); + envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); + header.set_name(":path"); + header.set_exact_match("foo"); EXPECT_THROW( healthCheckFilterConfig.createFilterFactoryFromProto(config, "dummy_stats_prefix", context), @@ -85,7 +87,9 @@ TEST(HealthCheckFilterConfig, NotFailingWhenNotPassThroughAndTimeoutNotSetProto) NiceMock context; config.mutable_pass_through_mode()->set_value(false); - config.set_endpoint("foo"); + envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); + header.set_name(":path"); + header.set_exact_match("foo"); healthCheckFilterConfig.createFilterFactoryFromProto(config, "dummy_stats_prefix", context); } @@ -97,7 +101,9 @@ TEST(HealthCheckFilterConfig, HealthCheckFilterWithEmptyProto) { healthCheckFilterConfig.createEmptyConfigProto().get()); config.mutable_pass_through_mode()->set_value(false); - config.set_endpoint("foo"); + envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); + header.set_name(":path"); + header.set_exact_match("foo"); healthCheckFilterConfig.createFilterFactoryFromProto(config, "dummy_stats_prefix", context); } @@ -152,7 +158,7 @@ TEST(HealthCheckFilterConfig, HealthCheckFilterHeaderMatch) { envoy::api::v2::route::HeaderMatcher& yheader = *config.add_headers(); yheader.set_name("y-healthcheck"); - yheader.set_value("foo"); + yheader.set_exact_match("foo"); Http::TestHeaderMapImpl headers{{"x-healthcheck", "arbitrary_value"}, {"y-healthcheck", "foo"}}; @@ -170,7 +176,7 @@ TEST(HealthCheckFilterConfig, HealthCheckFilterHeaderMatchWrongValue) { envoy::api::v2::route::HeaderMatcher& yheader = *config.add_headers(); yheader.set_name("y-healthcheck"); - yheader.set_value("foo"); + yheader.set_exact_match("foo"); Http::TestHeaderMapImpl headers{{"x-healthcheck", "arbitrary_value"}, {"y-healthcheck", "bar"}}; @@ -188,49 +194,13 @@ TEST(HealthCheckFilterConfig, HealthCheckFilterHeaderMatchMissingHeader) { envoy::api::v2::route::HeaderMatcher& yheader = *config.add_headers(); yheader.set_name("y-healthcheck"); - yheader.set_value("foo"); + yheader.set_exact_match("foo"); Http::TestHeaderMapImpl headers{{"y-healthcheck", "foo"}}; testHealthCheckHeaderMatch(config, headers, false); } -// If an endpoint is specified and the path matches, it should match regardless of any :path -// conditions given in the headers field. -TEST(HealthCheckFilterConfig, HealthCheckFilterEndpoint) { - envoy::config::filter::http::health_check::v2::HealthCheck config; - - config.mutable_pass_through_mode()->set_value(false); - - config.set_endpoint("foo"); - - envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); - header.set_name(Http::Headers::get().Path.get()); - header.set_value("bar"); - - Http::TestHeaderMapImpl headers{{Http::Headers::get().Path.get(), "foo"}}; - - testHealthCheckHeaderMatch(config, headers, true); -} - -// If an endpoint is specified and the path does not match, the filter should not match regardless -// of any :path conditions given in the headers field. -TEST(HealthCheckFilterConfig, HealthCheckFilterEndpointOverride) { - envoy::config::filter::http::health_check::v2::HealthCheck config; - - config.mutable_pass_through_mode()->set_value(false); - - config.set_endpoint("foo"); - - envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); - header.set_name(Http::Headers::get().Path.get()); - header.set_value("bar"); - - Http::TestHeaderMapImpl headers{{Http::Headers::get().Path.get(), "bar"}}; - - testHealthCheckHeaderMatch(config, headers, false); -} - // Conditions for the same header should match if they are both satisfied. TEST(HealthCheckFilterConfig, HealthCheckFilterDuplicateMatch) { envoy::config::filter::http::health_check::v2::HealthCheck config; @@ -239,7 +209,7 @@ TEST(HealthCheckFilterConfig, HealthCheckFilterDuplicateMatch) { envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); header.set_name("x-healthcheck"); - header.set_value("foo"); + header.set_exact_match("foo"); envoy::api::v2::route::HeaderMatcher& dup_header = *config.add_headers(); dup_header.set_name("x-healthcheck"); @@ -257,11 +227,11 @@ TEST(HealthCheckFilterConfig, HealthCheckFilterDuplicateNoMatch) { envoy::api::v2::route::HeaderMatcher& header = *config.add_headers(); header.set_name("x-healthcheck"); - header.set_value("foo"); + header.set_exact_match("foo"); envoy::api::v2::route::HeaderMatcher& dup_header = *config.add_headers(); dup_header.set_name("x-healthcheck"); - dup_header.set_value("bar"); + dup_header.set_exact_match("bar"); Http::TestHeaderMapImpl headers{{"x-healthcheck", "foo"}}; diff --git a/test/extensions/filters/http/jwt_authn/filter_integration_test.cc b/test/extensions/filters/http/jwt_authn/filter_integration_test.cc index b6a2870014c20..eb3f8803de900 100644 --- a/test/extensions/filters/http/jwt_authn/filter_integration_test.cc +++ b/test/extensions/filters/http/jwt_authn/filter_integration_test.cc @@ -26,7 +26,7 @@ std::string getFilterConfig(bool use_local_jwks) { } HttpFilter filter; - filter.set_name(HttpFilterNames::get().JWT_AUTHN); + filter.set_name(HttpFilterNames::get().JwtAuthn); MessageUtil::jsonConvert(proto_config, *filter.mutable_config()); return MessageUtil::getJsonStringFromMessage(filter); } @@ -109,9 +109,13 @@ class RemoteJwksIntegrationTest : public HttpProtocolIntegrationTest { } void waitForJwksResponse(const std::string& status, const std::string& jwks_body) { - fake_jwks_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - jwks_request_ = fake_jwks_connection_->waitForNewStream(*dispatcher_); - jwks_request_->waitForEndStream(*dispatcher_); + AssertionResult result = + fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_jwks_connection_); + RELEASE_ASSERT(result, result.message()); + result = fake_jwks_connection_->waitForNewStream(*dispatcher_, jwks_request_); + RELEASE_ASSERT(result, result.message()); + result = jwks_request_->waitForEndStream(*dispatcher_); + RELEASE_ASSERT(result, result.message()); Http::TestHeaderMapImpl response_headers{{":status", status}}; jwks_request_->encodeHeaders(response_headers, false); @@ -122,12 +126,16 @@ class RemoteJwksIntegrationTest : public HttpProtocolIntegrationTest { void cleanup() { codec_client_->close(); if (fake_jwks_connection_ != nullptr) { - fake_jwks_connection_->close(); - fake_jwks_connection_->waitForDisconnect(); + AssertionResult result = fake_jwks_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_jwks_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } if (fake_upstream_connection_ != nullptr) { - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + AssertionResult result = fake_upstream_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_upstream_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } diff --git a/test/extensions/filters/http/lua/BUILD b/test/extensions/filters/http/lua/BUILD index 39eceafebf201..5f6c20712a8c2 100644 --- a/test/extensions/filters/http/lua/BUILD +++ b/test/extensions/filters/http/lua/BUILD @@ -16,8 +16,11 @@ envoy_extension_cc_test( srcs = ["lua_filter_test.cc"], extension_name = "envoy.filters.http.lua", deps = [ + "//source/common/request_info:request_info_lib", "//source/extensions/filters/http/lua:lua_filter_lib", "//test/mocks/http:http_mocks", + "//test/mocks/network:network_mocks", + "//test/mocks/ssl:ssl_mocks", "//test/mocks/thread_local:thread_local_mocks", "//test/mocks/upstream:upstream_mocks", "//test/test_common:utility_lib", @@ -29,8 +32,10 @@ envoy_extension_cc_test( srcs = ["wrappers_test.cc"], extension_name = "envoy.filters.http.lua", deps = [ + "//source/common/request_info:request_info_lib", "//source/extensions/filters/http/lua:wrappers_lib", "//test/extensions/filters/common/lua:lua_wrappers_lib", + "//test/mocks/request_info:request_info_mocks", "//test/test_common:utility_lib", ], ) diff --git a/test/extensions/filters/http/lua/lua_filter_test.cc b/test/extensions/filters/http/lua/lua_filter_test.cc index 2562c2fa9fe50..0283baa92e33e 100644 --- a/test/extensions/filters/http/lua/lua_filter_test.cc +++ b/test/extensions/filters/http/lua/lua_filter_test.cc @@ -1,9 +1,12 @@ #include "common/buffer/buffer_impl.h" #include "common/http/message_impl.h" +#include "common/request_info/request_info_impl.h" #include "extensions/filters/http/lua/lua_filter.h" #include "test/mocks/http/mocks.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/ssl/mocks.h" #include "test/mocks/thread_local/mocks.h" #include "test/mocks/upstream/mocks.h" #include "test/test_common/printers.h" @@ -15,6 +18,8 @@ using testing::AtLeast; using testing::InSequence; using testing::Invoke; using testing::Return; +using testing::ReturnPointee; +using testing::ReturnRef; using testing::StrEq; using testing::_; @@ -61,11 +66,20 @@ class LuaHttpFilterTest : public testing::Test { void setup(const std::string& lua_code) { config_.reset(new FilterConfig(lua_code, tls_, cluster_manager_)); + setupFilter(); + } + + void setupFilter() { filter_.reset(new TestFilter(config_)); filter_->setDecoderFilterCallbacks(decoder_callbacks_); filter_->setEncoderFilterCallbacks(encoder_callbacks_); } + void setupSecureConnection(const bool secure) { + EXPECT_CALL(decoder_callbacks_, connection()).WillOnce(Return(&connection_)); + EXPECT_CALL(Const(connection_), ssl()).Times(1).WillOnce(Return(secure ? &ssl_ : nullptr)); + } + void setupMetadata(const std::string& yaml) { MessageUtil::loadFromYaml(yaml, metadata_); EXPECT_CALL(decoder_callbacks_.route_->route_entry_, metadata()) @@ -79,6 +93,9 @@ class LuaHttpFilterTest : public testing::Test { Http::MockStreamDecoderFilterCallbacks decoder_callbacks_; Http::MockStreamEncoderFilterCallbacks encoder_callbacks_; envoy::api::v2::core::Metadata metadata_; + NiceMock ssl_; + NiceMock connection_; + NiceMock request_info_; const std::string HEADER_ONLY_SCRIPT{R"EOF( function envoy_on_request(request_handle) @@ -1211,6 +1228,8 @@ TEST_F(LuaHttpFilterTest, HttpCallInvalidHeaders) { } // Respond right away. +// This is also a regression test for https://github.com/envoyproxy/envoy/issues/3570 which runs +// the request flow 2000 times and does a GC at the end to make sure we don't leak memory. TEST_F(LuaHttpFilterTest, ImmediateResponse) { const std::string SCRIPT{R"EOF( function envoy_on_request(request_handle) @@ -1227,12 +1246,31 @@ TEST_F(LuaHttpFilterTest, ImmediateResponse) { InSequence s; setup(SCRIPT); - Http::TestHeaderMapImpl request_headers{{":path", "/"}}; - Http::TestHeaderMapImpl expected_headers{{":status", "503"}, {"content-length", "4"}}; - EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&expected_headers), false)); - EXPECT_CALL(decoder_callbacks_, encodeData(_, true)); - EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, - filter_->decodeHeaders(request_headers, false)); + // Perform a GC and snap bytes currently used by the runtime. + config_->runtimeGC(); + const uint64_t mem_use_at_start = config_->runtimeBytesUsed(); + + for (uint64_t i = 0; i < 2000; i++) { + Http::TestHeaderMapImpl request_headers{{":path", "/"}}; + Http::TestHeaderMapImpl expected_headers{{":status", "503"}, {"content-length", "4"}}; + EXPECT_CALL(decoder_callbacks_, encodeHeaders_(HeaderMapEqualRef(&expected_headers), false)); + EXPECT_CALL(decoder_callbacks_, encodeData(_, true)); + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers, false)); + filter_->onDestroy(); + setupFilter(); + } + + // Perform GC and compare bytes currently used by the runtime to the original value. + // NOTE: This value is not the same as the original value for reasons that I do not fully + // understand. Depending on the number of requests tested, it increases incrementally, but + // then goes down again at a certain point. There must be some type of interpreter caching + // going on because I'm pretty certain this is not another leak. Because of this, we need + // to do a soft comparison here. In my own testing, without a fix for #3570, the memory + // usage after is at least 20x higher after 2000 iterations so we just check to see if it's + // within 2x. + config_->runtimeGC(); + EXPECT_TRUE(config_->runtimeBytesUsed() < mem_use_at_start * 2); } // Respond with bad status. @@ -1460,6 +1498,78 @@ TEST_F(LuaHttpFilterTest, GetMetadataFromHandleNoLuaMetadata) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); } +// Get the current protocol. +TEST_F(LuaHttpFilterTest, GetCurrentProtocol) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:logTrace(request_handle:requestInfo():protocol()) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + EXPECT_CALL(decoder_callbacks_, requestInfo()).WillOnce(ReturnRef(request_info_)); + EXPECT_CALL(request_info_, protocol()).WillOnce(Return(Http::Protocol::Http11)); + + Http::TestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("HTTP/1.1"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); +} + +// Set and get request info dynamic metadata. +TEST_F(LuaHttpFilterTest, SetGetDynamicMetadata) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:requestInfo():dynamicMetadata():set("envoy.lb", "foo", "bar") + request_handle:logTrace(request_handle:requestInfo():dynamicMetadata():get("envoy.lb")["foo"]) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestHeaderMapImpl request_headers{{":path", "/"}}; + RequestInfo::RequestInfoImpl request_info(Http::Protocol::Http2); + EXPECT_EQ(0, request_info.dynamicMetadata().filter_metadata_size()); + EXPECT_CALL(decoder_callbacks_, requestInfo()).WillOnce(ReturnRef(request_info)); + EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("bar"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + EXPECT_EQ(1, request_info.dynamicMetadata().filter_metadata_size()); + EXPECT_EQ("bar", request_info.dynamicMetadata() + .filter_metadata() + .at("envoy.lb") + .fields() + .at("foo") + .string_value()); +} + +// Check the connection. +TEST_F(LuaHttpFilterTest, CheckConnection) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + if request_handle:connection():ssl() == nil then + request_handle:logTrace("plain") + else + request_handle:logTrace("secure") + end + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestHeaderMapImpl request_headers{{":path", "/"}}; + + setupSecureConnection(false); + EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("plain"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + + setupSecureConnection(true); + EXPECT_CALL(*filter_, scriptLog(spdlog::level::trace, StrEq("secure"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); +} + } // namespace Lua } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/lua/lua_integration_test.cc b/test/extensions/filters/http/lua/lua_integration_test.cc index 7820f546e400f..c3b62e5ab6546 100644 --- a/test/extensions/filters/http/lua/lua_integration_test.cc +++ b/test/extensions/filters/http/lua/lua_integration_test.cc @@ -45,7 +45,7 @@ class LuaIntegrationTest : public HttpIntegrationTest, new_route->mutable_match()->set_prefix("/alt/route"); new_route->mutable_route()->set_cluster("alt_cluster"); - const std::string key = Extensions::HttpFilters::HttpFilterNames::get().LUA; + const std::string key = Extensions::HttpFilters::HttpFilterNames::get().Lua; const std::string yaml = R"EOF( foo.bar: @@ -72,12 +72,16 @@ class LuaIntegrationTest : public HttpIntegrationTest, void cleanup() { codec_client_->close(); if (fake_lua_connection_ != nullptr) { - fake_lua_connection_->close(); - fake_lua_connection_->waitForDisconnect(); + AssertionResult result = fake_lua_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_lua_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } if (fake_upstream_connection_ != nullptr) { - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + AssertionResult result = fake_upstream_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_upstream_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } @@ -106,9 +110,20 @@ name: envoy.lua local metadata = request_handle:metadata():get("foo.bar") local body_length = request_handle:body():length() + + request_handle:requestInfo():dynamicMetadata():set("envoy.lb", "foo", "bar") + local dynamic_metadata_value = request_handle:requestInfo():dynamicMetadata():get("envoy.lb")["foo"] + request_handle:headers():add("request_body_size", body_length) request_handle:headers():add("request_metadata_foo", metadata["foo"]) request_handle:headers():add("request_metadata_baz", metadata["baz"]) + if request_handle:connection():ssl() == nil then + request_handle:headers():add("request_secure", "false") + else + request_handle:headers():add("request_secure", "true") + end + request_handle:headers():add("request_protocol", request_handle:requestInfo():protocol()) + request_handle:headers():add("request_dynamic_metadata_value", dynamic_metadata_value) end function envoy_on_response(response_handle) @@ -117,6 +132,7 @@ name: envoy.lua response_handle:headers():add("response_metadata_foo", metadata["foo"]) response_handle:headers():add("response_metadata_baz", metadata["baz"]) response_handle:headers():add("response_body_size", body_length) + response_handle:headers():add("request_protocol", response_handle:requestInfo():protocol()) response_handle:headers():remove("foo") end )EOF"; @@ -152,6 +168,18 @@ name: envoy.lua .get(Http::LowerCaseString("request_metadata_baz")) ->value() .c_str()); + EXPECT_STREQ( + "false", + upstream_request_->headers().get(Http::LowerCaseString("request_secure"))->value().c_str()); + + EXPECT_STREQ( + "HTTP/1.1", + upstream_request_->headers().get(Http::LowerCaseString("request_protocol"))->value().c_str()); + + EXPECT_STREQ("bar", upstream_request_->headers() + .get(Http::LowerCaseString("request_dynamic_metadata_value")) + ->value() + .c_str()); Http::TestHeaderMapImpl response_headers{{":status", "200"}, {"foo", "bar"}}; upstream_request_->encodeHeaders(response_headers, false); @@ -170,6 +198,8 @@ name: envoy.lua EXPECT_STREQ( "bat", response->headers().get(Http::LowerCaseString("response_metadata_baz"))->value().c_str()); + EXPECT_STREQ("HTTP/1.1", + response->headers().get(Http::LowerCaseString("request_protocol"))->value().c_str()); EXPECT_EQ(nullptr, response->headers().get(Http::LowerCaseString("foo"))); cleanup(); @@ -208,9 +238,9 @@ name: envoy.lua {"x-forwarded-for", "10.0.0.1"}}; auto response = codec_client_->makeHeaderOnlyRequest(request_headers); - fake_lua_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - lua_request_ = fake_lua_connection_->waitForNewStream(*dispatcher_); - lua_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_lua_connection_)); + ASSERT_TRUE(fake_lua_connection_->waitForNewStream(*dispatcher_, lua_request_)); + ASSERT_TRUE(lua_request_->waitForEndStream(*dispatcher_)); Http::TestHeaderMapImpl response_headers{{":status", "200"}, {"foo", "bar"}}; lua_request_->encodeHeaders(response_headers, false); Buffer::OwnedImpl response_data1("good"); @@ -266,9 +296,9 @@ name: envoy.lua {"x-forwarded-for", "10.0.0.1"}}; auto response = codec_client_->makeHeaderOnlyRequest(request_headers); - fake_lua_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - lua_request_ = fake_lua_connection_->waitForNewStream(*dispatcher_); - lua_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_lua_connection_)); + ASSERT_TRUE(fake_lua_connection_->waitForNewStream(*dispatcher_, lua_request_)); + ASSERT_TRUE(lua_request_->waitForEndStream(*dispatcher_)); Http::TestHeaderMapImpl response_headers{{":status", "200"}, {"foo", "bar"}}; lua_request_->encodeHeaders(response_headers, true); diff --git a/test/extensions/filters/http/lua/wrappers_test.cc b/test/extensions/filters/http/lua/wrappers_test.cc index ce35a5a38bfb8..15e78db73561d 100644 --- a/test/extensions/filters/http/lua/wrappers_test.cc +++ b/test/extensions/filters/http/lua/wrappers_test.cc @@ -1,9 +1,15 @@ +#include "common/http/utility.h" +#include "common/request_info/request_info_impl.h" + #include "extensions/filters/http/lua/wrappers.h" #include "test/extensions/filters/common/lua/lua_wrappers.h" +#include "test/mocks/request_info/mocks.h" #include "test/test_common/utility.h" using testing::InSequence; +using testing::Return; +using testing::ReturnPointee; namespace Envoy { namespace Extensions { @@ -210,6 +216,182 @@ TEST_F(LuaHeaderMapWrapperTest, IteratorAcrossYield) { "[string \"...\"]:5: object used outside of proper scope"); } +class LuaRequestInfoWrapperTest + : public Filters::Common::Lua::LuaWrappersTestBase { +public: + virtual void setup(const std::string& script) { + Filters::Common::Lua::LuaWrappersTestBase::setup(script); + state_->registerType(); + state_->registerType(); + } + +protected: + void expectToPrintCurrentProtocol(const absl::optional& protocol) { + const std::string SCRIPT{R"EOF( + function callMe(object) + testPrint(string.format("'%s'", object:protocol())) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + NiceMock request_info; + ON_CALL(request_info, protocol()).WillByDefault(ReturnPointee(&protocol)); + Filters::Common::Lua::LuaDeathRef wrapper( + RequestInfoWrapper::create(coroutine_->luaState(), request_info), true); + EXPECT_CALL(*this, + testPrint(fmt::format("'{}'", Http::Utility::getProtocolString(protocol.value())))); + start("callMe"); + wrapper.reset(); + } + + envoy::api::v2::core::Metadata parseMetadataFromYaml(const std::string& yaml_string) { + envoy::api::v2::core::Metadata metadata; + MessageUtil::loadFromYaml(yaml_string, metadata); + return metadata; + } +}; + +// Return the current request protocol. +TEST_F(LuaRequestInfoWrapperTest, ReturnCurrentProtocol) { + expectToPrintCurrentProtocol(Http::Protocol::Http10); + expectToPrintCurrentProtocol(Http::Protocol::Http11); + expectToPrintCurrentProtocol(Http::Protocol::Http2); +} + +// Set, get and iterate request info dynamic metadata. +TEST_F(LuaRequestInfoWrapperTest, SetGetAndIterateDynamicMetadata) { + const std::string SCRIPT{R"EOF( + function callMe(object) + testPrint(type(object:dynamicMetadata())) + object:dynamicMetadata():set("envoy.lb", "foo", "bar") + object:dynamicMetadata():set("envoy.lb", "so", "cool") + + testPrint(object:dynamicMetadata():get("envoy.lb")["foo"]) + testPrint(object:dynamicMetadata():get("envoy.lb")["so"]) + + for filter, entry in pairs(object:dynamicMetadata()) do + for key, value in pairs(entry) do + testPrint(string.format("'%s' '%s'", key, value)) + end + end + + local function nRetVals(...) + return select('#',...) + end + testPrint(tostring(nRetVals(object:dynamicMetadata():get("envoy.ngx")))) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + RequestInfo::RequestInfoImpl request_info(Http::Protocol::Http2); + EXPECT_EQ(0, request_info.dynamicMetadata().filter_metadata_size()); + Filters::Common::Lua::LuaDeathRef wrapper( + RequestInfoWrapper::create(coroutine_->luaState(), request_info), true); + EXPECT_CALL(*this, testPrint("userdata")); + EXPECT_CALL(*this, testPrint("bar")); + EXPECT_CALL(*this, testPrint("cool")); + EXPECT_CALL(*this, testPrint("'foo' 'bar'")); + EXPECT_CALL(*this, testPrint("'so' 'cool'")); + EXPECT_CALL(*this, testPrint("0")); + start("callMe"); + + EXPECT_EQ(1, request_info.dynamicMetadata().filter_metadata_size()); + EXPECT_EQ("bar", request_info.dynamicMetadata() + .filter_metadata() + .at("envoy.lb") + .fields() + .at("foo") + .string_value()); + wrapper.reset(); +} + +// Modify during iteration. +TEST_F(LuaRequestInfoWrapperTest, ModifyDuringIterationForDynamicMetadata) { + const std::string SCRIPT{R"EOF( + function callMe(object) + object:dynamicMetadata():set("envoy.lb", "hello", "world") + for key, value in pairs(object:dynamicMetadata()) do + object:dynamicMetadata():set("envoy.lb", "hello", "envoy") + end + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + RequestInfo::RequestInfoImpl request_info(Http::Protocol::Http2); + Filters::Common::Lua::LuaDeathRef wrapper( + RequestInfoWrapper::create(coroutine_->luaState(), request_info), true); + EXPECT_THROW_WITH_MESSAGE( + start("callMe"), Filters::Common::Lua::LuaException, + "[string \"...\"]:5: dynamic metadata map cannot be modified while iterating"); +} + +// Modify after iteration. +TEST_F(LuaRequestInfoWrapperTest, ModifyAfterIterationForDynamicMetadata) { + const std::string SCRIPT{R"EOF( + function callMe(object) + object:dynamicMetadata():set("envoy.lb", "hello", "world") + for filter, entry in pairs(object:dynamicMetadata()) do + testPrint(filter) + for key, value in pairs(entry) do + testPrint(string.format("'%s' '%s'", key, value)) + end + end + + object:dynamicMetadata():set("envoy.lb", "hello", "envoy") + object:dynamicMetadata():set("envoy.proxy", "proto", "grpc") + for filter, entry in pairs(object:dynamicMetadata()) do + testPrint(filter) + for key, value in pairs(entry) do + testPrint(string.format("'%s' '%s'", key, value)) + end + end + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + RequestInfo::RequestInfoImpl request_info(Http::Protocol::Http2); + EXPECT_EQ(0, request_info.dynamicMetadata().filter_metadata_size()); + Filters::Common::Lua::LuaDeathRef wrapper( + RequestInfoWrapper::create(coroutine_->luaState(), request_info), true); + EXPECT_CALL(*this, testPrint("envoy.lb")); + EXPECT_CALL(*this, testPrint("'hello' 'world'")); + EXPECT_CALL(*this, testPrint("envoy.proxy")); + EXPECT_CALL(*this, testPrint("'proto' 'grpc'")); + EXPECT_CALL(*this, testPrint("envoy.lb")); + EXPECT_CALL(*this, testPrint("'hello' 'envoy'")); + start("callMe"); +} + +// Don't finish iteration. +TEST_F(LuaRequestInfoWrapperTest, DontFinishIterationForDynamicMetadata) { + const std::string SCRIPT{R"EOF( + function callMe(object) + object:dynamicMetadata():set("envoy.lb", "foo", "bar") + iterator = pairs(object:dynamicMetadata()) + key, value = iterator() + iterator2 = pairs(object:dynamicMetadata()) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + RequestInfo::RequestInfoImpl request_info(Http::Protocol::Http2); + Filters::Common::Lua::LuaDeathRef wrapper( + RequestInfoWrapper::create(coroutine_->luaState(), request_info), true); + EXPECT_THROW_WITH_MESSAGE( + start("callMe"), Filters::Common::Lua::LuaException, + "[string \"...\"]:6: cannot create a second iterator before completing the first"); +} + } // namespace Lua } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/rbac/rbac_filter_integration_test.cc b/test/extensions/filters/http/rbac/rbac_filter_integration_test.cc index fd41f57beb8ed..9e766f7bc4f53 100644 --- a/test/extensions/filters/http/rbac/rbac_filter_integration_test.cc +++ b/test/extensions/filters/http/rbac/rbac_filter_integration_test.cc @@ -77,7 +77,7 @@ TEST_P(RBACIntegrationTest, RouteOverride) { ->Mutable(0) ->mutable_per_filter_config(); - (*config)[Extensions::HttpFilters::HttpFilterNames::get().RBAC] = pfc; + (*config)[Extensions::HttpFilters::HttpFilterNames::get().Rbac] = pfc; }); config_helper_.addFilter(RBAC_CONFIG); diff --git a/test/extensions/filters/http/rbac/rbac_filter_test.cc b/test/extensions/filters/http/rbac/rbac_filter_test.cc index 58253bd94878b..7e1d89ff852b3 100644 --- a/test/extensions/filters/http/rbac/rbac_filter_test.cc +++ b/test/extensions/filters/http/rbac/rbac_filter_test.cc @@ -46,13 +46,9 @@ class RoleBasedAccessControlFilterTest : public testing::Test { filter_.setDecoderFilterCallbacks(callbacks_); } - void setDestinationPort(uint16_t port, int times = 2) { + void setDestinationPort(uint16_t port) { address_ = Envoy::Network::Utility::parseInternetAddress("1.2.3.4", port, false); - auto& expect = EXPECT_CALL(connection_, localAddress()); - if (times > 0) { - expect.Times(times); - } - expect.WillRepeatedly(ReturnRef(address_)); + ON_CALL(connection_, localAddress()).WillByDefault(ReturnRef(address_)); } NiceMock callbacks_; @@ -94,7 +90,7 @@ TEST_F(RoleBasedAccessControlFilterTest, Denied) { } TEST_F(RoleBasedAccessControlFilterTest, RouteLocalOverride) { - setDestinationPort(456, 0); + setDestinationPort(456); envoy::config::filter::http::rbac::v2::RBACPerRoute route_config; route_config.mutable_rbac()->mutable_rules()->set_action( @@ -102,10 +98,10 @@ TEST_F(RoleBasedAccessControlFilterTest, RouteLocalOverride) { NiceMock engine{route_config.rbac().rules()}; NiceMock per_route_config_{route_config}; - EXPECT_CALL(engine, allowed(_, _)).WillRepeatedly(Return(true)); + EXPECT_CALL(engine, allowed(_, _, _)).WillRepeatedly(Return(true)); EXPECT_CALL(per_route_config_, engine()).WillRepeatedly(ReturnRef(engine)); - EXPECT_CALL(callbacks_.route_->route_entry_, perFilterConfig(HttpFilterNames::get().RBAC)) + EXPECT_CALL(callbacks_.route_->route_entry_, perFilterConfig(HttpFilterNames::get().Rbac)) .WillRepeatedly(Return(&per_route_config_)); EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_.decodeHeaders(headers_, true)); diff --git a/test/extensions/filters/http/router/config_test.cc b/test/extensions/filters/http/router/config_test.cc index fe6943aa4041f..89732808e5784 100644 --- a/test/extensions/filters/http/router/config_test.cc +++ b/test/extensions/filters/http/router/config_test.cc @@ -73,7 +73,7 @@ TEST(RouterFilterConfigTest, DoubleRegistrationTest) { (Registry::RegisterFactory()), EnvoyException, - fmt::format("Double registration for name: '{}'", HttpFilterNames::get().ROUTER)); + fmt::format("Double registration for name: '{}'", HttpFilterNames::get().Router)); } } // namespace RouterFilter diff --git a/test/extensions/filters/http/squash/squash_filter_integration_test.cc b/test/extensions/filters/http/squash/squash_filter_integration_test.cc index 5f1a6788a78eb..fd8379194900f 100644 --- a/test/extensions/filters/http/squash/squash_filter_integration_test.cc +++ b/test/extensions/filters/http/squash/squash_filter_integration_test.cc @@ -22,19 +22,27 @@ class SquashFilterIntegrationTest : public HttpIntegrationTest, ~SquashFilterIntegrationTest() { if (fake_squash_connection_) { - fake_squash_connection_->close(); - fake_squash_connection_->waitForDisconnect(); + AssertionResult result = fake_squash_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_squash_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } FakeStreamPtr sendSquash(const std::string& status, const std::string& body) { if (!fake_squash_connection_) { - fake_squash_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); + AssertionResult result = + fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_squash_connection_); + RELEASE_ASSERT(result, result.message()); } - FakeStreamPtr request_stream = fake_squash_connection_->waitForNewStream(*dispatcher_); - request_stream->waitForEndStream(*dispatcher_); + FakeStreamPtr request_stream; + AssertionResult result = + fake_squash_connection_->waitForNewStream(*dispatcher_, request_stream); + RELEASE_ASSERT(result, result.message()); + result = request_stream->waitForEndStream(*dispatcher_); + RELEASE_ASSERT(result, result.message()); if (body.empty()) { request_stream->encodeHeaders(Http::TestHeaderMapImpl{{":status", status}}, true); } else { @@ -127,7 +135,7 @@ TEST_P(SquashFilterIntegrationTest, TestHappyPath) { EXPECT_STREQ("/api/v2/debugattachment/", create_stream->headers().Path()->value().c_str()); // Make sure the env var was replaced ProtobufWkt::Struct actualbody; - MessageUtil::loadFromJson(TestUtility::bufferToString(create_stream->body()), actualbody); + MessageUtil::loadFromJson(create_stream->body().toString(), actualbody); ProtobufWkt::Struct expectedbody; MessageUtil::loadFromJson("{\"spec\": { \"attachment\" : { \"env\": \"" ENV_VAR_VALUE diff --git a/test/extensions/filters/listener/proxy_protocol/BUILD b/test/extensions/filters/listener/proxy_protocol/BUILD index 57128bbb9bcc1..656aeb09e8a6a 100644 --- a/test/extensions/filters/listener/proxy_protocol/BUILD +++ b/test/extensions/filters/listener/proxy_protocol/BUILD @@ -24,11 +24,13 @@ envoy_extension_cc_test( "//source/common/stats:stats_lib", "//source/extensions/filters/listener/proxy_protocol:proxy_protocol_lib", "//source/server:connection_handler_lib", + "//test/mocks/api:api_mocks", "//test/mocks/buffer:buffer_mocks", "//test/mocks/network:network_mocks", "//test/mocks/server:server_mocks", "//test/test_common:environment_lib", "//test/test_common:network_utility_lib", + "//test/test_common:threadsafe_singleton_injector_lib", "//test/test_common:utility_lib", ], ) diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc index 5bc31d35fc4ef..b72e6ddd0c35e 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc @@ -14,18 +14,22 @@ #include "extensions/filters/listener/proxy_protocol/proxy_protocol.h" +#include "test/mocks/api/mocks.h" #include "test/mocks/buffer/mocks.h" #include "test/mocks/network/mocks.h" #include "test/mocks/server/mocks.h" #include "test/test_common/environment.h" #include "test/test_common/network_utility.h" #include "test/test_common/printers.h" +#include "test/test_common/threadsafe_singleton_injector.h" #include "test/test_common/utility.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +using testing::AnyNumber; using testing::AtLeast; +using testing::InSequence; using testing::Invoke; using testing::NiceMock; using testing::Return; @@ -47,6 +51,7 @@ class ProxyProtocolTest : public testing::TestWithParamaddListener(*this); conn_ = dispatcher_.createClientConnection(socket_.localAddress(), Network::Address::InstanceConstSharedPtr(), @@ -94,6 +99,11 @@ class ProxyProtocolTest : public testing::TestWithParamwrite(buf, false); + } + void write(const std::string& s) { Buffer::OwnedImpl buf(s); conn_->write(buf, false); @@ -103,7 +113,7 @@ class ProxyProtocolTest : public testing::TestWithParam Network::FilterStatus { - EXPECT_EQ(TestUtility::bufferToString(buffer), expected); + EXPECT_EQ(buffer.toString(), expected); buffer.drain(expected.length()); dispatcher_.exit(); return Network::FilterStatus::Continue; @@ -150,7 +160,7 @@ INSTANTIATE_TEST_CASE_P(IpVersions, ProxyProtocolTest, testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), TestUtility::ipTestParamsToString); -TEST_P(ProxyProtocolTest, Basic) { +TEST_P(ProxyProtocolTest, v1Basic) { connect(); write("PROXY TCP4 1.2.3.4 253.253.253.253 65535 1234\r\nmore data"); @@ -162,6 +172,39 @@ TEST_P(ProxyProtocolTest, Basic) { disconnect(); } +TEST_P(ProxyProtocolTest, v1Minimal) { + connect(); + write("PROXY UNKNOWN\r\nmore data"); + + expectData("more data"); + + if (GetParam() == Envoy::Network::Address::IpVersion::v4) { + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "127.0.0.1"); + } else { + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "::1"); + } + EXPECT_FALSE(server_connection_->localAddressRestored()); + + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2Basic) { + // A well-formed ipv4/tcp message, no extensions + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(); + write(buffer, sizeof(buffer)); + + expectData("more data"); + + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1.2.3.4"); + EXPECT_TRUE(server_connection_->localAddressRestored()); + + disconnect(); +} + TEST_P(ProxyProtocolTest, BasicV6) { connect(); write("PROXY TCP6 1:2:3::4 5:6::7:8 65535 1234\r\nmore data"); @@ -174,6 +217,316 @@ TEST_P(ProxyProtocolTest, BasicV6) { disconnect(); } +TEST_P(ProxyProtocolTest, v2BasicV6) { + // A well-formed ipv6/tcp message, no extensions + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, + 0x0a, 0x21, 0x22, 0x00, 0x24, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 'm', 'o', 'r', + 'e', ' ', 'd', 'a', 't', 'a'}; + connect(); + write(buffer, sizeof(buffer)); + + expectData("more data"); + + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1:2:3::4"); + EXPECT_TRUE(server_connection_->localAddressRestored()); + + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2UnsupportedAF) { + // A well-formed message with an unsupported address family + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x41, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + write(buffer, sizeof(buffer)); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, errorRecv_2) { + // A well formed v4/tcp message, no extensions, but introduce an error on recv (e.g. socket close) + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)).Times(AnyNumber()).WillOnce(Return((errno = 0, -1))); + EXPECT_CALL(os_sys_calls, ioctl(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([](int fd, unsigned long int request, void* argp) { + return ::ioctl(fd, request, argp); + })); + EXPECT_CALL(os_sys_calls, writev(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::writev(fd, iov, iovcnt); })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::readv(fd, iov, iovcnt); })); + + connect(false); + write(buffer, sizeof(buffer)); + + errno = 0; + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, errorFIONREAD_1) { + // A well formed v4/tcp message, no extensions, but introduce an error on ioctl(...FIONREAD...) + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + EXPECT_CALL(os_sys_calls, ioctl(_, FIONREAD, _)).WillOnce(Return(-1)); + EXPECT_CALL(os_sys_calls, writev(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::writev(fd, iov, iovcnt); })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::readv(fd, iov, iovcnt); })); + + connect(false); + write(buffer, sizeof(buffer)); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2NotLocalOrOnBehalf) { + // An illegal command type: neither 'local' nor 'proxy' command + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x23, 0x1f, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + write(buffer, sizeof(buffer)); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2LocalConnection) { + // A 'local' connection, e.g. health-checking, no address, no extensions + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, + 0x49, 0x54, 0x0a, 0x20, 0x00, 0x00, 0x00, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(); + write(buffer, sizeof(buffer)); + expectData("more data"); + if (server_connection_->remoteAddress()->ip()->version() == + Envoy::Network::Address::IpVersion::v6) { + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "::1"); + } else if (server_connection_->remoteAddress()->ip()->version() == + Envoy::Network::Address::IpVersion::v4) { + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "127.0.0.1"); + } + EXPECT_FALSE(server_connection_->localAddressRestored()); + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2LocalConnectionExtension) { + // A 'local' connection, e.g. health-checking, no address, 1 TLV (0x00,0x00,0x01,0xff) is present. + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x20, 0x00, 0x00, 0x04, 0x00, 0x00, 0x01, 0xff, + 'm', 'o', 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(); + write(buffer, sizeof(buffer)); + expectData("more data"); + if (server_connection_->remoteAddress()->ip()->version() == + Envoy::Network::Address::IpVersion::v6) { + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "::1"); + } else if (server_connection_->remoteAddress()->ip()->version() == + Envoy::Network::Address::IpVersion::v4) { + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "127.0.0.1"); + } + EXPECT_FALSE(server_connection_->localAddressRestored()); + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2ShortV4) { + // An ipv4/tcp connection that has incorrect addr-len encoded + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x21, 0x00, 0x04, 0x00, 0x08, 0x00, 0x02, + 'm', 'o', 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + + write(buffer, sizeof(buffer)); + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2ShortAddrV4) { + // An ipv4/tcp connection that has insufficient header-length encoded + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0b, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + + write(buffer, sizeof(buffer)); + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2ShortV6) { + // An ipv6/tcp connection that has incorrect addr-len encoded + constexpr uint8_t buffer[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x22, 0x00, + 0x14, 0x00, 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 'm', 'o', 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + + write(buffer, sizeof(buffer)); + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2ShortAddrV6) { + // An ipv6/tcp connection that has insufficient header-length encoded + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, + 0x0a, 0x21, 0x22, 0x00, 0x23, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 'm', 'o', 'r', + 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + + write(buffer, sizeof(buffer)); + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2AF_UNIX) { + // A well-formed AF_UNIX (0x32 in b14) connection is rejected + constexpr uint8_t buffer[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x32, 0x00, + 0x14, 0x00, 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 'm', 'o', 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + write(buffer, sizeof(buffer)); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2BadCommand) { + // A non local/proxy command (0x29 in b13) is rejected + constexpr uint8_t buffer[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x29, 0x32, 0x00, + 0x14, 0x00, 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 'm', 'o', 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + write(buffer, sizeof(buffer)); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2WrongVersion) { + // A non '2' version is rejected (0x93 in b13) + constexpr uint8_t buffer[] = { + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x93, 0x00, + 0x14, 0x00, 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x02, 'm', 'o', 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(false); + write(buffer, sizeof(buffer)); + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v1TooLong) { + constexpr uint8_t buffer[] = {' ', ' ', ' ', ' ', ' ', ' ', ' ', ' '}; + connect(false); + write("PROXY TCP4 1.2.3.4 2.3.4.5 100 100"); + for (size_t i = 0; i < 256; i += sizeof(buffer)) + write(buffer, sizeof(buffer)); + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2ParseExtensions) { + // A well-formed ipv4/tcp with a pair of TLV extensions is accepted + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x14, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + constexpr uint8_t tlv[] = {0x0, 0x0, 0x1, 0xff}; + + constexpr uint8_t data[] = {'D', 'A', 'T', 'A'}; + + connect(); + write(buffer, sizeof(buffer)); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + for (int i = 0; i < 2; i++) { + write(tlv, sizeof(tlv)); + } + write(data, sizeof(data)); + expectData("DATA"); + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2ParseExtensionsIoctlError) { + // A well-formed ipv4/tcp with a TLV extension. An error is created in the ioctl(...FIONREAD...) + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + constexpr uint8_t tlv[] = {0x0, 0x0, 0x1, 0xff}; + + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, ioctl(_, FIONREAD, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([](int fd, unsigned long int request, void* argp) { + int x = ::ioctl(fd, request, argp); + if (x == 0 && *static_cast(argp) == sizeof(tlv)) { + return -1; + } else { + return x; + } + })); + + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, void* buf, size_t len, int flags) { return ::recv(fd, buf, len, flags); })); + + EXPECT_CALL(os_sys_calls, writev(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::writev(fd, iov, iovcnt); })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::readv(fd, iov, iovcnt); })); + + connect(false); + write(buffer, sizeof(buffer)); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(tlv, sizeof(tlv)); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2ParseExtensionsFrag) { + // A well-formed ipv4/tcp header with 2 TLV/extenions, these are fragmented on delivery + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x14, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + constexpr uint8_t tlv[] = {0x0, 0x0, 0x1, 0xff}; + + constexpr uint8_t data[] = {'D', 'A', 'T', 'A'}; + + connect(); + write(buffer, sizeof(buffer)); + for (int i = 0; i < 2; i++) { + write(tlv, sizeof(tlv)); + } + write(data, sizeof(data)); + expectData("DATA"); + disconnect(); +} + TEST_P(ProxyProtocolTest, Fragmented) { connect(); write("PROXY TCP4"); @@ -194,6 +547,124 @@ TEST_P(ProxyProtocolTest, Fragmented) { disconnect(); } +TEST_P(ProxyProtocolTest, v2Fragmented1) { + // A well-formed ipv4/tcp header, delivering part of the signature, then part of + // the address, then the remainder + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(); + write(buffer, 10); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(buffer + 10, 10); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(buffer + 20, 17); + + expectData("more data"); + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1.2.3.4"); + EXPECT_TRUE(server_connection_->localAddressRestored()); + + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2Fragmented2) { + // A well-formed ipv4/tcp header, delivering all of the signature + 1, then the remainder + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + connect(); + write(buffer, 17); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(buffer + 17, 10); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(buffer + 27, 10); + + expectData("more data"); + + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1.2.3.4"); + EXPECT_TRUE(server_connection_->localAddressRestored()); + + disconnect(); +} + +TEST_P(ProxyProtocolTest, v2Fragmented3Error) { + // A well-formed ipv4/tcp header, delivering all of the signature +1, w/ an error + // simulated in recv() on the +1 + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, void* buf, size_t len, int flags) { return ::recv(fd, buf, len, flags); })); + EXPECT_CALL(os_sys_calls, recv(_, _, 1, _)).Times(AnyNumber()).WillOnce(Return(-1)); + + EXPECT_CALL(os_sys_calls, ioctl(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([](int fd, unsigned long int request, void* argp) { + return ::ioctl(fd, request, argp); + })); + EXPECT_CALL(os_sys_calls, writev(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::writev(fd, iov, iovcnt); })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::readv(fd, iov, iovcnt); })); + + connect(false); + write(buffer, 17); + + expectProxyProtoError(); +} + +TEST_P(ProxyProtocolTest, v2Fragmented4Error) { + // A well-formed ipv4/tcp header, part of the signature with an error introduced + // in recv() on the remainder + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02, 'm', 'o', + 'r', 'e', ' ', 'd', 'a', 't', 'a'}; + + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, void* buf, size_t len, int flags) { return ::recv(fd, buf, len, flags); })); + EXPECT_CALL(os_sys_calls, recv(_, _, 4, _)).Times(AnyNumber()).WillOnce(Return(-1)); + + EXPECT_CALL(os_sys_calls, ioctl(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([](int fd, unsigned long int request, void* argp) { + return ::ioctl(fd, request, argp); + })); + EXPECT_CALL(os_sys_calls, writev(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::writev(fd, iov, iovcnt); })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [](int fd, const struct iovec* iov, int iovcnt) { return ::readv(fd, iov, iovcnt); })); + + connect(false); + write(buffer, 10); + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + write(buffer + 10, 10); + + expectProxyProtoError(); +} + TEST_P(ProxyProtocolTest, PartialRead) { connect(); @@ -214,6 +685,29 @@ TEST_P(ProxyProtocolTest, PartialRead) { disconnect(); } +TEST_P(ProxyProtocolTest, v2PartialRead) { + // A well-formed ipv4/tcp header, delivered with part of the signature, + // part of the header, rest of header + body + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, + 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, + 0x03, 0x04, 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, + 0x02, 'm', 'o', 'r', 'e', 'd', 'a', 't', 'a'}; + connect(); + + for (size_t i = 0; i < sizeof(buffer); i += 9) { + write(&buffer[i], 9); + if (i == 0) + dispatcher_.run(Event::Dispatcher::RunType::NonBlock); + } + + expectData("moredata"); + + EXPECT_EQ(server_connection_->remoteAddress()->ip()->addressAsString(), "1.2.3.4"); + EXPECT_TRUE(server_connection_->localAddressRestored()); + + disconnect(); +} + TEST_P(ProxyProtocolTest, MalformedProxyLine) { connect(false); @@ -391,7 +885,7 @@ class WildcardProxyProtocolTest : public testing::TestWithParam Network::FilterStatus { - EXPECT_EQ(TestUtility::bufferToString(buffer), expected); + EXPECT_EQ(buffer.toString(), expected); buffer.drain(expected.length()); dispatcher_.exit(); return Network::FilterStatus::Continue; diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc index 48b8c6fceec2d..a7697248e5705 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc @@ -34,7 +34,7 @@ class FastMockListenerFilterCallbacks : public Network::MockListenerFilterCallba : socket_(socket), dispatcher_(dispatcher) {} Network::ConnectionSocket& socket() override { return socket_; } Event::Dispatcher& dispatcher() override { return dispatcher_; } - void continueFilterChain(bool success) override { RELEASE_ASSERT(success); } + void continueFilterChain(bool success) override { RELEASE_ASSERT(success, ""); } Network::ConnectionSocket& socket_; Event::Dispatcher& dispatcher_; @@ -63,7 +63,7 @@ class FastMockOsSysCalls : public Api::MockOsSysCalls { FastMockOsSysCalls(const std::vector& client_hello) : client_hello_(client_hello) {} ssize_t recv(int, void* buffer, size_t length, int) override { - RELEASE_ASSERT(length >= client_hello_.size()); + RELEASE_ASSERT(length >= client_hello_.size(), ""); memcpy(buffer, client_hello_.data(), client_hello_.size()); return client_hello_.size(); } @@ -85,10 +85,11 @@ static void BM_TlsInspector(benchmark::State& state) { Filter filter(cfg); filter.onAccept(cb); dispatcher.file_event_callback_(Event::FileReadyType::Read); - RELEASE_ASSERT(socket.detectedTransportProtocol() == "tls"); - RELEASE_ASSERT(socket.requestedServerName() == "example.com"); + RELEASE_ASSERT(socket.detectedTransportProtocol() == "tls", ""); + RELEASE_ASSERT(socket.requestedServerName() == "example.com", ""); RELEASE_ASSERT(socket.requestedApplicationProtocols().size() == 2 && - socket.requestedApplicationProtocols().front() == "h2"); + socket.requestedApplicationProtocols().front() == "h2", + ""); socket.setDetectedTransportProtocol(""); socket.setRequestedServerName(""); socket.setRequestedApplicationProtocols({}); diff --git a/test/extensions/filters/network/client_ssl_auth/config_test.cc b/test/extensions/filters/network/client_ssl_auth/config_test.cc index a3c19a2110872..d815680aec900 100644 --- a/test/extensions/filters/network/client_ssl_auth/config_test.cc +++ b/test/extensions/filters/network/client_ssl_auth/config_test.cc @@ -99,7 +99,7 @@ TEST(ClientSslAuthConfigFactoryTest, DoubleRegistrationTest) { (Registry::RegisterFactory()), EnvoyException, - fmt::format("Double registration for name: '{}'", NetworkFilterNames::get().CLIENT_SSL_AUTH)); + fmt::format("Double registration for name: '{}'", NetworkFilterNames::get().ClientSslAuth)); } } // namespace ClientSslAuth diff --git a/test/extensions/filters/network/ext_authz/ext_authz_test.cc b/test/extensions/filters/network/ext_authz/ext_authz_test.cc index 310a90983a3fc..9fff77c7ca5cb 100644 --- a/test/extensions/filters/network/ext_authz/ext_authz_test.cc +++ b/test/extensions/filters/network/ext_authz/ext_authz_test.cc @@ -61,6 +61,14 @@ class ExtAuthzFilterTest : public testing::Test { filter_->onBelowWriteBufferLowWatermark(); } + Filters::Common::ExtAuthz::ResponsePtr + makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus status) { + Filters::Common::ExtAuthz::ResponsePtr response = + std::make_unique(); + response->status = status; + return response; + } + ~ExtAuthzFilterTest() { for (const Stats::GaugeSharedPtr& gauge : stats_store_.gauges()) { EXPECT_EQ(0U, gauge->value()); @@ -114,7 +122,7 @@ TEST_F(ExtAuthzFilterTest, OKWithOnData) { EXPECT_EQ(1U, stats_store_.gauge("ext_authz.name.active").value()); EXPECT_CALL(filter_callbacks_, continueReading()); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::OK); + request_callbacks_->onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::OK)); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -152,7 +160,7 @@ TEST_F(ExtAuthzFilterTest, DeniedWithOnData) { EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); EXPECT_CALL(*client_, cancel()).Times(0); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Denied); + request_callbacks_->onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::Denied)); EXPECT_EQ(Network::FilterStatus::StopIteration, filter_->onData(data, false)); @@ -182,7 +190,7 @@ TEST_F(ExtAuthzFilterTest, FailOpen) { EXPECT_CALL(filter_callbacks_.connection_, close(_)).Times(0); EXPECT_CALL(*client_, cancel()).Times(0); EXPECT_CALL(filter_callbacks_, continueReading()); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + request_callbacks_->onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::Error)); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -213,7 +221,7 @@ TEST_F(ExtAuthzFilterTest, FailClose) { EXPECT_CALL(filter_callbacks_.connection_, close(_)).Times(1); EXPECT_CALL(filter_callbacks_, continueReading()).Times(0); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + request_callbacks_->onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::Error)); EXPECT_EQ(1U, stats_store_.counter("ext_authz.name.total").value()); EXPECT_EQ(1U, stats_store_.counter("ext_authz.name.error").value()); @@ -241,7 +249,7 @@ TEST_F(ExtAuthzFilterTest, DoNotCallCancelonRemoteClose) { EXPECT_EQ(Network::FilterStatus::StopIteration, filter_->onData(data, false)); EXPECT_CALL(filter_callbacks_, continueReading()); - request_callbacks_->onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + request_callbacks_->onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::Error)); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onData(data, false)); @@ -295,7 +303,7 @@ TEST_F(ExtAuthzFilterTest, ImmediateOK) { EXPECT_CALL(*client_, check(_, _, _)) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { - callbacks.onComplete(Filters::Common::ExtAuthz::CheckStatus::OK); + callbacks.onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::OK)); }))); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onNewConnection()); @@ -325,7 +333,7 @@ TEST_F(ExtAuthzFilterTest, ImmediateNOK) { EXPECT_CALL(*client_, check(_, _, _)) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { - callbacks.onComplete(Filters::Common::ExtAuthz::CheckStatus::Denied); + callbacks.onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::Denied)); }))); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onNewConnection()); @@ -351,7 +359,7 @@ TEST_F(ExtAuthzFilterTest, ImmediateErrorFailOpen) { EXPECT_CALL(*client_, check(_, _, _)) .WillOnce( WithArgs<0>(Invoke([&](Filters::Common::ExtAuthz::RequestCallbacks& callbacks) -> void { - callbacks.onComplete(Filters::Common::ExtAuthz::CheckStatus::Error); + callbacks.onComplete(makeAuthzResponse(Filters::Common::ExtAuthz::CheckStatus::Error)); }))); EXPECT_EQ(Network::FilterStatus::Continue, filter_->onNewConnection()); diff --git a/test/extensions/filters/network/http_connection_manager/config_test.cc b/test/extensions/filters/network/http_connection_manager/config_test.cc index d6acb20c5b24e..1fc646bdba6f0 100644 --- a/test/extensions/filters/network/http_connection_manager/config_test.cc +++ b/test/extensions/filters/network/http_connection_manager/config_test.cc @@ -29,7 +29,17 @@ parseHttpConnectionManagerFromJson(const std::string& json_string) { envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager http_connection_manager; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Config::FilterJson::translateHttpConnectionManager(*json_object_ptr, http_connection_manager); + NiceMock scope; + Config::FilterJson::translateHttpConnectionManager(*json_object_ptr, http_connection_manager, + scope.statsOptions()); + return http_connection_manager; +} + +envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager +parseHttpConnectionManagerFromV2Yaml(const std::string& yaml) { + envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager + http_connection_manager; + MessageUtil::loadFromYaml(yaml, http_connection_manager); return http_connection_manager; } @@ -118,6 +128,23 @@ TEST_F(HttpConnectionManagerConfigTest, MiscConfig) { ContainerEq(config.tracingConfig()->request_headers_for_tags_)); EXPECT_EQ(*context_.local_info_.address_, config.localAddress()); EXPECT_EQ("foo", config.serverName()); + EXPECT_EQ(5 * 60 * 1000, config.streamIdleTimeout().count()); +} + +// Validated that an explicit zero stream idle timeout disables. +TEST_F(HttpConnectionManagerConfigTest, DisabledStreamIdleTimeout) { + const std::string yaml_string = R"EOF( + stat_prefix: ingress_http + stream_idle_timeout: 0s + route_config: + name: local_route + http_filters: + - name: envoy.router + )EOF"; + + HttpConnectionManagerConfig config(parseHttpConnectionManagerFromV2Yaml(yaml_string), context_, + date_provider_, route_config_provider_manager_); + EXPECT_EQ(0, config.streamIdleTimeout().count()); } TEST_F(HttpConnectionManagerConfigTest, SingleDateProvider) { diff --git a/test/extensions/filters/network/redis_proxy/codec_impl_test.cc b/test/extensions/filters/network/redis_proxy/codec_impl_test.cc index 819fb7665256b..9ecf214707635 100644 --- a/test/extensions/filters/network/redis_proxy/codec_impl_test.cc +++ b/test/extensions/filters/network/redis_proxy/codec_impl_test.cc @@ -35,7 +35,7 @@ TEST_F(RedisEncoderDecoderImplTest, Null) { RespValue value; EXPECT_EQ("null", value.toString()); encoder_.encode(value, buffer_); - EXPECT_EQ("$-1\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ("$-1\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -47,7 +47,7 @@ TEST_F(RedisEncoderDecoderImplTest, Error) { value.asString() = "error"; EXPECT_EQ("\"error\"", value.toString()); encoder_.encode(value, buffer_); - EXPECT_EQ("-error\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ("-error\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -59,7 +59,7 @@ TEST_F(RedisEncoderDecoderImplTest, SimpleString) { value.asString() = "simple string"; EXPECT_EQ("\"simple string\"", value.toString()); encoder_.encode(value, buffer_); - EXPECT_EQ("+simple string\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ("+simple string\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -71,7 +71,7 @@ TEST_F(RedisEncoderDecoderImplTest, Integer) { value.asInteger() = std::numeric_limits::max(); EXPECT_EQ("9223372036854775807", value.toString()); encoder_.encode(value, buffer_); - EXPECT_EQ(":9223372036854775807\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ(":9223372036854775807\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -82,7 +82,7 @@ TEST_F(RedisEncoderDecoderImplTest, NegativeIntegerSmall) { value.type(RespType::Integer); value.asInteger() = -1; encoder_.encode(value, buffer_); - EXPECT_EQ(":-1\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ(":-1\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -93,7 +93,7 @@ TEST_F(RedisEncoderDecoderImplTest, NegativeIntegerLarge) { value.type(RespType::Integer); value.asInteger() = std::numeric_limits::min(); encoder_.encode(value, buffer_); - EXPECT_EQ(":-9223372036854775808\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ(":-9223372036854775808\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -104,7 +104,7 @@ TEST_F(RedisEncoderDecoderImplTest, EmptyArray) { value.type(RespType::Array); EXPECT_EQ("[]", value.toString()); encoder_.encode(value, buffer_); - EXPECT_EQ("*0\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ("*0\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -122,7 +122,7 @@ TEST_F(RedisEncoderDecoderImplTest, Array) { value.asArray().swap(values); EXPECT_EQ("[\"hello\", -5]", value.toString()); encoder_.encode(value, buffer_); - EXPECT_EQ("*2\r\n$5\r\nhello\r\n:-5\r\n", TestUtility::bufferToString(buffer_)); + EXPECT_EQ("*2\r\n$5\r\nhello\r\n:-5\r\n", buffer_.toString()); decoder_.decode(buffer_); EXPECT_EQ(value, *decoded_values_[0]); EXPECT_EQ(0UL, buffer_.length()); @@ -145,11 +145,10 @@ TEST_F(RedisEncoderDecoderImplTest, NestedArray) { value.type(RespType::Array); value.asArray().swap(values); encoder_.encode(value, buffer_); - EXPECT_EQ("*2\r\n*3\r\n$5\r\nhello\r\n:0\r\n$-1\r\n$5\r\nworld\r\n", - TestUtility::bufferToString(buffer_)); + EXPECT_EQ("*2\r\n*3\r\n$5\r\nhello\r\n:0\r\n$-1\r\n$5\r\nworld\r\n", buffer_.toString()); // To test partial decode we will feed the buffer in 1 char at a time. - for (char c : TestUtility::bufferToString(buffer_)) { + for (char c : buffer_.toString()) { Buffer::OwnedImpl temp_buffer(&c, 1); decoder_.decode(temp_buffer); EXPECT_EQ(0UL, temp_buffer.length()); diff --git a/test/extensions/filters/network/redis_proxy/mocks.cc b/test/extensions/filters/network/redis_proxy/mocks.cc index d64ce6c698331..5cc170a32400f 100644 --- a/test/extensions/filters/network/redis_proxy/mocks.cc +++ b/test/extensions/filters/network/redis_proxy/mocks.cc @@ -50,7 +50,7 @@ bool operator==(const RespValue& lhs, const RespValue& rhs) { } } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } MockEncoder::MockEncoder() { diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index da7a3d6c933be..1eea2452232bb 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -2,30 +2,37 @@ licenses(["notice"]) # Apache 2 load( "//bazel:envoy_build_system.bzl", - "envoy_cc_mock", - "envoy_cc_test_library", "envoy_package", ) load( "//test/extensions:extensions_build_system.bzl", + "envoy_extension_cc_mock", "envoy_extension_cc_test", + "envoy_extension_cc_test_library", ) envoy_package() -envoy_cc_mock( +envoy_extension_cc_mock( name = "mocks", srcs = ["mocks.cc"], hdrs = ["mocks.h"], + extension_name = "envoy.filters.network.thrift_proxy", deps = [ + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_lib", "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + "//test/mocks/network:network_mocks", "//test/test_common:printers_lib", ], ) -envoy_cc_test_library( +envoy_extension_cc_test_library( name = "utility_lib", hdrs = ["utility.h"], + extension_name = "envoy.filters.network.thrift_proxy", deps = [ "//source/common/buffer:buffer_lib", "//source/common/common:byte_order_lib", @@ -34,8 +41,8 @@ envoy_cc_test_library( ) envoy_extension_cc_test( - name = "binary_protocol_test", - srcs = ["binary_protocol_test.cc"], + name = "binary_protocol_impl_test", + srcs = ["binary_protocol_impl_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", @@ -59,8 +66,8 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "compact_protocol_test", - srcs = ["compact_protocol_test.cc"], + name = "compact_protocol_impl_test", + srcs = ["compact_protocol_impl_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", @@ -77,7 +84,27 @@ envoy_extension_cc_test( extension_name = "envoy.filters.network.thrift_proxy", deps = [ "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "conn_manager_test", + srcs = ["conn_manager_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + "//test/mocks/network:network_mocks", "//test/mocks/server:server_mocks", + "//test/mocks/upstream:upstream_mocks", + "//test/test_common:printers_lib", ], ) @@ -87,6 +114,7 @@ envoy_extension_cc_test( extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", + ":utility_lib", "//source/extensions/filters/network/thrift_proxy:decoder_lib", "//test/test_common:printers_lib", "//test/test_common:utility_lib", @@ -94,21 +122,21 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "filter_test", - srcs = ["filter_test.cc"], + name = "framed_transport_impl_test", + srcs = ["framed_transport_impl_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ + ":mocks", ":utility_lib", - "//source/common/stats:stats_lib", - "//source/extensions/filters/network/thrift_proxy:filter_lib", - "//test/mocks/network:network_mocks", + "//source/extensions/filters/network/thrift_proxy:transport_lib", "//test/test_common:printers_lib", + "//test/test_common:utility_lib", ], ) envoy_extension_cc_test( - name = "protocol_test", - srcs = ["protocol_test.cc"], + name = "protocol_impl_test", + srcs = ["protocol_impl_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", @@ -120,8 +148,39 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "transport_test", - srcs = ["transport_test.cc"], + name = "router_test", + srcs = ["router_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//source/extensions/filters/network/thrift_proxy/router:router_lib", + "//test/mocks/network:network_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/upstream:upstream_mocks", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + ], +) + +envoy_extension_cc_test( + name = "transport_impl_test", + srcs = ["transport_impl_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//test/test_common:printers_lib", + "//test/test_common:utility_lib", + ], +) + +envoy_extension_cc_test( + name = "unframed_transport_impl_test", + srcs = ["unframed_transport_impl_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", @@ -131,3 +190,22 @@ envoy_extension_cc_test( "//test/test_common:utility_lib", ], ) + +envoy_extension_cc_test( + name = "integration_test", + srcs = ["integration_test.cc"], + data = [ + "//test/extensions/filters/network/thrift_proxy/driver:generate_fixture", + ], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + "//source/extensions/filters/network/tcp_proxy:config", + "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//test/integration:integration_lib", + "//test/test_common:environment_lib", + "//test/test_common:network_utility_lib", + "//test/test_common:printers_lib", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc similarity index 68% rename from test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc rename to test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc index c4f2c6f443390..3db9d588d5338 100644 --- a/test/extensions/filters/network/thrift_proxy/binary_protocol_test.cc +++ b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc @@ -2,31 +2,26 @@ #include "common/buffer/buffer_impl.h" -#include "extensions/filters/network/thrift_proxy/binary_protocol.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(BinaryProtocolTest, Name) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary"); } TEST(BinaryProtocolTest, ReadMessageBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -97,7 +92,6 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addInt32(buffer, 0); addInt32(buffer, 1234); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 1234)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -139,7 +133,6 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addString(buffer, "the_name"); addInt32(buffer, 5678); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 5678)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -150,34 +143,29 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { TEST(BinaryProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; - EXPECT_CALL(cb, messageComplete()); EXPECT_TRUE(proto.readMessageEnd(buffer)); } TEST(BinaryProtocolTest, ReadStructBegin) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; std::string name = "-"; - EXPECT_CALL(cb, structBegin(absl::string_view(""))); + EXPECT_TRUE(proto.readStructBegin(buffer, name)); EXPECT_EQ(name, ""); } TEST(BinaryProtocolTest, ReadStructEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); - EXPECT_CALL(cb, structEnd()); + BinaryProtocolImpl proto; + EXPECT_TRUE(proto.readStructEnd(buffer)); } TEST(BinaryProtocolTest, ReadFieldBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -201,7 +189,6 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { addInt8(buffer, FieldType::Stop); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Stop, 0)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Stop); @@ -224,7 +211,7 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { EXPECT_EQ(field_id, 1); } - // Non-terminal field + // Non-stop field { Buffer::OwnedImpl buffer; std::string name = "-"; @@ -234,25 +221,40 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { addInt8(buffer, FieldType::I32); addInt16(buffer, 99); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 99)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); EXPECT_EQ(field_id, 99); EXPECT_EQ(buffer.length(), 0); } + + // field id < 0 + { + Buffer::OwnedImpl buffer; + std::string name = "-"; + FieldType field_type = FieldType::String; + int16_t field_id = 1; + + addInt8(buffer, FieldType::I32); + addInt16(buffer, -1); + + EXPECT_THROW_WITH_MESSAGE(proto.readFieldBegin(buffer, name, field_type, field_id), + EnvoyException, "invalid binary protocol field id -1"); + EXPECT_EQ(name, "-"); + EXPECT_EQ(field_type, FieldType::String); + EXPECT_EQ(field_id, 1); + EXPECT_EQ(buffer.length(), 3); + } } TEST(BinaryProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } TEST(BinaryProtocolTest, ReadMapBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -310,14 +312,12 @@ TEST(BinaryProtocolTest, ReadMapBegin) { TEST(BinaryProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } TEST(BinaryProtocolTest, ReadListBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -367,14 +367,12 @@ TEST(BinaryProtocolTest, ReadListBegin) { TEST(BinaryProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } TEST(BinaryProtocolTest, ReadSetBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() Buffer::OwnedImpl buffer; @@ -392,14 +390,12 @@ TEST(BinaryProtocolTest, ReadSetBegin) { TEST(BinaryProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } TEST(BinaryProtocolTest, ReadIntegerTypes) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Bool { @@ -517,8 +513,7 @@ TEST(BinaryProtocolTest, ReadIntegerTypes) { } TEST(BinaryProtocolTest, ReadDouble) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -546,8 +541,7 @@ TEST(BinaryProtocolTest, ReadDouble) { } TEST(BinaryProtocolTest, ReadString) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data to read length { @@ -614,8 +608,7 @@ TEST(BinaryProtocolTest, ReadString) { TEST(BinaryProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; std::string value = "-"; @@ -627,15 +620,279 @@ TEST(BinaryProtocolTest, ReadBinary) { EXPECT_EQ(buffer.length(), 0); } +TEST(BinaryProtocolTest, WriteMessageBegin) { + BinaryProtocolImpl proto; + + // Named call + { + Buffer::OwnedImpl buffer; + proto.writeMessageBegin(buffer, "message", MessageType::Call, 1); + EXPECT_EQ(std::string("\x80\x1\0\x1\0\0\0\x7message\0\0\0\x1", 19), buffer.toString()); + } + + // Unnamed oneway + { + Buffer::OwnedImpl buffer; + proto.writeMessageBegin(buffer, "", MessageType::Oneway, 2); + EXPECT_EQ(std::string("\x80\x1\0\x4\0\0\0\0\0\0\0\x2", 12), buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteMessageEnd) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeMessageEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteStructBegin) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeStructBegin(buffer, "unused"); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteStructEnd) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeStructEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteFieldBegin) { + BinaryProtocolImpl proto; + + // Stop field + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Stop, 1); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + + // Normal field + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::I32, 1); + EXPECT_EQ(std::string("\x8\0\x1", 3), buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteFieldEnd) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeFieldEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteMapBegin) { + BinaryProtocolImpl proto; + + // Non-empty map + { + Buffer::OwnedImpl buffer; + proto.writeMapBegin(buffer, FieldType::I32, FieldType::String, 3); + EXPECT_EQ(std::string("\x8\xb\0\0\0\x3", 6), buffer.toString()); + } + + // Empty map + { + Buffer::OwnedImpl buffer; + proto.writeMapBegin(buffer, FieldType::I32, FieldType::String, 0); + EXPECT_EQ(std::string("\x8\xb\0\0\0\0", 6), buffer.toString()); + } + + // Oversized map + { + Buffer::OwnedImpl buffer; + EXPECT_THROW_WITH_MESSAGE( + proto.writeMapBegin(buffer, FieldType::I32, FieldType::String, 3000000000), EnvoyException, + "illegal binary protocol map size 3000000000"); + } +} + +TEST(BinaryProtocolTest, WriteMapEnd) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeMapEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteListBegin) { + BinaryProtocolImpl proto; + + // Non-empty list + { + Buffer::OwnedImpl buffer; + proto.writeListBegin(buffer, FieldType::String, 3); + EXPECT_EQ(std::string("\xb\0\0\0\x3", 5), buffer.toString()); + } + + // Empty list + { + Buffer::OwnedImpl buffer; + proto.writeListBegin(buffer, FieldType::String, 0); + EXPECT_EQ(std::string("\xb\0\0\0\0", 5), buffer.toString()); + } + + // Oversized list + { + Buffer::OwnedImpl buffer; + EXPECT_THROW_WITH_MESSAGE(proto.writeListBegin(buffer, FieldType::String, 3000000000), + EnvoyException, "illegal binary protocol list/set size 3000000000"); + } +} + +TEST(BinaryProtocolTest, WriteListEnd) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeListEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteSetBegin) { + BinaryProtocolImpl proto; + + // Only test the happy path, as this shares an implementation with writeListBegin + // Non-empty list + Buffer::OwnedImpl buffer; + proto.writeSetBegin(buffer, FieldType::String, 3); + EXPECT_EQ(std::string("\xb\0\0\0\x3", 5), buffer.toString()); +} + +TEST(BinaryProtocolTest, WriteSetEnd) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeSetEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(BinaryProtocolTest, WriteBool) { + BinaryProtocolImpl proto; + + // True + { + Buffer::OwnedImpl buffer; + proto.writeBool(buffer, true); + EXPECT_EQ("\x1", buffer.toString()); + } + + // False + { + Buffer::OwnedImpl buffer; + proto.writeBool(buffer, false); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteByte) { + BinaryProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeByte(buffer, -1); + EXPECT_EQ("\xFF", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeByte(buffer, 127); + EXPECT_EQ("\x7F", buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteInt16) { + BinaryProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, -1); + EXPECT_EQ("\xFF\xFF", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, 0x0102); + EXPECT_EQ("\x1\x2", buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteInt32) { + BinaryProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, -1); + EXPECT_EQ("\xFF\xFF\xFF\xFF", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, 0x01020304); + EXPECT_EQ("\x1\x2\x3\x4", buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteInt64) { + BinaryProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, -1); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, 0x0102030405060708); + EXPECT_EQ("\x1\x2\x3\x4\x5\x6\x7\x8", buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteDouble) { + BinaryProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeDouble(buffer, 3.0); + EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); +} + +TEST(BinaryProtocolTest, WriteString) { + BinaryProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeString(buffer, "abc"); + EXPECT_EQ(std::string("\0\0\0\x3" + "abc", + 7), + buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeString(buffer, ""); + EXPECT_EQ(std::string("\0\0\0\0", 4), buffer.toString()); + } +} + +TEST(BinaryProtocolTest, WriteBinary) { + BinaryProtocolImpl proto; + + // Happy path only, since this is just a synonym for writeString + Buffer::OwnedImpl buffer; + proto.writeBinary(buffer, "abc"); + EXPECT_EQ(std::string("\0\0\0\x3" + "abc", + 7), + buffer.toString()); +} + TEST(LaxBinaryProtocolTest, Name) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary/non-strict"); } TEST(LaxBinaryProtocolTest, ReadMessageBegin) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; // Insufficient data { @@ -685,7 +942,6 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { addInt8(buffer, MessageType::Call); addInt32(buffer, 1234); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 1234)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -723,7 +979,6 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { addInt8(buffer, MessageType::Call); addInt32(buffer, 5678); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 5678)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -732,6 +987,24 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { } } +TEST(LaxBinaryProtocolTest, WriteMessageBegin) { + LaxBinaryProtocolImpl proto; + + // Named call + { + Buffer::OwnedImpl buffer; + proto.writeMessageBegin(buffer, "message", MessageType::Call, 1); + EXPECT_EQ(std::string("\0\0\0\x7message\x1\0\0\0\x1", 16), buffer.toString()); + } + + // Unnamed oneway + { + Buffer::OwnedImpl buffer; + proto.writeMessageBegin(buffer, "", MessageType::Oneway, 2); + EXPECT_EQ(std::string("\0\0\0\0\x4\0\0\0\x2", 9), buffer.toString()); + } +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc b/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc index 49a12ba56dae3..26030cd7bd9a7 100644 --- a/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc +++ b/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc @@ -1,3 +1,5 @@ +#include + #include "envoy/common/exception.h" #include "common/buffer/buffer_impl.h" @@ -15,50 +17,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(BufferWrapperTest, ImplementedFunctions) { - Buffer::OwnedImpl buffer; - addString(buffer, "abcdefghij"); - - BufferWrapper wrapper(buffer); - { - char s[4] = {0}; - wrapper.copyOut(0, 3, s); - EXPECT_EQ("abc", std::string(s)); - EXPECT_EQ(10, wrapper.length()); - EXPECT_EQ(0, wrapper.position()); - } - - { - char s[6] = {0}; - wrapper.copyOut(5, 5, s); - EXPECT_EQ("fghij", std::string(s)); - EXPECT_EQ(10, wrapper.length()); - EXPECT_EQ(0, wrapper.position()); - } - - { - std::string s(static_cast(wrapper.linearize(5)), 5); - EXPECT_EQ("abcde", s); - EXPECT_EQ(0, wrapper.position()); - } - - wrapper.drain(2); - - { - char s[4] = {0}; - wrapper.copyOut(4, 3, s); - EXPECT_EQ("ghi", std::string(s)); - EXPECT_EQ(8, wrapper.length()); - EXPECT_EQ(2, wrapper.position()); - } - - { - std::string s(static_cast(wrapper.linearize(8)), 8); - EXPECT_EQ("cdefghij", s); - EXPECT_EQ(2, wrapper.position()); - } -} - TEST(BufferHelperTest, PeekI8) { { Buffer::OwnedImpl buffer; @@ -283,7 +241,7 @@ TEST(BufferHelperTest, DrainDouble) { addSeq(buffer, {0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}); EXPECT_EQ(BufferHelper::drainDouble(buffer), 3.0); - EXPECT_EQ(BufferHelper::drainDouble(buffer), -DBL_MAX); + EXPECT_EQ(BufferHelper::drainDouble(buffer), std::numeric_limits::lowest()); EXPECT_EQ(buffer.length(), 0); } @@ -354,16 +312,16 @@ TEST(BufferHelperTest, PeekVarInt32BufferUnderflow) { TEST(BufferHelperTest, PeekZigZagI32) { Buffer::OwnedImpl buffer; - addInt8(buffer, 0); // zigzag(0) = 0 - addInt8(buffer, 1); // zigzag(1) = -1 - addInt8(buffer, 2); // zigzag(2) = 1 - addSeq(buffer, {0xFE, 0x01}); // zigzag(0xFE) = 127 - addSeq(buffer, {0xFF, 0x01}); // zigzag(0xFF) = -128 - addSeq(buffer, {0xFF, 0xFF, 0x03}); // zigzag(0xFFFF) = -32768 - addSeq(buffer, {0xFF, 0xFF, 0xFF, 0x07}); // zigzag(0xFFFFFF) = -8388608 - addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x07}); // zigzag(0x7FFFFFFE) = 0x3FFFFFFF - addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x0F}); // zigzag(0xFFFFFFFE) = 0x7FFFFFFF - addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x0F}); // zigzag(0xFFFFFFFF) = 0x80000000 + addInt8(buffer, 0); // unzigzag(0) = 0 + addInt8(buffer, 1); // unzigzag(1) = -1 + addInt8(buffer, 2); // unzigzag(2) = 1 + addSeq(buffer, {0xFE, 0x01}); // unzigzag(0xFE) = 127 + addSeq(buffer, {0xFF, 0x01}); // unzigzag(0xFF) = -128 + addSeq(buffer, {0xFF, 0xFF, 0x03}); // unzigzag(0xFFFF) = -32768 + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0x07}); // unzigzag(0xFFFFFF) = -8388608 + addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x07}); // unzigzag(0x7FFFFFFE) = 0x3FFFFFFF + addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x0F}); // unzigzag(0xFFFFFFFE) = 0x7FFFFFFF + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x0F}); // unzigzag(0xFFFFFFFF) = 0x80000000 int size = 0; EXPECT_EQ(BufferHelper::peekZigZagI32(buffer, 0, size), 0); @@ -414,19 +372,19 @@ TEST(BufferHelperTest, PeekZigZagI32BufferUnderflow) { TEST(BufferHelperTest, PeekZigZagI64) { Buffer::OwnedImpl buffer; - addInt8(buffer, 0); // zigzag(0) = 0 - addInt8(buffer, 1); // zigzag(1) = -1 - addInt8(buffer, 2); // zigzag(2) = 1 - addSeq(buffer, {0xFF, 0xFF, 0x03}); // zigzag(0xFFFF) = -32768 - addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x0F}); // zigzag(0xFFFF FFFE) = 0x7FFF FFFF + addInt8(buffer, 0); // unzigzag(0) = 0 + addInt8(buffer, 1); // unzigzag(1) = -1 + addInt8(buffer, 2); // unzigzag(2) = 1 + addSeq(buffer, {0xFF, 0xFF, 0x03}); // unzigzag(0xFFFF) = -32768 + addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x0F}); // unzigzag(0xFFFF FFFE) = 0x7FFF FFFF - // zigzag(0xFFFF FFFF FFFE) = 0x7FFF FFFF FFFF + // unzigzag(0xFFFF FFFF FFFE) = 0x7FFF FFFF FFFF addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3F}); - // zigzag(0x7FFF FFFF FFFF FFFE) = 0x3FFF FFFF FFFF FFFF + // unzigzag(0x7FFF FFFF FFFF FFFE) = 0x3FFF FFFF FFFF FFFF addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F}); - // zigzag(0xFFFF FFFF FFFF FFFF) = 0x8000 0000 0000 0000 (-2^63) + // unzigzag(0xFFFF FFFF FFFF FFFF) = 0x8000 0000 0000 0000 (-2^63) addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}); int size = 0; @@ -470,6 +428,377 @@ TEST(BufferHelperTest, PeekZigZagI64BufferUnderflow) { "invalid compact protocol zig-zag i64"); } +TEST(BufferHelperTest, WriteI8) { + Buffer::OwnedImpl buffer; + BufferHelper::writeI8(buffer, -128); + BufferHelper::writeI8(buffer, -1); + BufferHelper::writeI8(buffer, 0); + BufferHelper::writeI8(buffer, 1); + BufferHelper::writeI8(buffer, 127); + + EXPECT_EQ(std::string("\x80\xFF\0\x1\x7F", 5), buffer.toString()); +} + +TEST(BufferHelperTest, WriteI16) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI16(buffer, std::numeric_limits::min()); + EXPECT_EQ(std::string("\x80\0", 2), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI16(buffer, 0); + EXPECT_EQ(std::string("\0\0", 2), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI16(buffer, 1); + EXPECT_EQ(std::string("\0\x1", 2), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI16(buffer, std::numeric_limits::max()); + EXPECT_EQ("\x7F\xFF", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteU16) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU16(buffer, 0); + EXPECT_EQ(std::string("\0\0", 2), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU16(buffer, 1); + EXPECT_EQ(std::string("\0\x1", 2), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU16(buffer, static_cast(std::numeric_limits::max()) + 1); + EXPECT_EQ(std::string("\x80\0", 2), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU16(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFF\xFF", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteI32) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI32(buffer, std::numeric_limits::min()); + EXPECT_EQ(std::string("\x80\0\0\0", 4), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI32(buffer, 0); + EXPECT_EQ(std::string("\0\0\0\0", 4), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI32(buffer, 1); + EXPECT_EQ(std::string("\0\0\0\x1", 4), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI32(buffer, std::numeric_limits::max()); + EXPECT_EQ("\x7F\xFF\xFF\xFF", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteU32) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU32(buffer, 0); + EXPECT_EQ(std::string("\0\0\0\0", 4), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU32(buffer, 1); + EXPECT_EQ(std::string("\0\0\0\x1", 4), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU32(buffer, static_cast(std::numeric_limits::max()) + 1); + EXPECT_EQ(std::string("\x80\0\0\0", 4), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeU32(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFF\xFF\xFF\xFF", buffer.toString()); + } +} +TEST(BufferHelperTest, WriteI64) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI64(buffer, std::numeric_limits::min()); + EXPECT_EQ(std::string("\x80\0\0\0\0\0\0\0\0", 8), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI64(buffer, 1); + EXPECT_EQ(std::string("\0\0\0\0\0\0\0\x1", 8), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI64(buffer, 0); + EXPECT_EQ(std::string("\0\0\0\0\0\0\0\0", 8), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeI64(buffer, std::numeric_limits::max()); + EXPECT_EQ("\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteDouble) { + // See the DrainDouble test. + { + Buffer::OwnedImpl buffer; + BufferHelper::writeDouble(buffer, 3.0); + EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + BufferHelper::writeDouble(buffer, std::numeric_limits::lowest()); + EXPECT_EQ("\xFF\xEF\xFF\xFF\xFF\xFF\xFF\xFF", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteVarIntI32) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, 0); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, 1); + EXPECT_EQ("\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, 128); + EXPECT_EQ("\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, (1 << 14) + 1); + EXPECT_EQ("\x81\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, (1 << 28) + 1); + EXPECT_EQ("\x81\x80\x80\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\x7", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, -1); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xF", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI32(buffer, std::numeric_limits::min()); + EXPECT_EQ("\x80\x80\x80\x80\x8", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteVarIntI64) { + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, 0); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, 1); + EXPECT_EQ("\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, 128); + EXPECT_EQ("\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, (1 << 14) + 1); + EXPECT_EQ("\x81\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, (1 << 28) + 1); + EXPECT_EQ("\x81\x80\x80\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, (static_cast(1) << 56) + 1); + EXPECT_EQ("\x81\x80\x80\x80\x80\x80\x80\x80\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\x7", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, -1); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, std::numeric_limits::min()); + EXPECT_EQ("\x80\x80\x80\x80\xF8\xFF\xFF\xFF\xFF\x1", buffer.toString()); + } + { + Buffer::OwnedImpl buffer; + BufferHelper::writeVarIntI64(buffer, std::numeric_limits::min()); + EXPECT_EQ("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x1", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteZigZagI32) { + // zigzag(0) = 0 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, 0); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + + // zigzag(-1) = 1 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, -1); + EXPECT_EQ("\x1", buffer.toString()); + } + + // zigzag(1) = 2 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, 1); + EXPECT_EQ("\x2", buffer.toString()); + } + + // zigzag(127) = 0xFE + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, 127); + EXPECT_EQ("\xFE\x1", buffer.toString()); + } + + // zigzag(128) = 0x100 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, 128); + EXPECT_EQ("\x80\x2", buffer.toString()); + } + + // zigzag(-128) = 0xFF + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, -128); + EXPECT_EQ("\xFF\x1", buffer.toString()); + } + + // zigzag(0x7FFFFFFF) = 0xFFFFFFFE + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFE\xFF\xFF\xFF\xF", buffer.toString()); + } + + // zigzag(0x80000000) = 0xFFFFFFFF + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI32(buffer, std::numeric_limits::min()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xF", buffer.toString()); + } +} + +TEST(BufferHelperTest, WriteZigZagI64) { + // zigzag(0) = 0 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, 0); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + + // zigzag(-1) = 1 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, -1); + EXPECT_EQ("\x1", buffer.toString()); + } + + // zigzag(1) = 2 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, 1); + EXPECT_EQ("\x2", buffer.toString()); + } + + // zigzag(127) = 0xFE + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, 127); + EXPECT_EQ("\xFE\x1", buffer.toString()); + } + + // zigzag(128) = 0x100 + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, 128); + EXPECT_EQ("\x80\x2", buffer.toString()); + } + + // zigzag(-128) = 0xFF + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, -128); + EXPECT_EQ("\xFF\x1", buffer.toString()); + } + + // zigzag(0x7FFFFFFF) = 0xFFFFFFFE + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFE\xFF\xFF\xFF\xF", buffer.toString()); + } + + // zigzag(0x80000000) = 0xFFFFFFFF + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, std::numeric_limits::min()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xF", buffer.toString()); + } + + // zigzag(0x7FFFFFFF FFFFFFFF) = 0xFFFFFFFFFFFFFFFE + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1", buffer.toString()); + } + + // zigzag(0x8000000000000000) = 0xFFFFFFFFFFFFFFFF + { + Buffer::OwnedImpl buffer; + BufferHelper::writeZigZagI64(buffer, std::numeric_limits::min()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1", buffer.toString()); + } +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc similarity index 66% rename from test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc rename to test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc index 7bc0fe7ab8a2c..79187821def52 100644 --- a/test/extensions/filters/network/thrift_proxy/compact_protocol_test.cc +++ b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc @@ -2,17 +2,14 @@ #include "common/buffer/buffer_impl.h" -#include "extensions/filters/network/thrift_proxy/compact_protocol.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" #include "gtest/gtest.h" -using testing::NiceMock; -using testing::StrictMock; using testing::TestWithParam; using testing::Values; @@ -22,14 +19,12 @@ namespace NetworkFilters { namespace ThriftProxy { TEST(CompactProtocolTest, Name) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_EQ(proto.name(), "compact"); } TEST(CompactProtocolTest, ReadMessageBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -170,7 +165,6 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addInt8(buffer, 32); addInt8(buffer, 0); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 32)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -228,7 +222,6 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addInt8(buffer, 8); addString(buffer, "the_name"); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 0x0102)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -239,22 +232,19 @@ TEST(CompactProtocolTest, ReadMessageBegin) { TEST(CompactProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); - EXPECT_CALL(cb, messageComplete()); + CompactProtocolImpl proto; + EXPECT_TRUE(proto.readMessageEnd(buffer)); } TEST(CompactProtocolTest, ReadStruct) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; std::string name = "-"; - EXPECT_CALL(cb, structBegin(absl::string_view(""))); + EXPECT_TRUE(proto.readStructBegin(buffer, name)); EXPECT_EQ(name, ""); - EXPECT_CALL(cb, structEnd()); EXPECT_TRUE(proto.readStructEnd(buffer)); EXPECT_THROW_WITH_MESSAGE(proto.readStructEnd(buffer), EnvoyException, @@ -262,8 +252,7 @@ TEST(CompactProtocolTest, ReadStruct) { } TEST(CompactProtocolTest, ReadFieldBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -287,7 +276,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0xF0); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Stop, 0)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Stop); @@ -336,7 +324,7 @@ TEST(CompactProtocolTest, ReadFieldBegin) { EXPECT_EQ(buffer.length(), 6); } - // Long-form field header, field id out of range + // Long-form field header, field id > 32767 { Buffer::OwnedImpl buffer; std::string name = "-"; @@ -344,16 +332,34 @@ TEST(CompactProtocolTest, ReadFieldBegin) { int16_t field_id = 1; addInt8(buffer, 0x05); - addSeq(buffer, {0xFE, 0xFF, 0x7F}); // zigzag(0x1FFFFE) = 0xFFFFF + addSeq(buffer, {0x80, 0x80, 0x04}); // zigzag(0x10000) = 0x8000 EXPECT_THROW_WITH_MESSAGE(proto.readFieldBegin(buffer, name, field_type, field_id), - EnvoyException, "invalid compact protocol field id 1048575"); + EnvoyException, "invalid compact protocol field id 32768"); EXPECT_EQ(name, "-"); EXPECT_EQ(field_type, FieldType::String); EXPECT_EQ(field_id, 1); EXPECT_EQ(buffer.length(), 4); } + // Long-form field header, field id < 0 + { + Buffer::OwnedImpl buffer; + std::string name = "-"; + FieldType field_type = FieldType::String; + int16_t field_id = 1; + + addInt8(buffer, 0x05); + addSeq(buffer, {0x01}); // zigzag(1) = -1 + + EXPECT_THROW_WITH_MESSAGE(proto.readFieldBegin(buffer, name, field_type, field_id), + EnvoyException, "invalid compact protocol field id -1"); + EXPECT_EQ(name, "-"); + EXPECT_EQ(field_type, FieldType::String); + EXPECT_EQ(field_id, 1); + EXPECT_EQ(buffer.length(), 2); + } + // Unknown compact protocol field type { Buffer::OwnedImpl buffer; @@ -382,7 +388,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0x05); addInt8(buffer, 0x04); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 2)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -399,7 +404,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0xF5); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 17)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -410,14 +414,12 @@ TEST(CompactProtocolTest, ReadFieldBegin) { TEST(CompactProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } TEST(CompactProtocolTest, ReadMapBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -557,14 +559,12 @@ TEST(CompactProtocolTest, ReadMapBegin) { TEST(CompactProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } TEST(CompactProtocolTest, ReadListBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -672,14 +672,12 @@ TEST(CompactProtocolTest, ReadListBegin) { TEST(CompactProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } TEST(CompactProtocolTest, ReadSetBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() Buffer::OwnedImpl buffer; @@ -696,14 +694,12 @@ TEST(CompactProtocolTest, ReadSetBegin) { TEST(CompactProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } TEST(CompactProtocolTest, ReadBool) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Bool field values are encoded in the field type { @@ -716,7 +712,6 @@ TEST(CompactProtocolTest, ReadBool) { addInt8(buffer, 0x01); addInt8(buffer, 0x04); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Bool, 2)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Bool); @@ -733,7 +728,6 @@ TEST(CompactProtocolTest, ReadBool) { addInt8(buffer, 0x02); addInt8(buffer, 0x06); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Bool, 3)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Bool); @@ -769,8 +763,7 @@ TEST(CompactProtocolTest, ReadBool) { } TEST(CompactProtocolTest, ReadIntegerTypes) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Byte { @@ -850,11 +843,11 @@ TEST(CompactProtocolTest, ReadIntegerTypes) { addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0x0F}); // zigzag(0xFFFFFFFE) = 0x7FFFFFFF EXPECT_TRUE(proto.readInt32(buffer, value)); - EXPECT_EQ(value, INT32_MAX); + EXPECT_EQ(value, std::numeric_limits::max()); addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x0F}); // zigzag(0xFFFFFFFF) = 0x80000000 EXPECT_TRUE(proto.readInt32(buffer, value)); - EXPECT_EQ(value, INT32_MIN); + EXPECT_EQ(value, std::numeric_limits::min()); // More than 32 bits value = 1; @@ -883,12 +876,12 @@ TEST(CompactProtocolTest, ReadIntegerTypes) { // zigzag(0xFFFFFFFFFFFFFFFE) = 0x7FFFFFFFFFFFFFFF addSeq(buffer, {0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}); EXPECT_TRUE(proto.readInt64(buffer, value)); - EXPECT_EQ(value, INT64_MAX); + EXPECT_EQ(value, std::numeric_limits::max()); // zigzag(0xFFFFFFFFFFFFFFFF) = 0x8000000000000000 addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}); EXPECT_TRUE(proto.readInt64(buffer, value)); - EXPECT_EQ(value, INT64_MIN); + EXPECT_EQ(value, std::numeric_limits::min()); // More than 64 bits value = 1; @@ -901,8 +894,7 @@ TEST(CompactProtocolTest, ReadIntegerTypes) { } TEST(CompactProtocolTest, ReadDouble) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -932,8 +924,7 @@ TEST(CompactProtocolTest, ReadDouble) { } TEST(CompactProtocolTest, ReadString) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -962,7 +953,7 @@ TEST(CompactProtocolTest, ReadString) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x8); // zigzag(8) = 4 + addInt8(buffer, 0x4); EXPECT_FALSE(proto.readString(buffer, value)); EXPECT_EQ(value, "-"); @@ -974,12 +965,12 @@ TEST(CompactProtocolTest, ReadString) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x01); // zigzag(1) = -1 + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // -1 EXPECT_THROW_WITH_MESSAGE(proto.readString(buffer, value), EnvoyException, "negative compact protocol string/binary length -1"); EXPECT_EQ(value, "-"); - EXPECT_EQ(buffer.length(), 1); + EXPECT_EQ(buffer.length(), 5); } // empty string @@ -999,7 +990,7 @@ TEST(CompactProtocolTest, ReadString) { Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x0C); // zigzag(0x0C) = 0x06 + addInt8(buffer, 0x06); addString(buffer, "string"); EXPECT_TRUE(proto.readString(buffer, value)); @@ -1010,12 +1001,11 @@ TEST(CompactProtocolTest, ReadString) { TEST(CompactProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; std::string value = "-"; - addInt8(buffer, 0x0C); // zigzag(0x0C) = 0x06 + addInt8(buffer, 0x06); addString(buffer, "string"); EXPECT_TRUE(proto.readBinary(buffer, value)); @@ -1028,24 +1018,498 @@ class CompactProtocolFieldTypeTest : public TestWithParam {}; TEST_P(CompactProtocolFieldTypeTest, ConvertsToFieldType) { uint8_t compact_field_type = GetParam(); - NiceMock cb; - CompactProtocolImpl proto(cb); - Buffer::OwnedImpl buffer; + CompactProtocolImpl proto; std::string name = "-"; int8_t invalid_field_type = static_cast(FieldType::LastFieldType) + 1; FieldType field_type = static_cast(invalid_field_type); int16_t field_id = 0; - addInt8(buffer, compact_field_type); - addInt8(buffer, 0x02); // zigzag(2) = 1 + { + Buffer::OwnedImpl buffer; + addInt8(buffer, compact_field_type); + addInt8(buffer, 0x02); // zigzag(2) = 1 - EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); - EXPECT_LE(field_type, FieldType::LastFieldType); + EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); + EXPECT_LE(field_type, FieldType::LastFieldType); + } + + { + // Long form field header + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "-", field_type, 100); + if (field_type == FieldType::Bool) { + proto.writeBool(buffer, compact_field_type == 1); + } + + uint8_t* data = static_cast(buffer.linearize(1)); + EXPECT_NE(nullptr, data); + EXPECT_EQ(compact_field_type, *data); + } } INSTANTIATE_TEST_CASE_P(CompactFieldTypes, CompactProtocolFieldTypeTest, Values(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); +TEST(CompactProtocolTest, WriteMessageBegin) { + CompactProtocolImpl proto; + + // Named call + { + Buffer::OwnedImpl buffer; + proto.writeMessageBegin(buffer, "message", MessageType::Call, 1); + EXPECT_EQ(std::string("\x82\x21\x1\x7message", 11), buffer.toString()); + } + + // Unnamed oneway + { + Buffer::OwnedImpl buffer; + proto.writeMessageBegin(buffer, "", MessageType::Oneway, 2); + EXPECT_EQ(std::string("\x82\x81\x2\0", 4), buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteMessageEnd) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeMessageEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(CompactProtocolTest, WriteStruct) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + + proto.writeStructBegin(buffer, "unused"); + proto.writeStructEnd(buffer); + EXPECT_EQ(0, buffer.length()); + + // struct begin/end always appear in nested pairs + EXPECT_THROW_WITH_MESSAGE(proto.writeStructEnd(buffer), EnvoyException, + "invalid write of compact protocol struct end") +} + +TEST(CompactProtocolTest, WriteFieldBegin) { + // Stop field + { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Stop, 1); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + + { + CompactProtocolImpl proto; + + // Short form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::I32, 1); + EXPECT_EQ("\x15", buffer.toString()); + } + + // Long form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Struct, 17); + EXPECT_EQ(std::string("\xC\x22", 2), buffer.toString()); + } + + // Short form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Byte, 32); + EXPECT_EQ("\xF3", buffer.toString()); + } + + // Short form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::String, 33); + EXPECT_EQ("\x18", buffer.toString()); + } + } + + { + CompactProtocolImpl proto; + + // Long form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::I32, 16); + EXPECT_EQ(std::string("\x5\x20", 2), buffer.toString()); + } + + // Short form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Struct, 17); + EXPECT_EQ("\x1C", buffer.toString()); + } + + // Long form + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Byte, 33); + EXPECT_EQ(std::string("\x3\x42", 2), buffer.toString()); + } + + // Long form (3 bytes) + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::String, 64); + EXPECT_EQ(std::string("\x8\x80\x1", 3), buffer.toString()); + } + } + + // Unknown field type + { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + + int8_t invalid_field_type = static_cast(FieldType::LastFieldType) + 1; + FieldType field_type = static_cast(invalid_field_type); + + EXPECT_THROW_WITH_MESSAGE(proto.writeFieldBegin(buffer, "unused", field_type, 1), + EnvoyException, + fmt::format("unknown protocol field type {}", invalid_field_type)); + } +} + +TEST(CompactProtocolTest, WriteFieldEnd) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeFieldEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(CompactProtocolTest, WriteBoolField) { + // Boolean struct fields are encoded with custom types to save a byte + + // Short form field + { + CompactProtocolImpl proto; + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 8); + EXPECT_EQ(0, buffer.length()); + proto.writeBool(buffer, true); + EXPECT_EQ("\x81", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 12); + EXPECT_EQ(0, buffer.length()); + proto.writeBool(buffer, false); + EXPECT_EQ("\x42", buffer.toString()); + } + } + + // Long form field + { + CompactProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 16); + EXPECT_EQ(0, buffer.length()); + proto.writeBool(buffer, true); + EXPECT_EQ(std::string("\x1\x20", 2), buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 32); + EXPECT_EQ(0, buffer.length()); + proto.writeBool(buffer, false); + EXPECT_EQ(std::string("\x2\x40", 2), buffer.toString()); + } + } +} + +TEST(CompactProtocolTest, WriteMapBegin) { + CompactProtocolImpl proto; + + // Empty map + { + Buffer::OwnedImpl buffer; + proto.writeMapBegin(buffer, FieldType::I32, FieldType::Bool, 0); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } + + // Non-empty map + { + Buffer::OwnedImpl buffer; + proto.writeMapBegin(buffer, FieldType::I32, FieldType::Bool, 3); + EXPECT_EQ("\3\x51", buffer.toString()); + } + + // Oversized map + { + Buffer::OwnedImpl buffer; + EXPECT_THROW_WITH_MESSAGE( + proto.writeMapBegin(buffer, FieldType::I32, FieldType::Bool, 3000000000), EnvoyException, + "illegal compact protocol map size 3000000000"); + } +} + +TEST(CompactProtocolTest, WriteMapEnd) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeMapEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(CompactProtocolTest, WriteListBegin) { + CompactProtocolImpl proto; + + // Empty list + { + Buffer::OwnedImpl buffer; + proto.writeListBegin(buffer, FieldType::I32, 0); + EXPECT_EQ("\x5", buffer.toString()); + } + + // List (short form) + { + Buffer::OwnedImpl buffer; + proto.writeListBegin(buffer, FieldType::I32, 14); + EXPECT_EQ("\xE5", buffer.toString()); + } + + // List (long form) + { + Buffer::OwnedImpl buffer; + proto.writeListBegin(buffer, FieldType::Bool, 15); + EXPECT_EQ("\xF1\xF", buffer.toString()); + } + + // Oversized list + { + Buffer::OwnedImpl buffer; + EXPECT_THROW_WITH_MESSAGE(proto.writeListBegin(buffer, FieldType::I32, 3000000000), + EnvoyException, "illegal compact protocol list/set size 3000000000"); + } +} + +TEST(CompactProtocolTest, WriteListEnd) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeListEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(CompactProtocolTest, WriteSetBegin) { + CompactProtocolImpl proto; + + // Empty set only, as writeSetBegin delegates to writeListBegin. + Buffer::OwnedImpl buffer; + proto.writeSetBegin(buffer, FieldType::I32, 0); + EXPECT_EQ("\x5", buffer.toString()); +} + +TEST(CompactProtocolTest, WriteSetEnd) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeSetEnd(buffer); + EXPECT_EQ(0, buffer.length()); +} + +TEST(CompactProtocolTest, WriteBool) { + CompactProtocolImpl proto; + + // Non-field bools (see WriteBoolField test) + { + Buffer::OwnedImpl buffer; + proto.writeBool(buffer, true); + EXPECT_EQ("\x1", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeBool(buffer, false); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteByte) { + CompactProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeByte(buffer, -1); + EXPECT_EQ("\xFF", buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeByte(buffer, 127); + EXPECT_EQ("\x7F", buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteInt16) { + CompactProtocolImpl proto; + + // zigzag(1) = 2 + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, 1); + EXPECT_EQ("\x2", buffer.toString()); + } + + // zigzag(128) = 256 (0x200) + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, 128); + EXPECT_EQ("\x80\x2", buffer.toString()); + } + + // zigzag(-1) = 1 + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, -1); + EXPECT_EQ("\x1", buffer.toString()); + } + + // zigzag(32767) = 65534 (0xFFFE) + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFE\xFF\x3", buffer.toString()); + } + + // zigzag(-32768) = 65535 (0xFFFF) + { + Buffer::OwnedImpl buffer; + proto.writeInt16(buffer, std::numeric_limits::min()); + EXPECT_EQ("\xFF\xFF\x3", buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteInt32) { + CompactProtocolImpl proto; + + // zigzag(1) = 2 + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, 1); + EXPECT_EQ("\x2", buffer.toString()); + } + + // zigzag(128) = 256 (0x200) + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, 128); + EXPECT_EQ("\x80\x2", buffer.toString()); + } + + // zigzag(-1) = 1 + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, -1); + EXPECT_EQ("\x1", buffer.toString()); + } + + // zigzag(0x7FFFFFFF) = 0xFFFFFFFE + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFE\xFF\xFF\xFF\xF", buffer.toString()); + } + + // zigzag(0x80000000) = 0xFFFFFFFF + { + Buffer::OwnedImpl buffer; + proto.writeInt32(buffer, std::numeric_limits::min()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xF", buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteInt64) { + CompactProtocolImpl proto; + + // zigzag(1) = 2 + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, 1); + EXPECT_EQ("\x2", buffer.toString()); + } + + // zigzag(128) = 256 (0x200) + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, 128); + EXPECT_EQ("\x80\x2", buffer.toString()); + } + + // zigzag(-1) = 1 + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, -1); + EXPECT_EQ("\x1", buffer.toString()); + } + + // zigzag(0x7FFFFFFF FFFFFFFF) = 0xFFFFFFFF FFFFFFFE + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, std::numeric_limits::max()); + EXPECT_EQ("\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1", buffer.toString()); + } + + // zigzag(0x80000000 00000000) = 0xFFFFFFFF FFFFFFFF + { + Buffer::OwnedImpl buffer; + proto.writeInt64(buffer, std::numeric_limits::min()); + EXPECT_EQ("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x1", buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteDouble) { + CompactProtocolImpl proto; + Buffer::OwnedImpl buffer; + proto.writeDouble(buffer, 3.0); + EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); +} + +TEST(CompactProtocolTest, WriteString) { + CompactProtocolImpl proto; + + { + Buffer::OwnedImpl buffer; + proto.writeString(buffer, "abc"); + EXPECT_EQ(std::string("\x3" + "abc", + 4), + buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + std::string data(192, 'a'); + proto.writeString(buffer, data); + EXPECT_EQ(std::string("\xC0\x1") + data, buffer.toString()); + } + + { + Buffer::OwnedImpl buffer; + proto.writeString(buffer, ""); + EXPECT_EQ(std::string("\0", 1), buffer.toString()); + } +} + +TEST(CompactProtocolTest, WriteBinary) { + CompactProtocolImpl proto; + + // writeBinary is an alias for writeString + Buffer::OwnedImpl buffer; + proto.writeBinary(buffer, "abc"); + EXPECT_EQ(std::string("\x3" + "abc", + 4), + buffer.toString()); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/config_test.cc b/test/extensions/filters/network/thrift_proxy/config_test.cc index 3047282f9c771..220ce5c3fddf6 100644 --- a/test/extensions/filters/network/thrift_proxy/config_test.cc +++ b/test/extensions/filters/network/thrift_proxy/config_test.cc @@ -1,4 +1,4 @@ -#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.validate.h" +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.pb.validate.h" #include "extensions/filters/network/thrift_proxy/config.h" @@ -16,14 +16,13 @@ namespace ThriftProxy { TEST(ThriftFilterConfigTest, ValidateFail) { NiceMock context; - EXPECT_THROW( - ThriftProxyFilterConfigFactory().createFilterFactoryFromProto( - envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy(), context), - ProtoValidationException); + EXPECT_THROW(ThriftProxyFilterConfigFactory().createFilterFactoryFromProto( + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy(), context), + ProtoValidationException); } TEST(ThriftFilterConfigTest, ValidProtoConfiguration) { - envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy config{}; + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy config{}; config.set_stat_prefix("my_stat_prefix"); @@ -31,21 +30,21 @@ TEST(ThriftFilterConfigTest, ValidProtoConfiguration) { ThriftProxyFilterConfigFactory factory; Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); Network::MockConnection connection; - EXPECT_CALL(connection, addFilter(_)); + EXPECT_CALL(connection, addReadFilter(_)); cb(connection); } TEST(ThriftFilterConfigTest, ThriftProxyWithEmptyProto) { NiceMock context; ThriftProxyFilterConfigFactory factory; - envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy config = - *dynamic_cast( + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy config = + *dynamic_cast( factory.createEmptyConfigProto().get()); config.set_stat_prefix("my_stat_prefix"); Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); Network::MockConnection connection; - EXPECT_CALL(connection, addFilter(_)); + EXPECT_CALL(connection, addReadFilter(_)); cb(connection); } diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc new file mode 100644 index 0000000000000..accae1a9e3c75 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -0,0 +1,759 @@ +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" + +#include "common/buffer/buffer_impl.h" +#include "common/stats/stats_impl.h" + +#include "extensions/filters/network/thrift_proxy/buffer_helper.h" +#include "extensions/filters/network/thrift_proxy/config.h" +#include "extensions/filters/network/thrift_proxy/conn_manager.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Invoke; +using testing::NiceMock; +using testing::Return; +using testing::ReturnRef; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class TestConfigImpl : public ConfigImpl { +public: + TestConfigImpl(envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy proto_config, + Server::Configuration::MockFactoryContext& context, + ThriftFilters::DecoderFilterSharedPtr decoder_filter, ThriftFilterStats& stats) + : ConfigImpl(proto_config, context), decoder_filter_(decoder_filter), stats_(stats) {} + + // ConfigImpl + ThriftFilterStats& stats() override { return stats_; } + void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { + callbacks.addDecoderFilter(decoder_filter_); + } + +private: + ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ThriftFilterStats& stats_; +}; + +class ThriftConnectionManagerTest : public testing::Test { +public: + ThriftConnectionManagerTest() : stats_(ThriftFilterStats::generateStats("test.", store_)) {} + ~ThriftConnectionManagerTest() { + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + void initializeFilter() { initializeFilter(""); } + + void initializeFilter(const std::string& yaml) { + // Destroy any existing filter first. + filter_ = nullptr; + + for (auto counter : store_.counters()) { + counter->reset(); + } + + if (yaml.empty()) { + proto_config_.set_stat_prefix("test"); + } else { + MessageUtil::loadFromYaml(yaml, proto_config_); + MessageUtil::validate(proto_config_); + } + + proto_config_.set_stat_prefix("test"); + + decoder_filter_.reset(new NiceMock()); + config_.reset(new TestConfigImpl(proto_config_, context_, decoder_filter_, stats_)); + + filter_.reset(new ConnectionManager(*config_)); + filter_->initializeReadFilterCallbacks(filter_callbacks_); + filter_->onNewConnection(); + + // NOP currently. + filter_->onAboveWriteBufferHighWatermark(); + filter_->onBelowWriteBufferLowWatermark(); + } + + void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", msg_type, seq_id); + proto->writeStructBegin(msg, "response"); + proto->writeFieldBegin(msg, "success", FieldType::String, 0); + proto->writeString(msg, "field"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writeComplexFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, + int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", msg_type, seq_id); + proto->writeStructBegin(msg, "wrapper"); // call args struct or response struct + proto->writeFieldBegin(msg, "wrapper_field", FieldType::Struct, 0); // call arg/response success + + proto->writeStructBegin(msg, "payload"); + proto->writeFieldBegin(msg, "f1", FieldType::Bool, 1); + proto->writeBool(msg, true); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f2", FieldType::Byte, 2); + proto->writeByte(msg, 2); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f3", FieldType::Double, 3); + proto->writeDouble(msg, 3.0); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f4", FieldType::I16, 4); + proto->writeInt16(msg, 4); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f5", FieldType::I32, 5); + proto->writeInt32(msg, 5); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f6", FieldType::I64, 6); + proto->writeInt64(msg, 6); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f7", FieldType::String, 7); + proto->writeString(msg, "seven"); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f8", FieldType::Map, 8); + proto->writeMapBegin(msg, FieldType::I32, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeInt32(msg, 8); + proto->writeMapEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f9", FieldType::List, 9); + proto->writeListBegin(msg, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeListEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f10", FieldType::Set, 10); + proto->writeSetBegin(msg, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeSetEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); // payload stop field + proto->writeStructEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); // wrapper stop field + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, + int32_t seq_id, bool start) { + Buffer::OwnedImpl frame; + writeFramedBinaryMessage(frame, msg_type, seq_id); + + if (start) { + buffer.move(frame, 27); + } else { + frame.drain(27); + buffer.move(frame); + } + } + + void writeFramedBinaryTApplicationException(Buffer::Instance& buffer, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", MessageType::Exception, seq_id); + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::String, 1); + proto->writeString(msg, "error"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::I32, 2); + proto->writeInt32(msg, 1); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", MessageType::Reply, seq_id); + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::Struct, 2); + + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::String, 1); + proto->writeString(msg, "err"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + NiceMock context_; + std::shared_ptr decoder_filter_; + Stats::IsolatedStoreImpl store_; + ThriftFilterStats stats_; + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy proto_config_; + + std::unique_ptr config_; + + Buffer::OwnedImpl buffer_; + Buffer::OwnedImpl write_buffer_; + std::unique_ptr filter_; + NiceMock filter_callbacks_; +}; + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftOneWay) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesStopIterationAndResume) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + // Nothing further happens: we're stopped. + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_EQ(1, callbacks->streamId()); + EXPECT_EQ(TransportType::Framed, callbacks->downstreamTransportType()); + EXPECT_EQ(ProtocolType::Binary, callbacks->downstreamProtocolType()); + EXPECT_EQ(&filter_callbacks_.connection_, callbacks->connection()); + + // Resume processing. + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + callbacks->continueDecoding(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesFrameSplitAcrossBuffers) { + initializeFilter(); + + writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + // Complete the buffer + writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesInvalidMsgType) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Reply, 0x0F); // reply is invalid for a request + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(1U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolError) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1f, // framed: 31 bytes + 0x80, 0x01, 0x00, 0x01, // binary, call + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x08, 0xff, 0xff // illegal field id + }); + + std::string err = "invalid binary protocol field id -1"; + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x42, // framed: 66 bytes + 0x80, 0x01, 0x00, 0x03, // binary, exception + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x0b, 0x00, 0x01, // begin string field + }); + addInt32(write_buffer_, err.length()); + addString(write_buffer_, err); + addSeq(write_buffer_, { + 0x08, 0x00, 0x02, // begin i32 field + 0x00, 0x00, 0x00, 0x07, // protocol error + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ(bufferToString(write_buffer_), bufferToString(buffer)); + })); + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolErrorDuringMessageBegin) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0xff, // binary, invalid type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnEvent) { + // No active calls + { + initializeFilter(); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + EXPECT_EQ(0U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + } + + // Remote close mid-request + { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0x01, // binary proto, call type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x0F, // seq id + }); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Local close mid-request + { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0x01, // binary proto, call type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x0F, // seq id + }); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Remote close before response + { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Local close before response + { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } +} + +TEST_F(ThriftConnectionManagerTest, Routing) { + const std::string yaml = R"EOF( +transport: FRAMED +protocol: BINARY +stat_prefix: test +route_config: + name: "routes" + routes: + - match: + method: name + route: + cluster: cluster +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + Router::RouteConstSharedPtr route = callbacks->route(); + EXPECT_NE(nullptr, route); + EXPECT_NE(nullptr, route->routeEntry()); + EXPECT_EQ("cluster", route->routeEntry()->clusterName()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + callbacks->continueDecoding(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { + initializeFilter(); + writeComplexFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(1U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryTApplicationException(write_buffer_, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryIDLException(write_buffer_, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(1U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // Call is not valid in a response + writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // illegal field id + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x1f, // framed: 31 bytes + 0x80, 0x01, 0x00, 0x02, // binary, reply + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x08, 0xff, 0xff // illegal field id + }); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(*decoder_filter_, resetUpstreamConnection()); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x01); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x02); + + std::list callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillRepeatedly(Invoke( + [&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks.push_back(&cb); })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(2U, store_.gauge("test.request_active").value()); + EXPECT_EQ(2U, store_.counter("test.request").value()); + EXPECT_EQ(2U, store_.counter("test.request_call").value()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(2); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); + callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); + callbacks.pop_front(); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x02); + callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); + callbacks.pop_front(); + EXPECT_EQ(2U, store_.counter("test.response").value()); + EXPECT_EQ(2U, store_.counter("test.response_reply").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + callbacks->resetDownstreamConnection(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 751da9c8e9430..9762fddfaf007 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -3,9 +3,11 @@ #include "extensions/filters/network/thrift_proxy/decoder.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" +#include "absl/strings/string_view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -21,6 +23,7 @@ using testing::Return; using testing::ReturnRef; using testing::SetArgReferee; using testing::StrictMock; +using testing::TestParamInfo; using testing::TestWithParam; using testing::Values; using testing::_; @@ -31,75 +34,129 @@ namespace NetworkFilters { namespace ThriftProxy { namespace { -Expectation expectValue(NiceMock& proto, FieldType field_type, bool result = true) { +ExpectationSet expectValue(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type, bool result = true) { + ExpectationSet s; switch (field_type) { case FieldType::Bool: - return EXPECT_CALL(proto, readBool(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readBool(_, _)).WillOnce(Return(result)); + if (result) { + s += + EXPECT_CALL(filter, boolValue(_)).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::Byte: - return EXPECT_CALL(proto, readByte(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readByte(_, _)).WillOnce(Return(result)); + if (result) { + s += + EXPECT_CALL(filter, byteValue(_)).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::Double: - return EXPECT_CALL(proto, readDouble(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readDouble(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, doubleValue(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I16: - return EXPECT_CALL(proto, readInt16(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt16(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int16Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I32: - return EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int32Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I64: - return EXPECT_CALL(proto, readInt64(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt64(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int64Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::String: - return EXPECT_CALL(proto, readString(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readString(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, stringValue(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } + return s; } -ExpectationSet expectContainerStart(NiceMock& proto, FieldType field_type, - FieldType inner_type) { +ExpectationSet expectContainerStart(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type, FieldType inner_type) { ExpectationSet s; switch (field_type) { case FieldType::Struct: s += EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); s += EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(inner_type), SetArgReferee<3>(1), Return(true))); + s += EXPECT_CALL(filter, fieldBegin(absl::string_view(), inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::List: s += EXPECT_CALL(proto, readListBegin(_, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(1), Return(true))); + s += EXPECT_CALL(filter, listBegin(inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Map: s += EXPECT_CALL(proto, readMapBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(inner_type), SetArgReferee<3>(1), Return(true))); + s += EXPECT_CALL(filter, mapBegin(inner_type, inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Set: s += EXPECT_CALL(proto, readSetBegin(_, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(1), Return(true))); + s += EXPECT_CALL(filter, setBegin(inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } return s; } -ExpectationSet expectContainerEnd(NiceMock& proto, FieldType field_type) { +ExpectationSet expectContainerEnd(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type) { ExpectationSet s; switch (field_type) { case FieldType::Struct: s += EXPECT_CALL(proto, readFieldEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); s += EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); s += EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::List: s += EXPECT_CALL(proto, readListEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, listEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Map: s += EXPECT_CALL(proto, readMapEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, mapEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Set: s += EXPECT_CALL(proto, readSetEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, setEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } return s; } @@ -108,35 +165,51 @@ ExpectationSet expectContainerEnd(NiceMock& proto, FieldType field class DecoderStateMachineNonValueTest : public TestWithParam {}; +static std::string protoStateParamToString(const TestParamInfo& params) { + return ProtocolStateNameValues::name(params.param); +} + INSTANTIATE_TEST_CASE_P(NonValueProtocolStates, DecoderStateMachineNonValueTest, Values(ProtocolState::MessageBegin, ProtocolState::MessageEnd, ProtocolState::StructBegin, ProtocolState::StructEnd, ProtocolState::FieldBegin, ProtocolState::FieldEnd, ProtocolState::MapBegin, ProtocolState::MapEnd, ProtocolState::ListBegin, ProtocolState::ListEnd, - ProtocolState::SetBegin, ProtocolState::SetEnd)); + ProtocolState::SetBegin, ProtocolState::SetEnd), + protoStateParamToString); class DecoderStateMachineValueTest : public TestWithParam {}; INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, DecoderStateMachineValueTest, Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, - FieldType::I32, FieldType::I64, FieldType::String)); + FieldType::I32, FieldType::I64, FieldType::String), + fieldTypeParamToString); class DecoderStateMachineNestingTest : public TestWithParam> {}; +static std::string nestedFieldTypesParamToString( + const TestParamInfo>& params) { + FieldType outer_field_type, inner_type, value_type; + std::tie(outer_field_type, inner_type, value_type) = params.param; + return fmt::format("{}Of{}Of{}", fieldTypeToString(outer_field_type), + fieldTypeToString(inner_type), fieldTypeToString(value_type)); +} + INSTANTIATE_TEST_CASE_P( NestedTypes, DecoderStateMachineNestingTest, Combine(Values(FieldType::Struct, FieldType::List, FieldType::Map, FieldType::Set), Values(FieldType::Struct, FieldType::List, FieldType::Map, FieldType::Set), Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, - FieldType::I32, FieldType::I64, FieldType::String))); + FieldType::I32, FieldType::I64, FieldType::String)), + nestedFieldTypesParamToString); TEST_P(DecoderStateMachineNonValueTest, NoData) { ProtocolState state = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; - DecoderStateMachine dsm(proto); + StrictMock filter; + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(state); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), state); @@ -147,17 +220,18 @@ TEST_P(DecoderStateMachineValueTest, NoFieldValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type, false); - expectValue(proto, field_type, true); + expectValue(proto, filter, field_type, false); + expectValue(proto, filter, field_type, true); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -171,18 +245,19 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -192,13 +267,14 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { TEST(DecoderStateMachineTest, NoListValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -208,13 +284,14 @@ TEST(DecoderStateMachineTest, NoListValueData) { TEST(DecoderStateMachineTest, EmptyList) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -225,16 +302,17 @@ TEST_P(DecoderStateMachineValueTest, ListValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -245,18 +323,19 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, field_type); + expectValue(proto, filter, field_type); } EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -266,6 +345,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { TEST(DecoderStateMachineTest, NoMapKeyData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -273,7 +353,7 @@ TEST(DecoderStateMachineTest, NoMapKeyData) { SetArgReferee<3>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -283,6 +363,7 @@ TEST(DecoderStateMachineTest, NoMapKeyData) { TEST(DecoderStateMachineTest, NoMapValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -291,7 +372,7 @@ TEST(DecoderStateMachineTest, NoMapValueData) { EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(proto, readString(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -301,6 +382,7 @@ TEST(DecoderStateMachineTest, NoMapValueData) { TEST(DecoderStateMachineTest, EmptyMap) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -308,7 +390,7 @@ TEST(DecoderStateMachineTest, EmptyMap) { SetArgReferee<3>(0), Return(true))); EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -319,18 +401,19 @@ TEST_P(DecoderStateMachineValueTest, MapKeyValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type); // key - expectValue(proto, FieldType::String); // value + expectValue(proto, filter, field_type); // key + expectValue(proto, filter, FieldType::String); // value EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -341,18 +424,19 @@ TEST_P(DecoderStateMachineValueTest, MapValueValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, FieldType::I32); // key - expectValue(proto, field_type); // value + expectValue(proto, filter, FieldType::I32); // key + expectValue(proto, filter, field_type); // value EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -363,6 +447,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -370,13 +455,13 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { SetArgReferee<3>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, FieldType::I32); // key - expectValue(proto, field_type); // value + expectValue(proto, filter, FieldType::I32); // key + expectValue(proto, filter, field_type); // value } EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -386,13 +471,14 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { TEST(DecoderStateMachineTest, NoSetValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -402,13 +488,14 @@ TEST(DecoderStateMachineTest, NoSetValueData) { TEST(DecoderStateMachineTest, EmptySet) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -419,16 +506,17 @@ TEST_P(DecoderStateMachineValueTest, SetValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -439,18 +527,19 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, field_type); + expectValue(proto, filter, field_type); } EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -460,6 +549,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { TEST(DecoderStateMachineTest, EmptyStruct) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) @@ -471,7 +561,7 @@ TEST(DecoderStateMachineTest, EmptyStruct) { EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -481,24 +571,39 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -507,6 +612,7 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { TEST(DecoderStateMachineTest, MultiFieldStruct) { Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; std::vector field_types = {FieldType::Bool, FieldType::Byte, FieldType::Double, @@ -516,24 +622,36 @@ TEST(DecoderStateMachineTest, MultiFieldStruct) { EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); int16_t field_id = 1; for (FieldType field_type : field_types) { EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id++), Return(true))); + .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, field_id)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + field_id++; - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); } EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -545,35 +663,41 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; // start of message and outermost struct EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); - expectContainerStart(proto, FieldType::Struct, outer_field_type); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + expectContainerStart(proto, filter, FieldType::Struct, outer_field_type); - expectContainerStart(proto, outer_field_type, inner_type); + expectContainerStart(proto, filter, outer_field_type, inner_type); int outer_reps = outer_field_type == FieldType::Map ? 2 : 1; for (int i = 0; i < outer_reps; i++) { - expectContainerStart(proto, inner_type, value_type); + expectContainerStart(proto, filter, inner_type, value_type); int inner_reps = inner_type == FieldType::Map ? 2 : 1; for (int j = 0; j < inner_reps; j++) { - expectValue(proto, value_type); + expectValue(proto, filter, value_type); } - expectContainerEnd(proto, inner_type); + expectContainerEnd(proto, filter, inner_type); } - expectContainerEnd(proto, outer_field_type); + expectContainerEnd(proto, filter, outer_field_type); // end of message and outermost struct - expectContainerEnd(proto, FieldType::Struct); + expectContainerEnd(proto, filter, FieldType::Struct); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -582,39 +706,68 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { TEST(DecoderTest, OnData) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(filter, transportBegin(absl::optional(100))) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - decoder.onData(buffer); + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); } TEST(DecoderTest, OnDataResumes) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; + buffer.add("x"); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); - decoder.onData(buffer); + + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) @@ -622,23 +775,66 @@ TEST(DecoderTest, OnDataResumes) { EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(false)); - decoder.onData(buffer); + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); // buffer.length() == 1 +} + +TEST(DecoderTest, OnDataResumesTransportFrameStart) { + StrictMock* transport = new StrictMock(); + StrictMock* proto = new StrictMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + EXPECT_CALL(*transport, name()).Times(AnyNumber()); + EXPECT_CALL(*proto, name()).Times(AnyNumber()); + + InSequence dummy; + + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = false; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(false))); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), + SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + + underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 } TEST(DecoderTest, OnDataResumesTransportFrameEnd) { StrictMock* transport = new StrictMock(); StrictMock* proto = new StrictMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); EXPECT_CALL(*transport, name()).Times(AnyNumber()); EXPECT_CALL(*proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); @@ -648,11 +844,90 @@ TEST(DecoderTest, OnDataResumesTransportFrameEnd) { EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(false)); - decoder.onData(buffer); + + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(false)); - decoder.onData(buffer); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 +} + +TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { + + StrictMock* transport = new StrictMock(); + EXPECT_CALL(*transport, name()).WillRepeatedly(ReturnRef(transport->name_)); + + StrictMock* proto = new StrictMock(); + EXPECT_CALL(*proto, name()).WillRepeatedly(ReturnRef(proto->name_)); + + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + InSequence dummy; + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = true; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(filter, transportBegin(absl::optional(100))) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), + SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::I32), SetArgReferee<3>(1), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), FieldType::I32, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readInt32(_, _)).WillOnce(Return(true)); + EXPECT_CALL(filter, int32Value(_)).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); } #define TEST_NAME(X) EXPECT_EQ(ProtocolStateNameValues::name(ProtocolState::X), #X); diff --git a/test/extensions/filters/network/thrift_proxy/driver/BUILD b/test/extensions/filters/network/thrift_proxy/driver/BUILD new file mode 100644 index 0000000000000..beed670c9e20e --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +filegroup( + name = "generate_fixture", + srcs = ["generate_fixture.sh"], + data = [ + ":client", + ":server", + ], +) + +py_binary( + name = "client", + srcs = ["client.py"], + deps = [ + "//test/extensions/filters/network/thrift_proxy/driver/fbthrift:fbthrift_lib", + "//test/extensions/filters/network/thrift_proxy/driver/finagle:finagle_lib", + "//test/extensions/filters/network/thrift_proxy/driver/generated/example:example_lib", + "@com_github_twitter_common_rpc//:twitter_common_rpc", + ], +) + +py_binary( + name = "server", + srcs = ["server.py"], + deps = [ + "//test/extensions/filters/network/thrift_proxy/driver/fbthrift:fbthrift_lib", + "//test/extensions/filters/network/thrift_proxy/driver/finagle:finagle_lib", + "//test/extensions/filters/network/thrift_proxy/driver/generated/example:example_lib", + "@com_github_twitter_common_rpc//:twitter_common_rpc", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/README.md b/test/extensions/filters/network/thrift_proxy/driver/README.md new file mode 100644 index 0000000000000..251d1542abdc5 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/README.md @@ -0,0 +1,33 @@ +Thrift Integration Test Driver +============================== + +The code in this package provides `client.py` and `server.py` which +can be used as a thrift client and server pair. Both scripts support +all the Thrift transport and protocol variations that Envoy's Thrift +proxy supports (or will eventually support): + +Transports: framed, unframed, header +Protocols: binary, compact, json, ttwitter (e.g., finagle-thrift) + +The client script can be configured to write its request and the +server's response to a file. The server script can be configured to +return successful responses, IDL-defined exceptions, or server +(application) exceptions. + +Envoy's thrift_proxy integration tests use the `generate_fixtures.sh` +script to create request and response files for various combinations +of transport, protocol, service multiplexing. In addition, the +integration tests generate IDL and application exception responses. +The generated data is used with the Envoy's integration test +infrastructure to simulate downstream and upstream connections. +Generated files are used instead of running the client and server +scripts directly to eliminate the need to select a Thrift upstream +server port (or determine its self-selected port). + +Regenerating example.thrift +--------------------------- + +Install the Apache thrift library (from source or a package) so that +the `thrift` command is available. The `generate_bindings.sh` script +will regenerate the Python bindings which are checked into the +repository. diff --git a/test/extensions/filters/network/thrift_proxy/driver/client.py b/test/extensions/filters/network/thrift_proxy/driver/client.py new file mode 100755 index 0000000000000..bbc1293cee55b --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/client.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python + +import argparse +import io +import sys + +from generated.example import Example +from generated.example.ttypes import ( + Param, TheWorks, AppException +) + +from thrift import Thrift +from thrift.protocol import ( + TBinaryProtocol, TCompactProtocol, TJSONProtocol, TMultiplexedProtocol +) +from thrift.transport import TSocket +from thrift.transport import TTransport +from fbthrift import THeaderTransport +from twitter.common.rpc.finagle.protocol import TFinagleProtocol + + +class TRecordingTransport(TTransport.TTransportBase): + def __init__(self, underlying, writehandle, readhandle): + self._underlying = underlying + self._whandle = writehandle + self._rhandle = readhandle + + def isOpen(self): + return self._underlying.isOpen() + + def open(self): + if not self._underlying.isOpen(): + self._underlying.open() + + def close(self): + self._underlying.close() + self._whandle.close() + self._rhandle.close() + + def read(self, sz): + buf = self._underlying.read(sz) + if len(buf) != 0: + self._rhandle.write(buf) + return buf + + def write(self, buf): + if len(buf) != 0: + self._whandle.write(buf) + self._underlying.write(buf) + + def flush(self): + self._underlying.flush() + self._whandle.flush() + self._rhandle.flush() + + +def main(cfg, reqhandle, resphandle): + if cfg.unix: + if cfg.addr == "": + sys.exit("invalid unix domain socket: {}".format(cfg.addr)) + socket = TSocket.TSocket(unix_socket=cfg.addr) + else: + try: + (host, port) = cfg.addr.rsplit(":", 1) + if host == "": + host = "localhost" + socket = TSocket.TSocket(host=host, port=int(port)) + except ValueError: + sys.exit("invalid address: {}".format(cfg.addr)) + + transport = TRecordingTransport(socket, reqhandle, resphandle) + + if cfg.transport == "framed": + transport = TTransport.TFramedTransport(transport) + elif cfg.transport == "unframed": + transport = TTransport.TBufferedTransport(transport) + elif cfg.transport == "header": + transport = THeaderTransport.THeaderTransport( + transport, + client_type=THeaderTransport.CLIENT_TYPE.HEADER, + ) + else: + sys.exit("unknown transport {0}".format(cfg.transport)) + + transport.open() + + if cfg.protocol == "binary": + protocol = TBinaryProtocol.TBinaryProtocol(transport) + elif cfg.protocol == "compact": + protocol = TCompactProtocol.TCompactProtocol(transport) + elif cfg.protocol == "json": + protocol = TJSONProtocol.TJSONProtocol(transport) + elif cfg.protocol == "finagle": + protocol = TFinagleProtocol(transport, client_id="thrift-playground") + else: + sys.exit("unknown protocol {0}".format(cfg.protocol)) + + if cfg.service is not None: + protocol = TMultiplexedProtocol.TMultiplexedProtocol(protocol, cfg.service) + + client = Example.Client(protocol) + + try: + if cfg.method == "ping": + client.ping() + print("client: pinged") + elif cfg.method == "poke": + client.poke() + print("client: poked") + elif cfg.method == "add": + if len(cfg.params) != 2: + sys.exit("add takes 2 arguments, got: {0}".format(cfg.params)) + + a = int(cfg.params[0]) + b = int(cfg.params[1]) + v = client.add(a, b) + print("client: added {0} + {1} = {2}".format(a, b, v)) + elif cfg.method == "execute": + param = Param( + return_fields=cfg.params, + the_works=TheWorks( + field_1=True, + field_2=0x7f, + field_3=0x7fff, + field_4=0x7fffffff, + field_5=0x7fffffffffffffff, + field_6=-1.5, + field_7=u"string is UTF-8: \U0001f60e", + field_8=b"binary is bytes: \x80\x7f\x00\x01", + field_9={1: "one", 2: "two", 3: "three"}, + field_10=[1, 2, 4, 8], + field_11=set(["a", "b", "c"]), + field_12=False, + ) + ) + + try: + result = client.execute(param) + print("client: executed {0}: {1}".format(param, result)) + except AppException as e: + print("client: execute failed with IDL Exception: {0}".format(e.why)) + else: + sys.exit("unknown method {0}".format(cfg.method)) + except Thrift.TApplicationException as e: + print("client exception: {0}: {1}".format(e.type, e.message)) + + if cfg.request is None: + req = "".join( [ "%02X " % ord( x ) for x in reqhandle.getvalue() ] ).strip() + print("request: {}".format(req)) + if cfg.response is None: + resp = "".join( [ "%02X " % ord( x ) for x in resphandle.getvalue() ] ).strip() + print("response: {}".format(resp)) + + transport.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Thrift client tool.", + ) + parser.add_argument( + "method", + metavar="METHOD", + help="Name of the service method to invoke.", + ) + parser.add_argument( + "params", + metavar="PARAMS", + nargs="*", + help="Method parameters", + ) + parser.add_argument( + "-a", + "--addr", + metavar="ADDR", + dest="addr", + required=True, + help="Target address for requests in the form host:port. The host is optional. If --unix" + + " is set, the address is the socket name.", + ) + parser.add_argument( + "-m", + "--multiplex", + metavar="SERVICE", + dest="service", + help="Enable service multiplexing and set the service name.", + ) + parser.add_argument( + "-p", + "--protocol", + dest="protocol", + default="binary", + choices=["binary", "compact", "json", "finagle"], + help="selects a protocol.", + ) + parser.add_argument( + "--request", + metavar="FILE", + dest="request", + help="Writes the Thrift request to a file.", + ) + parser.add_argument( + "--response", + metavar="FILE", + dest="response", + help="Writes the Thrift response to a file.", + ) + parser.add_argument( + "-t", + "--transport", + dest="transport", + default="framed", + choices=["framed", "unframed", "header"], + help="selects a transport.", + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + action="store_true", + ) + cfg = parser.parse_args() + + reqhandle = io.BytesIO() + resphandle = io.BytesIO() + if cfg.request is not None: + try: + reqhandle = io.open(cfg.request, "wb") + except IOError as e: + sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) + if cfg.response is not None: + try: + resphandle = io.open(cfg.response, "wb") + except IOError as e: + sys.exit("I/O error({0}): {1}".format(e.errno, e.strerror)) + try: + main(cfg, reqhandle, resphandle) + except Thrift.TException as tx: + sys.exit("Unhandled Thrift Exception: {0}".format(tx.message)) diff --git a/test/extensions/filters/network/thrift_proxy/driver/example.thrift b/test/extensions/filters/network/thrift_proxy/driver/example.thrift new file mode 100644 index 0000000000000..eda22a8d7d1ee --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/example.thrift @@ -0,0 +1,40 @@ +// TheWorks contains one instance of each type of field. Envoy does not +// concern itself with the optionality of fields, so we leave it +// defaulted. +struct TheWorks { + 1: bool field_1, + 2: i8 field_2, + 3: i16 field_3, + 4: i32 field_4, + 5: i64 field_5, + 6: double field_6, + 7: string field_7, + 8: binary field_8, + 9: map field_9, + 10: list field_10, + 11: set field_11, + 12: bool field_12, +} + +struct Param { + 1: list return_fields, + 2: TheWorks the_works, +} + +struct Result { + 1: TheWorks the_works, +} + +exception AppException { + 1: string why, +} + +service Example { + void ping(), + + oneway void poke(), + + i32 add(1:i32 a, 2:i32 b), + + Result execute(1:Param input) throws (1:AppException appex), +} diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/BUILD b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/BUILD new file mode 100644 index 0000000000000..a1b33006f10f3 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/BUILD @@ -0,0 +1,16 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +py_library( + name = "fbthrift_lib", + srcs = [ + "THeaderTransport.py", + "__init__.py", + ], + deps = [ + "@com_github_apache_thrift//:apache_thrift", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py new file mode 100644 index 0000000000000..cba5ec0651d9b --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py @@ -0,0 +1,662 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# INFO:(zuercher): Adapted from +# https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/transport/THeaderTransport.py + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import sys +if sys.version_info[0] >= 3: + from http import server + BaseHTTPServer = server + xrange = range + from io import BytesIO as StringIO + PY3 = True +else: + import BaseHTTPServer + from cStringIO import StringIO + PY3 = False + +from struct import pack, unpack +import zlib + +from thrift.Thrift import TApplicationException +from thrift.transport.TTransport import TTransportException, TTransportBase, CReadableTransport + +# INFO:(zuercher): Instead of importing these constants from TBinaryProtocol and TCompactProtocol +BINARY_PROTO_ID = 0x80 +COMPACT_PROTO_ID = 0x82 + + +# INFO:(zuercher): Copied from: +# https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/protocol/TCompactProtocol.py +def getVarint(n): + out = [] + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + if sys.version_info[0] >= 3: + return bytes(out) + else: + return b''.join(map(chr, out)) + + +# INFO:(zuercher): Copied from +# https://github.com/facebook/fbthrift/blob/b090870/thrift/lib/py/protocol/TCompactProtocol.py +def readVarint(trans): + result = 0 + shift = 0 + while True: + x = trans.read(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 + + +# Import the snappy module if it is available +try: + import snappy +except ImportError: + # If snappy is not available, don't fail immediately. + # Only raise an error if we actually ever need to perform snappy + # compression. + class DummySnappy(object): + def compress(self, buf): + raise TTransportException(TTransportException.INVALID_TRANSFORM, + 'snappy module not available') + + def decompress(self, buf): + raise TTransportException(TTransportException.INVALID_TRANSFORM, + 'snappy module not available') + snappy = DummySnappy() # type: ignore + + +# Definitions from THeader.h + + +class CLIENT_TYPE: + HEADER = 0 + FRAMED_DEPRECATED = 1 + UNFRAMED_DEPRECATED = 2 + HTTP_SERVER = 3 + HTTP_CLIENT = 4 + FRAMED_COMPACT = 5 + HEADER_SASL = 6 + HTTP_GET = 7 + UNKNOWN = 8 + UNFRAMED_COMPACT_DEPRECATED = 9 + + +class HEADER_FLAG: + SUPPORT_OUT_OF_ORDER = 0x01 + DUPLEX_REVERSE = 0x08 + SASL = 0x10 + + +class TRANSFORM: + NONE = 0x00 + ZLIB = 0x01 + HMAC = 0x02 + SNAPPY = 0x03 + QLZ = 0x04 + ZSTD = 0x05 + + +class INFO: + NORMAL = 1 + PERSISTENT = 2 + + +T_BINARY_PROTOCOL = 0 +T_COMPACT_PROTOCOL = 2 +HEADER_MAGIC = 0x0FFF0000 +PACKED_HEADER_MAGIC = pack(b'!H', HEADER_MAGIC >> 16) +HEADER_MASK = 0xFFFF0000 +FLAGS_MASK = 0x0000FFFF +HTTP_SERVER_MAGIC = 0x504F5354 # POST +HTTP_CLIENT_MAGIC = 0x48545450 # HTTP +HTTP_GET_CLIENT_MAGIC = 0x47455420 # GET +HTTP_HEAD_CLIENT_MAGIC = 0x48454144 # HEAD +BIG_FRAME_MAGIC = 0x42494746 # BIGF +MAX_FRAME_SIZE = 0x3FFFFFFF +MAX_BIG_FRAME_SIZE = 2 ** 61 - 1 + + +class THeaderTransport(TTransportBase, CReadableTransport): + """Transport that sends headers. Also understands framed/unframed/HTTP + transports and will do the right thing""" + + __max_frame_size = MAX_FRAME_SIZE + + # Defaults to current user, but there is also a setter below. + __identity = None + IDENTITY_HEADER = "identity" + ID_VERSION_HEADER = "id_version" + ID_VERSION = "1" + + def __init__(self, trans, client_types=None, client_type=None): + self.__trans = trans + self.__rbuf = StringIO() + self.__rbuf_frame = False + self.__wbuf = StringIO() + self.seq_id = 0 + self.__flags = 0 + self.__read_transforms = [] + self.__write_transforms = [] + self.__supported_client_types = set(client_types or + (CLIENT_TYPE.HEADER,)) + self.__proto_id = T_COMPACT_PROTOCOL # default to compact like c++ + self.__client_type = client_type or CLIENT_TYPE.HEADER + self.__read_headers = {} + self.__read_persistent_headers = {} + self.__write_headers = {} + self.__write_persistent_headers = {} + + self.__supported_client_types.add(self.__client_type) + + # If we support unframed binary / framed binary also support compact + if CLIENT_TYPE.UNFRAMED_DEPRECATED in self.__supported_client_types: + self.__supported_client_types.add( + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED) + if CLIENT_TYPE.FRAMED_DEPRECATED in self.__supported_client_types: + self.__supported_client_types.add( + CLIENT_TYPE.FRAMED_COMPACT) + + def set_header_flag(self, flag): + self.__flags |= flag + + def clear_header_flag(self, flag): + self.__flags &= ~ flag + + def header_flags(self): + return self.__flags + + def set_max_frame_size(self, size): + if size > MAX_BIG_FRAME_SIZE: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "Cannot set max frame size > %s" % + MAX_BIG_FRAME_SIZE) + if size > MAX_FRAME_SIZE and self.__client_type != CLIENT_TYPE.HEADER: + raise TTransportException( + TTransportException.INVALID_FRAME_SIZE, + "Cannot set max frame size > %s for clients other than HEADER" + % MAX_FRAME_SIZE) + self.__max_frame_size = size + + def get_peer_identity(self): + if self.IDENTITY_HEADER in self.__read_headers: + if self.__read_headers[self.ID_VERSION_HEADER] == self.ID_VERSION: + return self.__read_headers[self.IDENTITY_HEADER] + return None + + def set_identity(self, identity): + self.__identity = identity + + def get_protocol_id(self): + return self.__proto_id + + def set_protocol_id(self, proto_id): + self.__proto_id = proto_id + + def set_header(self, str_key, str_value): + self.__write_headers[str_key] = str_value + + def get_write_headers(self): + return self.__write_headers + + def get_headers(self): + return self.__read_headers + + def clear_headers(self): + self.__write_headers.clear() + + def set_persistent_header(self, str_key, str_value): + self.__write_persistent_headers[str_key] = str_value + + def get_write_persistent_headers(self): + return self.__write_persistent_headers + + def clear_persistent_headers(self): + self.__write_persistent_headers.clear() + + def add_transform(self, trans_id): + self.__write_transforms.append(trans_id) + + def _reset_protocol(self): + # HTTP calls that are one way need to flush here. + if self.__client_type == CLIENT_TYPE.HTTP_SERVER: + self.flush() + # set to anything except unframed + self.__client_type = CLIENT_TYPE.UNKNOWN + # Read header bytes to check which protocol to decode + self.readFrame(0) + + def getTransport(self): + return self.__trans + + def isOpen(self): + return self.getTransport().isOpen() + + def open(self): + return self.getTransport().open() + + def close(self): + return self.getTransport().close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) == sz: + return ret + + if self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): + return ret + self.getTransport().readAll(sz - len(ret)) + + self.readFrame(sz - len(ret)) + return ret + self.__rbuf.read(sz - len(ret)) + + readAll = read # TTransportBase.readAll does a needless copy here. + + def readFrame(self, req_sz): + self.__rbuf_frame = True + word1 = self.getTransport().readAll(4) + sz = unpack('!I', word1)[0] + proto_id = word1[0] if PY3 else ord(word1[0]) + if proto_id == BINARY_PROTO_ID: + # unframed + self.__client_type = CLIENT_TYPE.UNFRAMED_DEPRECATED + self.__proto_id = T_BINARY_PROTOCOL + if req_sz <= 4: # check for reads < 0. + self.__rbuf = StringIO(word1) + else: + self.__rbuf = StringIO(word1 + self.getTransport().read( + req_sz - 4)) + elif proto_id == COMPACT_PROTO_ID: + self.__client_type = CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED + self.__proto_id = T_COMPACT_PROTOCOL + if req_sz <= 4: # check for reads < 0. + self.__rbuf = StringIO(word1) + else: + self.__rbuf = StringIO(word1 + self.getTransport().read( + req_sz - 4)) + elif sz == HTTP_SERVER_MAGIC: + self.__client_type = CLIENT_TYPE.HTTP_SERVER + mf = self.getTransport().handle.makefile('rb', -1) + + self.handler = RequestHandler(mf, + 'client_address:port', '') + self.header = self.handler.wfile + self.__rbuf = StringIO(self.handler.data) + else: + if sz == BIG_FRAME_MAGIC: + sz = unpack('!Q', self.getTransport().readAll(8))[0] + # could be header format or framed. Check next two bytes. + magic = self.getTransport().readAll(2) + proto_id = magic[0] if PY3 else ord(magic[0]) + if proto_id == COMPACT_PROTO_ID: + self.__client_type = CLIENT_TYPE.FRAMED_COMPACT + self.__proto_id = T_COMPACT_PROTOCOL + _frame_size_check(sz, self.__max_frame_size, header=False) + self.__rbuf = StringIO(magic + self.getTransport().readAll( + sz - 2)) + elif proto_id == BINARY_PROTO_ID: + self.__client_type = CLIENT_TYPE.FRAMED_DEPRECATED + self.__proto_id = T_BINARY_PROTOCOL + _frame_size_check(sz, self.__max_frame_size, header=False) + self.__rbuf = StringIO(magic + self.getTransport().readAll( + sz - 2)) + elif magic == PACKED_HEADER_MAGIC: + self.__client_type = CLIENT_TYPE.HEADER + _frame_size_check(sz, self.__max_frame_size) + # flags(2), seq_id(4), header_size(2) + n_header_meta = self.getTransport().readAll(8) + self.__flags, self.seq_id, header_size = unpack('!HIH', + n_header_meta) + data = StringIO() + data.write(magic) + data.write(n_header_meta) + data.write(self.getTransport().readAll(sz - 10)) + data.seek(10) + self.read_header_format(sz - 10, header_size, data) + else: + self.__client_type = CLIENT_TYPE.UNKNOWN + raise TTransportException( + TTransportException.INVALID_CLIENT_TYPE, + "Could not detect client transport type") + + if self.__client_type not in self.__supported_client_types: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Client type {} not supported on server" + .format(self.__client_type)) + + def read_header_format(self, sz, header_size, data): + # clear out any previous transforms + self.__read_transforms = [] + + header_size = header_size * 4 + if header_size > sz: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "Header size is larger than frame") + end_header = header_size + data.tell() + + self.__proto_id = readVarint(data) + num_headers = readVarint(data) + + if self.__proto_id == 1 and self.__client_type != \ + CLIENT_TYPE.HTTP_SERVER: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Trying to recv JSON encoding over binary") + + # Read the headers. Data for each header varies. + for _ in range(0, num_headers): + trans_id = readVarint(data) + if trans_id == TRANSFORM.ZLIB: + self.__read_transforms.insert(0, trans_id) + elif trans_id == TRANSFORM.SNAPPY: + self.__read_transforms.insert(0, trans_id) + elif trans_id == TRANSFORM.HMAC: + raise TApplicationException( + TApplicationException.INVALID_TRANSFORM, + "Hmac transform is no longer supported: %i" % trans_id) + else: + # TApplicationException will be sent back to client + raise TApplicationException( + TApplicationException.INVALID_TRANSFORM, + "Unknown transform in client request: %i" % trans_id) + + # Clear out previous info headers. + self.__read_headers.clear() + + # Read the info headers. + while data.tell() < end_header: + info_id = readVarint(data) + if info_id == INFO.NORMAL: + _read_info_headers( + data, end_header, self.__read_headers) + elif info_id == INFO.PERSISTENT: + _read_info_headers( + data, end_header, self.__read_persistent_headers) + else: + break # Unknown header. Stop info processing. + + if self.__read_persistent_headers: + self.__read_headers.update(self.__read_persistent_headers) + + # Skip the rest of the header + data.seek(end_header) + + payload = data.read(sz - header_size) + + # Read the data section. + self.__rbuf = StringIO(self.untransform(payload)) + + def write(self, buf): + self.__wbuf.write(buf) + + def transform(self, buf): + for trans_id in self.__write_transforms: + if trans_id == TRANSFORM.ZLIB: + buf = zlib.compress(buf) + elif trans_id == TRANSFORM.SNAPPY: + buf = snappy.compress(buf) + else: + raise TTransportException(TTransportException.INVALID_TRANSFORM, + "Unknown transform during send") + return buf + + def untransform(self, buf): + for trans_id in self.__read_transforms: + if trans_id == TRANSFORM.ZLIB: + buf = zlib.decompress(buf) + elif trans_id == TRANSFORM.SNAPPY: + buf = snappy.decompress(buf) + if trans_id not in self.__write_transforms: + self.__write_transforms.append(trans_id) + return buf + + def flush(self): + self.flushImpl(False) + + def onewayFlush(self): + self.flushImpl(True) + + def _flushHeaderMessage(self, buf, wout, wsz): + """Write a message for CLIENT_TYPE.HEADER + + @param buf(StringIO): Buffer to write message to + @param wout(str): Payload + @param wsz(int): Payload length + """ + transform_data = StringIO() + # For now, all transforms don't require data. + num_transforms = len(self.__write_transforms) + for trans_id in self.__write_transforms: + transform_data.write(getVarint(trans_id)) + + # Add in special flags. + if self.__identity: + self.__write_headers[self.ID_VERSION_HEADER] = self.ID_VERSION + self.__write_headers[self.IDENTITY_HEADER] = self.__identity + + info_data = StringIO() + + # Write persistent kv-headers + _flush_info_headers(info_data, + self.get_write_persistent_headers(), + INFO.PERSISTENT) + + # Write non-persistent kv-headers + _flush_info_headers(info_data, + self.__write_headers, + INFO.NORMAL) + + header_data = StringIO() + header_data.write(getVarint(self.__proto_id)) + header_data.write(getVarint(num_transforms)) + + header_size = transform_data.tell() + header_data.tell() + \ + info_data.tell() + + padding_size = 4 - (header_size % 4) + header_size = header_size + padding_size + + # MAGIC(2) | FLAGS(2) + SEQ_ID(4) + HEADER_SIZE(2) + wsz += header_size + 10 + if wsz > MAX_FRAME_SIZE: + buf.write(pack("!I", BIG_FRAME_MAGIC)) + buf.write(pack("!Q", wsz)) + else: + buf.write(pack("!I", wsz)) + buf.write(pack("!HH", HEADER_MAGIC >> 16, self.__flags)) + buf.write(pack("!I", self.seq_id)) + buf.write(pack("!H", header_size // 4)) + + buf.write(header_data.getvalue()) + buf.write(transform_data.getvalue()) + buf.write(info_data.getvalue()) + + # Pad out the header with 0x00 + for _ in range(0, padding_size, 1): + buf.write(pack("!c", b'\0')) + + # Send data section + buf.write(wout) + + def flushImpl(self, oneway): + wout = self.__wbuf.getvalue() + wout = self.transform(wout) + wsz = len(wout) + + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf.seek(0) + self.__wbuf.truncate() + + if self.__proto_id == 1 and self.__client_type != CLIENT_TYPE.HTTP_SERVER: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Trying to send JSON encoding over binary") + + buf = StringIO() + if self.__client_type == CLIENT_TYPE.HEADER: + self._flushHeaderMessage(buf, wout, wsz) + elif self.__client_type in (CLIENT_TYPE.FRAMED_DEPRECATED, + CLIENT_TYPE.FRAMED_COMPACT): + buf.write(pack("!i", wsz)) + buf.write(wout) + elif self.__client_type in (CLIENT_TYPE.UNFRAMED_DEPRECATED, + CLIENT_TYPE.UNFRAMED_COMPACT_DEPRECATED): + buf.write(wout) + elif self.__client_type == CLIENT_TYPE.HTTP_SERVER: + # Reset the client type if we sent something - + # oneway calls via HTTP expect a status response otherwise + buf.write(self.header.getvalue()) + buf.write(wout) + self.__client_type == CLIENT_TYPE.HEADER + elif self.__client_type == CLIENT_TYPE.UNKNOWN: + raise TTransportException(TTransportException.INVALID_CLIENT_TYPE, + "Unknown client type") + + # We don't include the framing bytes as part of the frame size check + frame_size = buf.tell() - (4 if wsz < MAX_FRAME_SIZE else 12) + _frame_size_check(frame_size, + self.__max_frame_size, + header=self.__client_type == CLIENT_TYPE.HEADER) + self.getTransport().write(buf.getvalue()) + if oneway: + self.getTransport().onewayFlush() + else: + self.getTransport().flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + if not self.__rbuf_frame: + self.readFrame(0) + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastproto doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + + # On unframed clients, there is a chance there is something left + # in rbuf, and the read pointer is not advanced by fastproto + # so seek to the end to be safe + self.__rbuf.seek(0, 2) + while len(prefix) < reqlen: + prefix += self.read(reqlen) + self.__rbuf = StringIO(prefix) + return self.__rbuf + + +def _serialize_string(str_): + if PY3 and not isinstance(str_, bytes): + str_ = str_.encode() + return getVarint(len(str_)) + str_ + + +def _flush_info_headers(info_data, write_headers, type): + if (len(write_headers) > 0): + info_data.write(getVarint(type)) + info_data.write(getVarint(len(write_headers))) + write_headers_iter = write_headers.items() + for str_key, str_value in write_headers_iter: + info_data.write(_serialize_string(str_key)) + info_data.write(_serialize_string(str_value)) + write_headers.clear() + + +def _read_string(bufio, buflimit): + str_sz = readVarint(bufio) + if str_sz + bufio.tell() > buflimit: + raise TTransportException(TTransportException.INVALID_FRAME_SIZE, + "String read too big") + return bufio.read(str_sz) + + +def _read_info_headers(data, end_header, read_headers): + num_keys = readVarint(data) + for _ in xrange(num_keys): + str_key = _read_string(data, end_header) + str_value = _read_string(data, end_header) + read_headers[str_key] = str_value + + +def _frame_size_check(sz, set_max_size, header=True): + if sz > set_max_size or (not header and sz > MAX_FRAME_SIZE): + raise TTransportException( + TTransportException.INVALID_FRAME_SIZE, + "%s transport frame was too large" % 'Header' if header else 'Framed' + ) + + +class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): + + # Same as superclass function, but append 'POST' because we + # stripped it in the calling function. Would be nice if + # we had an ungetch instead + def handle_one_request(self): + self.raw_requestline = self.rfile.readline() + if not self.raw_requestline: + self.close_connection = 1 + return + self.raw_requestline = "POST" + self.raw_requestline + if not self.parse_request(): + # An error code has been sent, just exit + return + mname = 'do_' + self.command + if not hasattr(self, mname): + self.send_error(501, "Unsupported method (%r)" % self.command) + return + method = getattr(self, mname) + method() + + def setup(self): + self.rfile = self.request + self.wfile = StringIO() # New output buffer + + def finish(self): + if not self.rfile.closed: + self.rfile.close() + # leave wfile open for reading. + + def do_POST(self): + if int(self.headers['Content-Length']) > 0: + self.data = self.rfile.read(int(self.headers['Content-Length'])) + else: + self.data = "" + + # Prepare a response header, to be sent later. + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() + +# INFO:(zuercher): Added to simplify usage +class THeaderTransportFactory: + def getTransport(self, trans): + return THeaderTransport(trans, client_type=CLIENT_TYPE.HEADER) diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/BUILD b/test/extensions/filters/network/thrift_proxy/driver/finagle/BUILD new file mode 100644 index 0000000000000..71fa29d640635 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/BUILD @@ -0,0 +1,19 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +py_library( + name = "finagle_lib", + srcs = [ + "TFinagleServerProcessor.py", + "TFinagleServerProtocol.py", + "__init__.py", + ], + deps = [ + "@com_github_apache_thrift//:apache_thrift", + "@com_github_twitter_common_finagle_thrift//:twitter_common_finagle_thrift", + "@com_github_twitter_common_rpc//:twitter_common_rpc", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py new file mode 100644 index 0000000000000..3b207152ea21b --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProcessor.py @@ -0,0 +1,56 @@ +import logging + +from thrift.Thrift import TProcessor, TMessageType, TException +from thrift.protocol import TProtocolDecorator +from gen.twitter.finagle.thrift.ttypes import (ConnectionOptions, UpgradeReply) + +# Matches twitter/common/rpc/finagle/protocol.py +UPGRADE_METHOD = "__can__finagle__trace__v3__" + +# Twitter's TFinagleProcessor only works for the client side of an RPC. +class TFinagleServerProcessor(TProcessor): + def __init__(self, underlying): + self._underlying = underlying + + def process(self, iprot, oprot): + try: + if iprot.upgraded() is not None: + return self._underlying.process(iprot, oprot) + except AttributeError as e: + logging.exception("underlying protocol object is not a TFinagleServerProtocol", e) + return self._underlying.process(iprot, oprot) + + (name, ttype, seqid) = iprot.readMessageBegin() + if ttype != TMessageType.CALL and ttype != TMessageType.ONEWAY: + raise TException("TFinagle protocol only supports CALL & ONEWAY") + + # Check if this is an upgrade request. + if name == UPGRADE_METHOD: + connection_options = ConnectionOptions() + connection_options.read(iprot) + iprot.readMessageEnd() + + oprot.writeMessageBegin(UPGRADE_METHOD, TMessageType.REPLY, seqid) + upgrade_reply = UpgradeReply() + upgrade_reply.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + iprot.set_upgraded(True) + oprot.set_upgraded(True) + return True + + # Not upgraded. Replay the message begin to the underlying processor. + iprot.set_upgraded(False) + oprot.set_upgraded(False) + msg = (name, ttype, seqid) + return self._underlying.process(StoredMessageProtocol(iprot, msg), oprot) + + +class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator): + def __init__(self, protocol, messageBegin): + TProtocolDecorator.TProtocolDecorator.__init__(self, protocol) + self.messageBegin = messageBegin + + def readMessageBegin(self): + return self.messageBegin diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py new file mode 100644 index 0000000000000..dcdad5122e0e8 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/finagle/TFinagleServerProtocol.py @@ -0,0 +1,33 @@ +from thrift.protocol import TBinaryProtocol +from gen.twitter.finagle.thrift.ttypes import (RequestHeader, ResponseHeader) + + +class TFinagleServerProtocolFactory(object): + def getProtocol(self, trans): + return TFinagleServerProtocol(trans) + + +class TFinagleServerProtocol(TBinaryProtocol.TBinaryProtocol): + def __init__(self, *args, **kw): + self._last_request = None + self._upgraded = None + TBinaryProtocol.TBinaryProtocol.__init__(self, *args, **kw) + + def upgraded(self): + return self._upgraded + + def set_upgraded(self, upgraded): + self._upgraded = upgraded + + def writeMessageBegin(self, *args, **kwargs): + if self._upgraded: + header = ResponseHeader() # .. TODO set some fields + header.write(self) + return TBinaryProtocol.TBinaryProtocol.writeMessageBegin(self, *args, **kwargs) + + def readMessageBegin(self, *args, **kwargs): + if self._upgraded: + header = RequestHeader() + header.read(self) + self._last_request = header + return TBinaryProtocol.TBinaryProtocol.readMessageBegin(self, *args, **kwargs) diff --git a/test/extensions/filters/network/thrift_proxy/driver/finagle/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/finagle/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/extensions/filters/network/thrift_proxy/driver/generate_bindings.sh b/test/extensions/filters/network/thrift_proxy/driver/generate_bindings.sh new file mode 100755 index 0000000000000..6b65871512c0c --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generate_bindings.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# Generates the thrift bindings for example.thrift. Requires that +# apache-thrift's thrift generator is installed and on the path. + +DIR=$(cd `dirname $0` && pwd) +cd "${DIR}" + +thrift --gen py --out ./generated example.thrift diff --git a/test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh b/test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh new file mode 100755 index 0000000000000..be83b3eb65992 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +# Generates request and response fixtures for integration tests. + +# Usage: generate_fixture.sh [multiplex-service] -- method [param...] + +set -e + +function usage() { + echo "Usage: $0 [multiplex-service] -- method [param...]" + echo "where mode is success, exception, or idl-exception" + exit 1 +} + +FIXTURE_DIR="${TEST_TMPDIR}" +mkdir -p "${FIXTURE_DIR}" + +DRIVER_DIR="${TEST_RUNDIR}/test/extensions/filters/network/thrift_proxy/driver" + +if [[ -z "${TEST_UDSDIR}" ]]; then + TEST_UDSDIR=`mktemp -d /tmp/envoy_test_thrift.XXXXXX` +fi + +MODE="$1" +TRANSPORT="$2" +PROTOCOL="$3" +MULTIPLEX="$4" +if ! shift 4; then + usage +fi + +if [[ -z "${MODE}" || -z "${TRANSPORT}" || -z "${PROTOCOL}" || -z "${MULTIPLEX}" ]]; then + usage +fi + +if [[ "${MULTIPLEX}" != "--" ]]; then + if [[ "$1" != "--" ]]; then + echo "expected -- after multiplex service name" + exit 1 + fi + shift +else + MULTIPLEX="" +fi + +METHOD="$1" +if [[ "${METHOD}" == "" ]]; then + usage +fi +shift + +SOCKET="${TEST_UDSDIR}/fixture.sock" +rm -f "${SOCKET}" + +SERVICE_FLAGS=("--addr" "${SOCKET}" + "--unix" + "--response" "${MODE}" + "--transport" "${TRANSPORT}" + "--protocol" "${PROTOCOL}") + +if [[ -n "$MULTIPLEX" ]]; then + SERVICE_FLAGS[9]="--multiplex" + SERVICE_FLAGS[10]="${MULTIPLEX}" + + REQUEST_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MULTIPLEX}-${MODE}.request" + RESPONSE_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MULTIPLEX}-${MODE}.response" +else + REQUEST_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MODE}.request" + RESPONSE_FILE="${FIXTURE_DIR}/${TRANSPORT}-${PROTOCOL}-${MODE}.response" +fi + +# start server +"${DRIVER_DIR}/server" "${SERVICE_FLAGS[@]}" & +SERVER_PID="$!" + +trap "kill ${SERVER_PID}" EXIT; + +while [[ ! -a "${SOCKET}" ]]; do + sleep 0.1 + + if ! kill -0 "${SERVER_PID}"; then + echo "server failed to start" + exit 1 + fi +done + +"${DRIVER_DIR}/client" "${SERVICE_FLAGS[@]}" \ + --request "${REQUEST_FILE}" \ + --response "${RESPONSE_FILE}" \ + "${METHOD}" "$@" diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/generated/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/BUILD b/test/extensions/filters/network/thrift_proxy/driver/generated/example/BUILD new file mode 100644 index 0000000000000..6c9595737b16f --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/BUILD @@ -0,0 +1,18 @@ +licenses(["notice"]) # Apache 2 + +load("//bazel:envoy_build_system.bzl", "envoy_package") + +envoy_package() + +py_library( + name = "example_lib", + srcs = [ + "Example.py", + "__init__.py", + "constants.py", + "ttypes.py", + ], + deps = [ + "@com_github_apache_thrift//:apache_thrift", + ], +) diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example-remote b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example-remote new file mode 100755 index 0000000000000..11d032908d651 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example-remote @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +import sys +import pprint +if sys.version_info[0] > 2: + from urllib.parse import urlparse +else: + from urlparse import urlparse +from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient +from thrift.protocol.TBinaryProtocol import TBinaryProtocol + +from example import Example +from example.ttypes import * + +if len(sys.argv) <= 1 or sys.argv[1] == '--help': + print('') + print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]') + print('') + print('Functions:') + print(' void ping()') + print(' void poke()') + print(' i32 add(i32 a, i32 b)') + print(' Result execute(Param input)') + print('') + sys.exit(0) + +pp = pprint.PrettyPrinter(indent=2) +host = 'localhost' +port = 9090 +uri = '' +framed = False +ssl = False +validate = True +ca_certs = None +keyfile = None +certfile = None +http = False +argi = 1 + +if sys.argv[argi] == '-h': + parts = sys.argv[argi + 1].split(':') + host = parts[0] + if len(parts) > 1: + port = int(parts[1]) + argi += 2 + +if sys.argv[argi] == '-u': + url = urlparse(sys.argv[argi + 1]) + parts = url[1].split(':') + host = parts[0] + if len(parts) > 1: + port = int(parts[1]) + else: + port = 80 + uri = url[2] + if url[4]: + uri += '?%s' % url[4] + http = True + argi += 2 + +if sys.argv[argi] == '-f' or sys.argv[argi] == '-framed': + framed = True + argi += 1 + +if sys.argv[argi] == '-s' or sys.argv[argi] == '-ssl': + ssl = True + argi += 1 + +if sys.argv[argi] == '-novalidate': + validate = False + argi += 1 + +if sys.argv[argi] == '-ca_certs': + ca_certs = sys.argv[argi+1] + argi += 2 + +if sys.argv[argi] == '-keyfile': + keyfile = sys.argv[argi+1] + argi += 2 + +if sys.argv[argi] == '-certfile': + certfile = sys.argv[argi+1] + argi += 2 + +cmd = sys.argv[argi] +args = sys.argv[argi + 1:] + +if http: + transport = THttpClient.THttpClient(host, port, uri) +else: + if ssl: + socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile) + else: + socket = TSocket.TSocket(host, port) + if framed: + transport = TTransport.TFramedTransport(socket) + else: + transport = TTransport.TBufferedTransport(socket) +protocol = TBinaryProtocol(transport) +client = Example.Client(protocol) +transport.open() + +if cmd == 'ping': + if len(args) != 0: + print('ping requires 0 args') + sys.exit(1) + pp.pprint(client.ping()) + +elif cmd == 'poke': + if len(args) != 0: + print('poke requires 0 args') + sys.exit(1) + pp.pprint(client.poke()) + +elif cmd == 'add': + if len(args) != 2: + print('add requires 2 args') + sys.exit(1) + pp.pprint(client.add(eval(args[0]), eval(args[1]),)) + +elif cmd == 'execute': + if len(args) != 1: + print('execute requires 1 args') + sys.exit(1) + pp.pprint(client.execute(eval(args[0]),)) + +else: + print('Unrecognized method %s' % cmd) + sys.exit(1) + +transport.close() diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example.py new file mode 100644 index 0000000000000..325cbff2bae35 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/Example.py @@ -0,0 +1,660 @@ +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException +from thrift.protocol.TProtocol import TProtocolException +from thrift.TRecursive import fix_spec + +import sys +import logging +from .ttypes import * +from thrift.Thrift import TProcessor +from thrift.transport import TTransport +all_structs = [] + + +class Iface(object): + def ping(self): + pass + + def poke(self): + pass + + def add(self, a, b): + """ + Parameters: + - a + - b + """ + pass + + def execute(self, input): + """ + Parameters: + - input + """ + pass + + +class Client(Iface): + def __init__(self, iprot, oprot=None): + self._iprot = self._oprot = iprot + if oprot is not None: + self._oprot = oprot + self._seqid = 0 + + def ping(self): + self.send_ping() + self.recv_ping() + + def send_ping(self): + self._oprot.writeMessageBegin('ping', TMessageType.CALL, self._seqid) + args = ping_args() + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def recv_ping(self): + iprot = self._iprot + (fname, mtype, rseqid) = iprot.readMessageBegin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(iprot) + iprot.readMessageEnd() + raise x + result = ping_result() + result.read(iprot) + iprot.readMessageEnd() + return + + def poke(self): + self.send_poke() + + def send_poke(self): + self._oprot.writeMessageBegin('poke', TMessageType.ONEWAY, self._seqid) + args = poke_args() + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def add(self, a, b): + """ + Parameters: + - a + - b + """ + self.send_add(a, b) + return self.recv_add() + + def send_add(self, a, b): + self._oprot.writeMessageBegin('add', TMessageType.CALL, self._seqid) + args = add_args() + args.a = a + args.b = b + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def recv_add(self): + iprot = self._iprot + (fname, mtype, rseqid) = iprot.readMessageBegin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(iprot) + iprot.readMessageEnd() + raise x + result = add_result() + result.read(iprot) + iprot.readMessageEnd() + if result.success is not None: + return result.success + raise TApplicationException(TApplicationException.MISSING_RESULT, "add failed: unknown result") + + def execute(self, input): + """ + Parameters: + - input + """ + self.send_execute(input) + return self.recv_execute() + + def send_execute(self, input): + self._oprot.writeMessageBegin('execute', TMessageType.CALL, self._seqid) + args = execute_args() + args.input = input + args.write(self._oprot) + self._oprot.writeMessageEnd() + self._oprot.trans.flush() + + def recv_execute(self): + iprot = self._iprot + (fname, mtype, rseqid) = iprot.readMessageBegin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(iprot) + iprot.readMessageEnd() + raise x + result = execute_result() + result.read(iprot) + iprot.readMessageEnd() + if result.success is not None: + return result.success + if result.appex is not None: + raise result.appex + raise TApplicationException(TApplicationException.MISSING_RESULT, "execute failed: unknown result") + + +class Processor(Iface, TProcessor): + def __init__(self, handler): + self._handler = handler + self._processMap = {} + self._processMap["ping"] = Processor.process_ping + self._processMap["poke"] = Processor.process_poke + self._processMap["add"] = Processor.process_add + self._processMap["execute"] = Processor.process_execute + + def process(self, iprot, oprot): + (name, type, seqid) = iprot.readMessageBegin() + if name not in self._processMap: + iprot.skip(TType.STRUCT) + iprot.readMessageEnd() + x = TApplicationException(TApplicationException.UNKNOWN_METHOD, 'Unknown function %s' % (name)) + oprot.writeMessageBegin(name, TMessageType.EXCEPTION, seqid) + x.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + return + else: + self._processMap[name](self, seqid, iprot, oprot) + return True + + def process_ping(self, seqid, iprot, oprot): + args = ping_args() + args.read(iprot) + iprot.readMessageEnd() + result = ping_result() + try: + self._handler.ping() + msg_type = TMessageType.REPLY + except TTransport.TTransportException: + raise + except TApplicationException as ex: + logging.exception('TApplication exception in handler') + msg_type = TMessageType.EXCEPTION + result = ex + except Exception: + logging.exception('Unexpected exception in handler') + msg_type = TMessageType.EXCEPTION + result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') + oprot.writeMessageBegin("ping", msg_type, seqid) + result.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + def process_poke(self, seqid, iprot, oprot): + args = poke_args() + args.read(iprot) + iprot.readMessageEnd() + try: + self._handler.poke() + except TTransport.TTransportException: + raise + except Exception: + logging.exception('Exception in oneway handler') + + def process_add(self, seqid, iprot, oprot): + args = add_args() + args.read(iprot) + iprot.readMessageEnd() + result = add_result() + try: + result.success = self._handler.add(args.a, args.b) + msg_type = TMessageType.REPLY + except TTransport.TTransportException: + raise + except TApplicationException as ex: + logging.exception('TApplication exception in handler') + msg_type = TMessageType.EXCEPTION + result = ex + except Exception: + logging.exception('Unexpected exception in handler') + msg_type = TMessageType.EXCEPTION + result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') + oprot.writeMessageBegin("add", msg_type, seqid) + result.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + + def process_execute(self, seqid, iprot, oprot): + args = execute_args() + args.read(iprot) + iprot.readMessageEnd() + result = execute_result() + try: + result.success = self._handler.execute(args.input) + msg_type = TMessageType.REPLY + except TTransport.TTransportException: + raise + except AppException as appex: + msg_type = TMessageType.REPLY + result.appex = appex + except TApplicationException as ex: + logging.exception('TApplication exception in handler') + msg_type = TMessageType.EXCEPTION + result = ex + except Exception: + logging.exception('Unexpected exception in handler') + msg_type = TMessageType.EXCEPTION + result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') + oprot.writeMessageBegin("execute", msg_type, seqid) + result.write(oprot) + oprot.writeMessageEnd() + oprot.trans.flush() + +# HELPER FUNCTIONS AND STRUCTURES + + +class ping_args(object): + + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('ping_args') + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(ping_args) +ping_args.thrift_spec = ( +) + + +class ping_result(object): + + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('ping_result') + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(ping_result) +ping_result.thrift_spec = ( +) + + +class poke_args(object): + + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('poke_args') + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(poke_args) +poke_args.thrift_spec = ( +) + + +class add_args(object): + """ + Attributes: + - a + - b + """ + + + def __init__(self, a=None, b=None,): + self.a = a + self.b = b + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.I32: + self.a = iprot.readI32() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.b = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('add_args') + if self.a is not None: + oprot.writeFieldBegin('a', TType.I32, 1) + oprot.writeI32(self.a) + oprot.writeFieldEnd() + if self.b is not None: + oprot.writeFieldBegin('b', TType.I32, 2) + oprot.writeI32(self.b) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(add_args) +add_args.thrift_spec = ( + None, # 0 + (1, TType.I32, 'a', None, None, ), # 1 + (2, TType.I32, 'b', None, None, ), # 2 +) + + +class add_result(object): + """ + Attributes: + - success + """ + + + def __init__(self, success=None,): + self.success = success + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 0: + if ftype == TType.I32: + self.success = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('add_result') + if self.success is not None: + oprot.writeFieldBegin('success', TType.I32, 0) + oprot.writeI32(self.success) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(add_result) +add_result.thrift_spec = ( + (0, TType.I32, 'success', None, None, ), # 0 +) + + +class execute_args(object): + """ + Attributes: + - input + """ + + + def __init__(self, input=None,): + self.input = input + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRUCT: + self.input = Param() + self.input.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('execute_args') + if self.input is not None: + oprot.writeFieldBegin('input', TType.STRUCT, 1) + self.input.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(execute_args) +execute_args.thrift_spec = ( + None, # 0 + (1, TType.STRUCT, 'input', [Param, None], None, ), # 1 +) + + +class execute_result(object): + """ + Attributes: + - success + - appex + """ + + + def __init__(self, success=None, appex=None,): + self.success = success + self.appex = appex + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 0: + if ftype == TType.STRUCT: + self.success = Result() + self.success.read(iprot) + else: + iprot.skip(ftype) + elif fid == 1: + if ftype == TType.STRUCT: + self.appex = AppException() + self.appex.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('execute_result') + if self.success is not None: + oprot.writeFieldBegin('success', TType.STRUCT, 0) + self.success.write(oprot) + oprot.writeFieldEnd() + if self.appex is not None: + oprot.writeFieldBegin('appex', TType.STRUCT, 1) + self.appex.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(execute_result) +execute_result.thrift_spec = ( + (0, TType.STRUCT, 'success', [Result, None], None, ), # 0 + (1, TType.STRUCT, 'appex', [AppException, None], None, ), # 1 +) +fix_spec(all_structs) +del all_structs + diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/__init__.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/__init__.py new file mode 100644 index 0000000000000..a53ccc6084eeb --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/__init__.py @@ -0,0 +1 @@ +__all__ = ['ttypes', 'constants', 'Example'] diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/constants.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/constants.py new file mode 100644 index 0000000000000..0c217ceda6915 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/constants.py @@ -0,0 +1,14 @@ +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException +from thrift.protocol.TProtocol import TProtocolException +from thrift.TRecursive import fix_spec + +import sys +from .ttypes import * diff --git a/test/extensions/filters/network/thrift_proxy/driver/generated/example/ttypes.py b/test/extensions/filters/network/thrift_proxy/driver/generated/example/ttypes.py new file mode 100644 index 0000000000000..89aa4a9f62334 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/generated/example/ttypes.py @@ -0,0 +1,445 @@ +# +# Autogenerated by Thrift Compiler (0.11.0) +# +# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING +# +# options string: py +# + +from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException +from thrift.protocol.TProtocol import TProtocolException +from thrift.TRecursive import fix_spec + +import sys + +from thrift.transport import TTransport +all_structs = [] + + +class TheWorks(object): + """ + Attributes: + - field_1 + - field_2 + - field_3 + - field_4 + - field_5 + - field_6 + - field_7 + - field_8 + - field_9 + - field_10 + - field_11 + - field_12 + """ + + + def __init__(self, field_1=None, field_2=None, field_3=None, field_4=None, field_5=None, field_6=None, field_7=None, field_8=None, field_9=None, field_10=None, field_11=None, field_12=None,): + self.field_1 = field_1 + self.field_2 = field_2 + self.field_3 = field_3 + self.field_4 = field_4 + self.field_5 = field_5 + self.field_6 = field_6 + self.field_7 = field_7 + self.field_8 = field_8 + self.field_9 = field_9 + self.field_10 = field_10 + self.field_11 = field_11 + self.field_12 = field_12 + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.BOOL: + self.field_1 = iprot.readBool() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.BYTE: + self.field_2 = iprot.readByte() + else: + iprot.skip(ftype) + elif fid == 3: + if ftype == TType.I16: + self.field_3 = iprot.readI16() + else: + iprot.skip(ftype) + elif fid == 4: + if ftype == TType.I32: + self.field_4 = iprot.readI32() + else: + iprot.skip(ftype) + elif fid == 5: + if ftype == TType.I64: + self.field_5 = iprot.readI64() + else: + iprot.skip(ftype) + elif fid == 6: + if ftype == TType.DOUBLE: + self.field_6 = iprot.readDouble() + else: + iprot.skip(ftype) + elif fid == 7: + if ftype == TType.STRING: + self.field_7 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + else: + iprot.skip(ftype) + elif fid == 8: + if ftype == TType.STRING: + self.field_8 = iprot.readBinary() + else: + iprot.skip(ftype) + elif fid == 9: + if ftype == TType.MAP: + self.field_9 = {} + (_ktype1, _vtype2, _size0) = iprot.readMapBegin() + for _i4 in range(_size0): + _key5 = iprot.readI32() + _val6 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + self.field_9[_key5] = _val6 + iprot.readMapEnd() + else: + iprot.skip(ftype) + elif fid == 10: + if ftype == TType.LIST: + self.field_10 = [] + (_etype10, _size7) = iprot.readListBegin() + for _i11 in range(_size7): + _elem12 = iprot.readI32() + self.field_10.append(_elem12) + iprot.readListEnd() + else: + iprot.skip(ftype) + elif fid == 11: + if ftype == TType.SET: + self.field_11 = set() + (_etype16, _size13) = iprot.readSetBegin() + for _i17 in range(_size13): + _elem18 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + self.field_11.add(_elem18) + iprot.readSetEnd() + else: + iprot.skip(ftype) + elif fid == 12: + if ftype == TType.BOOL: + self.field_12 = iprot.readBool() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('TheWorks') + if self.field_1 is not None: + oprot.writeFieldBegin('field_1', TType.BOOL, 1) + oprot.writeBool(self.field_1) + oprot.writeFieldEnd() + if self.field_2 is not None: + oprot.writeFieldBegin('field_2', TType.BYTE, 2) + oprot.writeByte(self.field_2) + oprot.writeFieldEnd() + if self.field_3 is not None: + oprot.writeFieldBegin('field_3', TType.I16, 3) + oprot.writeI16(self.field_3) + oprot.writeFieldEnd() + if self.field_4 is not None: + oprot.writeFieldBegin('field_4', TType.I32, 4) + oprot.writeI32(self.field_4) + oprot.writeFieldEnd() + if self.field_5 is not None: + oprot.writeFieldBegin('field_5', TType.I64, 5) + oprot.writeI64(self.field_5) + oprot.writeFieldEnd() + if self.field_6 is not None: + oprot.writeFieldBegin('field_6', TType.DOUBLE, 6) + oprot.writeDouble(self.field_6) + oprot.writeFieldEnd() + if self.field_7 is not None: + oprot.writeFieldBegin('field_7', TType.STRING, 7) + oprot.writeString(self.field_7.encode('utf-8') if sys.version_info[0] == 2 else self.field_7) + oprot.writeFieldEnd() + if self.field_8 is not None: + oprot.writeFieldBegin('field_8', TType.STRING, 8) + oprot.writeBinary(self.field_8) + oprot.writeFieldEnd() + if self.field_9 is not None: + oprot.writeFieldBegin('field_9', TType.MAP, 9) + oprot.writeMapBegin(TType.I32, TType.STRING, len(self.field_9)) + for kiter19, viter20 in self.field_9.items(): + oprot.writeI32(kiter19) + oprot.writeString(viter20.encode('utf-8') if sys.version_info[0] == 2 else viter20) + oprot.writeMapEnd() + oprot.writeFieldEnd() + if self.field_10 is not None: + oprot.writeFieldBegin('field_10', TType.LIST, 10) + oprot.writeListBegin(TType.I32, len(self.field_10)) + for iter21 in self.field_10: + oprot.writeI32(iter21) + oprot.writeListEnd() + oprot.writeFieldEnd() + if self.field_11 is not None: + oprot.writeFieldBegin('field_11', TType.SET, 11) + oprot.writeSetBegin(TType.STRING, len(self.field_11)) + for iter22 in self.field_11: + oprot.writeString(iter22.encode('utf-8') if sys.version_info[0] == 2 else iter22) + oprot.writeSetEnd() + oprot.writeFieldEnd() + if self.field_12 is not None: + oprot.writeFieldBegin('field_12', TType.BOOL, 12) + oprot.writeBool(self.field_12) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + +class Param(object): + """ + Attributes: + - return_fields + - the_works + """ + + + def __init__(self, return_fields=None, the_works=None,): + self.return_fields = return_fields + self.the_works = the_works + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.LIST: + self.return_fields = [] + (_etype26, _size23) = iprot.readListBegin() + for _i27 in range(_size23): + _elem28 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + self.return_fields.append(_elem28) + iprot.readListEnd() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.STRUCT: + self.the_works = TheWorks() + self.the_works.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('Param') + if self.return_fields is not None: + oprot.writeFieldBegin('return_fields', TType.LIST, 1) + oprot.writeListBegin(TType.STRING, len(self.return_fields)) + for iter29 in self.return_fields: + oprot.writeString(iter29.encode('utf-8') if sys.version_info[0] == 2 else iter29) + oprot.writeListEnd() + oprot.writeFieldEnd() + if self.the_works is not None: + oprot.writeFieldBegin('the_works', TType.STRUCT, 2) + self.the_works.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + +class Result(object): + """ + Attributes: + - the_works + """ + + + def __init__(self, the_works=None,): + self.the_works = the_works + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRUCT: + self.the_works = TheWorks() + self.the_works.read(iprot) + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('Result') + if self.the_works is not None: + oprot.writeFieldBegin('the_works', TType.STRUCT, 1) + self.the_works.write(oprot) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) + + +class AppException(TException): + """ + Attributes: + - why + """ + + + def __init__(self, why=None,): + self.why = why + + def read(self, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: + iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec]) + return + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.why = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('AppException') + if self.why is not None: + oprot.writeFieldBegin('why', TType.STRING, 1) + oprot.writeString(self.why.encode('utf-8') if sys.version_info[0] == 2 else self.why) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + def validate(self): + return + + def __str__(self): + return repr(self) + + def __repr__(self): + L = ['%s=%r' % (key, value) + for key, value in self.__dict__.items()] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not (self == other) +all_structs.append(TheWorks) +TheWorks.thrift_spec = ( + None, # 0 + (1, TType.BOOL, 'field_1', None, None, ), # 1 + (2, TType.BYTE, 'field_2', None, None, ), # 2 + (3, TType.I16, 'field_3', None, None, ), # 3 + (4, TType.I32, 'field_4', None, None, ), # 4 + (5, TType.I64, 'field_5', None, None, ), # 5 + (6, TType.DOUBLE, 'field_6', None, None, ), # 6 + (7, TType.STRING, 'field_7', 'UTF8', None, ), # 7 + (8, TType.STRING, 'field_8', 'BINARY', None, ), # 8 + (9, TType.MAP, 'field_9', (TType.I32, None, TType.STRING, 'UTF8', False), None, ), # 9 + (10, TType.LIST, 'field_10', (TType.I32, None, False), None, ), # 10 + (11, TType.SET, 'field_11', (TType.STRING, 'UTF8', False), None, ), # 11 + (12, TType.BOOL, 'field_12', None, None, ), # 12 +) +all_structs.append(Param) +Param.thrift_spec = ( + None, # 0 + (1, TType.LIST, 'return_fields', (TType.STRING, 'UTF8', False), None, ), # 1 + (2, TType.STRUCT, 'the_works', [TheWorks, None], None, ), # 2 +) +all_structs.append(Result) +Result.thrift_spec = ( + None, # 0 + (1, TType.STRUCT, 'the_works', [TheWorks, None], None, ), # 1 +) +all_structs.append(AppException) +AppException.thrift_spec = ( + None, # 0 + (1, TType.STRING, 'why', 'UTF8', None, ), # 1 +) +fix_spec(all_structs) +del all_structs diff --git a/test/extensions/filters/network/thrift_proxy/driver/server.py b/test/extensions/filters/network/thrift_proxy/driver/server.py new file mode 100755 index 0000000000000..094a8d2338bfd --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/driver/server.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python + +import argparse +import logging +import sys + +from generated.example import Example +from generated.example.ttypes import ( + Result, TheWorks, AppException +) + +from thrift import Thrift, TMultiplexedProcessor +from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol +from thrift.server import TServer +from thrift.transport import TSocket +from thrift.transport import TTransport +from fbthrift import THeaderTransport +from finagle import TFinagleServerProcessor, TFinagleServerProtocol + + +class SuccessHandler: + def ping(self): + print("server: ping()") + + def poke(self): + print("server: poke()") + + def add(self, a, b): + result = a + b + print("server: add({0}, {1}) = {2}".format(a, b, result)) + return result + + def execute(self, param): + print("server: execute({0})".format(param)) + if "all" in param.return_fields: + return Result(param.the_works) + elif "none" in param.return_fields: + return Result(TheWorks()) + the_works = TheWorks() + for field, value in vars(param.the_works).items(): + if field in param.return_fields: + setattr(the_works, field, value) + return Result(the_works) + + +class IDLExceptionHandler: + def ping(self): + print("server: ping()") + + def poke(self): + print("server: poke()") + + def add(self, a, b): + result = a + b + print("server: add({0}, {1}) = {2}".format(a, b, result)) + return result + + def execute(self, param): + print("server: app error: execute failed") + raise AppException("execute failed") + + +class ExceptionHandler: + def ping(self): + print("server: ping failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for ping", + ) + + def poke(self): + print("server: poke failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for poke", + ) + + def add(self, a, b): + print("server: add failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for add", + ) + + def execute(self, param): + print("server: execute failure") + raise Thrift.TApplicationException( + type=Thrift.TApplicationException.INTERNAL_ERROR, + message="for execute", + ) + + +def main(cfg): + if cfg.unix: + if cfg.addr == "": + sys.exit("invalid listener unix domain socket: {}".format(cfg.addr)) + else: + try: + (host, port) = cfg.addr.rsplit(":", 1) + port = int(port) + except ValueError: + sys.exit("invalid listener address: {}".format(cfg.addr)) + + if cfg.response == "success": + handler = SuccessHandler() + elif cfg.response == "idl-exception": + handler = IDLExceptionHandler() + elif cfg.response == "exception": + # squelch traceback for the exception we throw + logging.getLogger().setLevel(logging.CRITICAL) + handler = ExceptionHandler() + else: + sys.exit("unknown server response mode {0}".format(cfg.response)) + + processor = Example.Processor(handler) + if cfg.service is not None: + # wrap processor with multiplexor + multi = TMultiplexedProcessor.TMultiplexedProcessor() + multi.registerProcessor(cfg.service, processor) + processor = multi + + if cfg.protocol == "finagle": + # wrap processor with finagle request/response header handler + processor = TFinagleServerProcessor.TFinagleServerProcessor(processor) + + if cfg.unix: + transport = TSocket.TServerSocket(unix_socket=cfg.addr) + else: + transport = TSocket.TServerSocket(host=host, port=port) + + if cfg.transport == "framed": + transport_factory = TTransport.TFramedTransportFactory() + elif cfg.transport == "unframed": + transport_factory = TTransport.TBufferedTransportFactory() + elif cfg.transport == "header": + transport_factory = THeaderTransport.THeaderTransportFactory() + else: + sys.exit("unknown transport {0}".format(cfg.transport)) + + if cfg.protocol == "binary": + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() + elif cfg.protocol == "compact": + protocol_factory = TCompactProtocol.TCompactProtocolFactory() + elif cfg.protocol == "json": + protocol_factory = TJSONProtocol.TJSONProtocolFactory() + elif cfg.protocol == "finagle": + protocol_factory = TFinagleServerProtocol.TFinagleServerProtocolFactory() + else: + sys.exit("unknown protocol {0}".format(cfg.protocol)) + + print("Thrift Server listening on {0} for {1} {2} requests".format( + cfg.addr, cfg.transport, cfg.protocol)) + if cfg.service is not None: + print("Thrift Server service name {0}".format(cfg.service)) + if cfg.response == "idl-exception": + print("Thrift Server will throw IDL exceptions when defined") + elif cfg.response == "exception": + print("Thrift Server will throw Thrift exceptions for all messages") + + server = TServer.TSimpleServer(processor, transport, transport_factory, protocol_factory) + try: + server.serve() + except KeyboardInterrupt: + print + + +if __name__ == "__main__": + logging.basicConfig() + parser = argparse.ArgumentParser(description="Thrift server to match client.py.") + parser.add_argument( + "-a", + "--addr", + metavar="ADDR", + dest="addr", + default=":0", + help="Listener address for server in the form host:port. The host is optional. If --unix" + + " is set, the address is the socket name.", + ) + parser.add_argument( + "-m", + "--multiplex", + metavar="SERVICE", + dest="service", + help="Enable service multiplexing and set the service name.", + ) + parser.add_argument( + "-p", + "--protocol", + help="Selects a protocol.", + dest="protocol", + default="binary", + choices=["binary", "compact", "json", "finagle"], + ) + parser.add_argument( + "-r", + "--response", + dest="response", + default="success", + choices=["success", "idl-exception", "exception"], + help="Controls how the server responds to requests", + ) + parser.add_argument( + "-t", + "--transport", + help="Selects a transport.", + dest="transport", + default="framed", + choices=["framed", "unframed", "header"], + ) + parser.add_argument( + "-u", + "--unix", + dest="unix", + action="store_true", + ) + cfg = parser.parse_args() + + try: + main(cfg) + except Thrift.TException as tx: + sys.exit("Thrift exception: {0}".format(tx.message)) diff --git a/test/extensions/filters/network/thrift_proxy/filter_test.cc b/test/extensions/filters/network/thrift_proxy/filter_test.cc deleted file mode 100644 index d192e47372098..0000000000000 --- a/test/extensions/filters/network/thrift_proxy/filter_test.cc +++ /dev/null @@ -1,559 +0,0 @@ -#include "common/buffer/buffer_impl.h" -#include "common/stats/stats_impl.h" - -#include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/filter.h" - -#include "test/extensions/filters/network/thrift_proxy/utility.h" -#include "test/mocks/network/mocks.h" -#include "test/test_common/printers.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::NiceMock; - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -class ThriftFilterTest : public testing::Test { -public: - ThriftFilterTest() {} - - void initializeFilter() { - for (auto counter : store_.counters()) { - counter->reset(); - } - - filter_.reset(new Filter("test.", store_)); - filter_->initializeReadFilterCallbacks(read_filter_callbacks_); - filter_->onNewConnection(); - - // NOP currently. - filter_->onAboveWriteBufferHighWatermark(); - filter_->onBelowWriteBufferLowWatermark(); - } - - void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { - uint8_t mt = static_cast(msg_type); - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, mt, // binary proto, type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0b, 0x00, 0x00, // begin string field - 0x00, 0x00, 0x00, 0x05, 'f', 'i', 'e', 'l', 'd', // string - 0x00, // stop field - }); - } - - void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, - int32_t seq_id, bool start) { - if (start) { - uint8_t mt = static_cast(msg_type); - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x2d, // framed: 45 bytes - 0x80, 0x01, 0x00, mt, // binary proto, type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0c, 0x00, 0x00, // begin struct field - 0x0b, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x05 // string length only - }); - } else { - addSeq(buffer, { - 'f', 'i', 'e', 'l', 'd', // string data - 0x0b, 0x00, 0x02, // begin string field - 0x00, 0x00, 0x00, 0x05, 'x', 'x', 'x', 'x', 'x', // string - 0x00, // stop field - 0x00, // stop field - }); - } - } - - void writeFramedBinaryTApplicationException(Buffer::Instance& buffer, int32_t seq_id) { - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x24, // framed: 36 bytes - 0x80, 0x01, 0x00, 0x03, // binary, exception - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0B, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x05, 'e', 'r', 'r', 'o', 'r', // string - 0x08, 0x00, 0x02, // begin i32 field - 0x00, 0x00, 0x00, 0x01, // exception type 1 - 0x00, // stop field - }); - } - - void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x23, // framed: 35 bytes - 0x80, 0x01, 0x00, 0x02, // binary proto, reply - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0C, 0x00, 0x02, // begin exception struct - 0x0B, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x03, 'e', 'r', 'r', // string - 0x00, // exception struct stop - 0x00, // reply struct stop field - }); - } - - Buffer::OwnedImpl buffer_; - Buffer::OwnedImpl write_buffer_; - Stats::IsolatedStoreImpl store_; - std::unique_ptr filter_; - NiceMock read_filter_callbacks_; -}; - -TEST_F(ThriftFilterTest, OnDataHandlesThriftCall) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesThriftOneWay) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(0U, store_.counter("test.request_call").value()); - EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesFrameSplitAcrossBuffers) { - initializeFilter(); - - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); - std::string expected_contents = bufferToString(buffer_); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - // Filter passes on the partial buffer, up to the last 4 bytes which it needs to resume the - // decoder on the next call. - std::string contents = bufferToString(buffer_); - EXPECT_EQ(len - 4, buffer_.length()); - EXPECT_EQ(expected_contents.substr(0, len - 4), contents); - - buffer_.drain(buffer_.length()); - - // Complete the buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); - expected_contents = expected_contents.substr(len - 4) + bufferToString(buffer_); - len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - // Filter buffered bytes from end of first buffer and passes them on now. - contents = bufferToString(buffer_); - EXPECT_EQ(len + 4, buffer_.length()); - EXPECT_EQ(expected_contents, contents); - - EXPECT_EQ(1U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesInvalidMsgType) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Reply, 0x0F); // reply is invalid for a request - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(0U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(1U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesProtocolError) { - initializeFilter(); - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(len, buffer_.length()); - - // Sniffing is now disabled. - buffer_.drain(buffer_.length()); - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(0U, store_.counter("test.request").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesProtocolErrorOnWrite) { - initializeFilter(); - - // Start the read buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - len -= buffer_.length(); - - // Disable sniffing - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); - - // Complete the read buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); - len += buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - len -= buffer_.length(); - EXPECT_EQ(0, len); -} - -TEST_F(ThriftFilterTest, OnDataStopsSniffingWithTooManyPendingCalls) { - initializeFilter(); - for (int i = 0; i < 64; i++) { - writeFramedBinaryMessage(buffer_, MessageType::Call, i); - } - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(64U, store_.gauge("test.request_active").value()); - buffer_.drain(buffer_.length()); - - // Sniffing is now disabled. - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 100); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(64U, store_.gauge("test.request_active").value()); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesThriftReply) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesOutOrOrderThriftReply) { - initializeFilter(); - - // set up two requests - writeFramedBinaryMessage(buffer_, MessageType::Call, 1); - writeFramedBinaryMessage(buffer_, MessageType::Call, 2); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(2U, store_.counter("test.request").value()); - EXPECT_EQ(2U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 2); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - write_buffer_.drain(write_buffer_.length()); - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 1); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(2U, store_.counter("test.response").value()); - EXPECT_EQ(2U, store_.counter("test.response_reply").value()); - EXPECT_EQ(2U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesFrameSplitAcrossBuffers) { - initializeFilter(); - - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F, true); - std::string expected_contents = bufferToString(write_buffer_); - uint64_t len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - // Filter passes on the partial buffer, up to the last 4 bytes which it needs to resume the - // decoder on the next call. - std::string contents = bufferToString(write_buffer_); - EXPECT_EQ(len - 4, write_buffer_.length()); - EXPECT_EQ(expected_contents.substr(0, len - 4), contents); - - write_buffer_.drain(write_buffer_.length()); - - // Complete the buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F, false); - expected_contents = expected_contents.substr(len - 4) + bufferToString(write_buffer_); - len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - // Filter buffered bytes from end of first buffer and passes them on now. - contents = bufferToString(write_buffer_); - EXPECT_EQ(len + 4, write_buffer_.length()); - EXPECT_EQ(expected_contents, contents); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesTApplicationException) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryTApplicationException(write_buffer_, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(0U, store_.counter("test.response_reply").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(1U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesIDLException) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryIDLException(write_buffer_, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(1U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesInvalidMsgType) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); // call is invalid for response - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesProtocolError) { - initializeFilter(); - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(len, buffer_.length()); - - // Sniffing is now disabled. - write_buffer_.drain(write_buffer_.length()); - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 1); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesProtocolErrorOnData) { - initializeFilter(); - - // Set up a request for the partial write - writeFramedBinaryMessage(buffer_, MessageType::Call, 1); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - buffer_.drain(buffer_.length()); - - // Start the write buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 1, true); - uint64_t len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - len -= write_buffer_.length(); - - // Force an error on the next request. - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x02, // sequence id - 0x00, // struct stop field - }); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); - - // Complete the read buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 1, false); - len += write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - len -= write_buffer_.length(); - EXPECT_EQ(0, len); -} - -TEST_F(ThriftFilterTest, OnEvent) { - // No active calls - { - initializeFilter(); - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(0U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - } - - // Close mid-request - { - initializeFilter(); - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0x01, // binary proto, call type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x0F, // seq id - }); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - } - - // Close before response - { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - } - - // Close mid-response - { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0x02, // binary proto, reply type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x0F, // seq id - }); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - write_buffer_.drain(write_buffer_.length()); - } -} - -TEST_F(ThriftFilterTest, ResponseWithUnknownSequenceID) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x10); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc new file mode 100644 index 0000000000000..2999de7edbf53 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc @@ -0,0 +1,115 @@ +#include "envoy/common/exception.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +TEST(FramedTransportTest, Name) { + FramedTransportImpl transport; + EXPECT_EQ(transport.name(), "framed"); +} + +TEST(FramedTransportTest, Type) { + FramedTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Framed); +} + +TEST(FramedTransportTest, NotEnoughData) { + Buffer::OwnedImpl buffer; + FramedTransportImpl transport; + absl::optional size = 1; + + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(1), size); + + addRepeated(buffer, 3, 0); + + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(1), size); +} + +TEST(FramedTransportTest, InvalidFrameSize) { + FramedTransportImpl transport; + + { + Buffer::OwnedImpl buffer; + addInt32(buffer, -1); + + absl::optional size = 1; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + "invalid thrift framed transport frame size -1"); + EXPECT_EQ(absl::optional(1), size); + } + + { + Buffer::OwnedImpl buffer; + addInt32(buffer, 0x7fffffff); + + absl::optional size = 1; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + "invalid thrift framed transport frame size 2147483647"); + EXPECT_EQ(absl::optional(1), size); + } +} + +TEST(FramedTransportTest, DecodeFrameStart) { + FramedTransportImpl transport; + + Buffer::OwnedImpl buffer; + addInt32(buffer, 100); + EXPECT_EQ(buffer.length(), 4); + + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100U), size); + EXPECT_EQ(buffer.length(), 0); +} + +TEST(FramedTransportTest, DecodeFrameEnd) { + FramedTransportImpl transport; + + Buffer::OwnedImpl buffer; + + EXPECT_TRUE(transport.decodeFrameEnd(buffer)); +} + +TEST(FramedTransportTest, EncodeFrame) { + FramedTransportImpl transport; + + { + Buffer::OwnedImpl message; + message.add("fake message"); + + Buffer::OwnedImpl buffer; + transport.encodeFrame(buffer, message); + + EXPECT_EQ(0, message.length()); + EXPECT_EQ(std::string("\0\0\0\xC" + "fake message", + 16), + buffer.toString()); + } + + { + Buffer::OwnedImpl message; + Buffer::OwnedImpl buffer; + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, message), EnvoyException, + "invalid thrift framed transport frame size 0"); + } +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/integration_test.cc b/test/extensions/filters/network/thrift_proxy/integration_test.cc new file mode 100644 index 0000000000000..b15ea8fb4e820 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/integration_test.cc @@ -0,0 +1,278 @@ +#include + +#include + +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +#include "test/integration/integration.h" +#include "test/test_common/environment.h" +#include "test/test_common/network_utility.h" + +#include "gtest/gtest.h" + +using testing::Combine; +using testing::TestParamInfo; +using testing::TestWithParam; +using testing::Values; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +std::string thrift_config; + +enum class CallResult { + Success, + IDLException, + Exception, +}; + +class ThriftConnManagerIntegrationTest + : public BaseIntegrationTest, + public TestWithParam> { +public: + ThriftConnManagerIntegrationTest() + : BaseIntegrationTest(Network::Address::IpVersion::v4, thrift_config) {} + + static void SetUpTestCase() { + thrift_config = ConfigHelper::BASE_CONFIG + R"EOF( + filter_chains: + filters: + - name: envoy.filters.network.thrift_proxy + config: + stat_prefix: thrift_stats + route_config: + name: "routes" + routes: + - match: + method_name: "execute" + route: + cluster: "cluster_0" + - match: + method_name: "poke" + route: + cluster: "cluster_0" + )EOF"; + } + + void initializeCall(CallResult result) { + std::tie(transport_, protocol_, multiplexed_) = GetParam(); + + std::string result_mode; + switch (result) { + case CallResult::Success: + result_mode = "success"; + break; + case CallResult::IDLException: + result_mode = "idl-exception"; + break; + case CallResult::Exception: + result_mode = "exception"; + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + + preparePayloads(result_mode, "execute"); + ASSERT(request_bytes_.length() > 0); + ASSERT(response_bytes_.length() > 0); + + BaseIntegrationTest::initialize(); + } + + void initializeOneway() { + std::tie(transport_, protocol_, multiplexed_) = GetParam(); + + preparePayloads("success", "poke"); + ASSERT(request_bytes_.length() > 0); + ASSERT(response_bytes_.length() == 0); + + BaseIntegrationTest::initialize(); + } + + void preparePayloads(std::string result_mode, std::string method) { + std::vector args = { + TestEnvironment::runfilesPath( + "test/extensions/filters/network/thrift_proxy/driver/generate_fixture.sh"), + result_mode, + transport_, + protocol_, + }; + + if (multiplexed_) { + args.push_back("svcname"); + } + args.push_back("--"); + args.push_back(method); + + TestEnvironment::exec(args); + + std::stringstream file_base; + file_base << "{{ test_tmpdir }}/" << transport_ << "-" << protocol_ << "-"; + if (multiplexed_) { + file_base << "svcname-"; + } + file_base << result_mode; + + readAll(file_base.str() + ".request", request_bytes_); + readAll(file_base.str() + ".response", response_bytes_); + } + + void TearDown() override { + test_server_.reset(); + fake_upstreams_.clear(); + } + +protected: + void readAll(std::string file, Buffer::OwnedImpl& buffer) { + file = TestEnvironment::substitute(file, version_); + + std::ifstream is(file, std::ios::binary | std::ios::ate); + RELEASE_ASSERT(!is.fail(), ""); + + std::ifstream::pos_type len = is.tellg(); + if (len > 0) { + std::vector bytes(len, 0); + is.seekg(0, std::ios::beg); + RELEASE_ASSERT(!is.fail(), ""); + + is.read(bytes.data(), len); + RELEASE_ASSERT(!is.fail(), ""); + + buffer.add(bytes.data(), len); + } + } + + std::string transport_; + std::string protocol_; + bool multiplexed_; + + std::string result_; + + Buffer::OwnedImpl request_bytes_; + Buffer::OwnedImpl response_bytes_; +}; + +static std::string +paramToString(const TestParamInfo>& params) { + std::string transport, protocol; + bool multiplexed; + std::tie(transport, protocol, multiplexed) = params.param; + transport = StringUtil::toUpper(absl::string_view(transport).substr(0, 1)) + transport.substr(1); + protocol = StringUtil::toUpper(absl::string_view(protocol).substr(0, 1)) + protocol.substr(1); + if (multiplexed) { + return fmt::format("{}{}Multiplexed", transport, protocol); + } + return fmt::format("{}{}", transport, protocol); +} + +INSTANTIATE_TEST_CASE_P( + TransportAndProtocol, ThriftConnManagerIntegrationTest, + Combine(Values(TransportNames::get().FRAMED, TransportNames::get().UNFRAMED), + Values(ProtocolNames::get().BINARY, ProtocolNames::get().COMPACT), Values(false, true)), + paramToString); + +TEST_P(ThriftConnManagerIntegrationTest, Success) { + initializeCall(CallResult::Success); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(request_bytes_.toString()); + + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(request_bytes_.length(), &data)); + Buffer::OwnedImpl upstream_request(data); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); + + ASSERT_TRUE(fake_upstream_connection->write(response_bytes_.toString())); + + tcp_client->waitForData(response_bytes_.toString()); + tcp_client->close(); + + EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); + EXPECT_EQ(1U, counter->value()); + counter = test_server_->counter("thrift.thrift_stats.response_success"); + EXPECT_EQ(1U, counter->value()); +} + +TEST_P(ThriftConnManagerIntegrationTest, IDLException) { + initializeCall(CallResult::IDLException); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(request_bytes_.toString()); + + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(request_bytes_.length(), &data)); + Buffer::OwnedImpl upstream_request(data); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); + + ASSERT_TRUE(fake_upstream_connection->write(response_bytes_.toString())); + + tcp_client->waitForData(response_bytes_.toString()); + tcp_client->close(); + + EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); + EXPECT_EQ(1U, counter->value()); + counter = test_server_->counter("thrift.thrift_stats.response_error"); + EXPECT_EQ(1U, counter->value()); +} + +TEST_P(ThriftConnManagerIntegrationTest, Exception) { + initializeCall(CallResult::Exception); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(request_bytes_.toString()); + + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(request_bytes_.length(), &data)); + Buffer::OwnedImpl upstream_request(data); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); + + ASSERT_TRUE(fake_upstream_connection->write(response_bytes_.toString())); + + tcp_client->waitForData(response_bytes_.toString()); + tcp_client->close(); + + EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_call"); + EXPECT_EQ(1U, counter->value()); + counter = test_server_->counter("thrift.thrift_stats.response_exception"); + EXPECT_EQ(1U, counter->value()); +} + +TEST_P(ThriftConnManagerIntegrationTest, Oneway) { + initializeOneway(); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(request_bytes_.toString()); + + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(request_bytes_.length(), &data)); + Buffer::OwnedImpl upstream_request(data); + EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); + + tcp_client->close(); + + Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_oneway"); + EXPECT_EQ(1U, counter->value()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index b44d9b95dab57..caa93654233e8 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -2,25 +2,77 @@ #include "gtest/gtest.h" +using testing::Return; using testing::ReturnRef; +using testing::_; namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -MockTransportCallbacks::MockTransportCallbacks() {} -MockTransportCallbacks::~MockTransportCallbacks() {} +MockConfig::MockConfig() {} +MockConfig::~MockConfig() {} -MockTransport::MockTransport() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); } +MockTransport::MockTransport() { + ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); + ON_CALL(*this, type()).WillByDefault(Return(type_)); +} MockTransport::~MockTransport() {} -MockProtocolCallbacks::MockProtocolCallbacks() {} -MockProtocolCallbacks::~MockProtocolCallbacks() {} - -MockProtocol::MockProtocol() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); } +MockProtocol::MockProtocol() { + ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); + ON_CALL(*this, type()).WillByDefault(Return(type_)); +} MockProtocol::~MockProtocol() {} +MockDecoderCallbacks::MockDecoderCallbacks() {} +MockDecoderCallbacks::~MockDecoderCallbacks() {} + +namespace ThriftFilters { + +MockDecoderFilter::MockDecoderFilter() { + ON_CALL(*this, transportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, transportEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, boolValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, byteValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int16Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int32Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int64Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, doubleValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, stringValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setEnd()).WillByDefault(Return(FilterStatus::Continue)); +} +MockDecoderFilter::~MockDecoderFilter() {} + +MockDecoderFilterCallbacks::MockDecoderFilterCallbacks() { + ON_CALL(*this, streamId()).WillByDefault(Return(stream_id_)); + ON_CALL(*this, connection()).WillByDefault(Return(&connection_)); +} +MockDecoderFilterCallbacks::~MockDecoderFilterCallbacks() {} + +} // namespace ThriftFilters + +namespace Router { + +MockRouteEntry::MockRouteEntry() {} +MockRouteEntry::~MockRouteEntry() {} + +MockRoute::MockRoute() {} +MockRoute::~MockRoute() {} + +} // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 6d8e6b98685c9..f932bc808d418 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -1,25 +1,33 @@ #pragma once +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" #include "extensions/filters/network/thrift_proxy/transport.h" +#include "test/mocks/network/mocks.h" #include "test/test_common/printers.h" #include "gmock/gmock.h" +using testing::NiceMock; + namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -class MockTransportCallbacks : public TransportCallbacks { +class MockConfig : public Config { public: - MockTransportCallbacks(); - ~MockTransportCallbacks(); + MockConfig(); + ~MockConfig(); - // ThriftProxy::TransportCallbacks - MOCK_METHOD1(transportFrameStart, void(absl::optional size)); - MOCK_METHOD0(transportFrameComplete, void()); + // ThriftProxy::Config + MOCK_METHOD0(filterFactory, ThriftFilters::FilterChainFactory&()); + MOCK_METHOD0(stats, ThriftFilterStats&()); + MOCK_METHOD1(createDecoder, DecoderPtr(DecoderCallbacks&)); + MOCK_METHOD0(routerConfig, Router::Config&()); }; class MockTransport : public Transport { @@ -29,23 +37,13 @@ class MockTransport : public Transport { // ThriftProxy::Transport MOCK_CONST_METHOD0(name, const std::string&()); - MOCK_METHOD1(decodeFrameStart, bool(Buffer::Instance&)); + MOCK_CONST_METHOD0(type, TransportType()); + MOCK_METHOD2(decodeFrameStart, bool(Buffer::Instance&, absl::optional&)); MOCK_METHOD1(decodeFrameEnd, bool(Buffer::Instance&)); + MOCK_METHOD2(encodeFrame, void(Buffer::Instance&, Buffer::Instance&)); std::string name_{"mock"}; -}; - -class MockProtocolCallbacks : public ProtocolCallbacks { -public: - MockProtocolCallbacks(); - ~MockProtocolCallbacks(); - - // ThriftProxy::ProtocolCallbacks - MOCK_METHOD3(messageStart, void(const absl::string_view, MessageType, int32_t)); - MOCK_METHOD1(structBegin, void(const absl::string_view)); - MOCK_METHOD3(structField, void(const absl::string_view, FieldType, int16_t)); - MOCK_METHOD0(structEnd, void()); - MOCK_METHOD0(messageComplete, void()); + TransportType type_{TransportType::Auto}; }; class MockProtocol : public Protocol { @@ -55,6 +53,7 @@ class MockProtocol : public Protocol { // ThriftProxy::Protocol MOCK_CONST_METHOD0(name, const std::string&()); + MOCK_CONST_METHOD0(type, ProtocolType()); MOCK_METHOD4(readMessageBegin, bool(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id)); MOCK_METHOD1(readMessageEnd, bool(Buffer::Instance& buffer)); @@ -79,9 +78,125 @@ class MockProtocol : public Protocol { MOCK_METHOD2(readString, bool(Buffer::Instance& buffer, std::string& value)); MOCK_METHOD2(readBinary, bool(Buffer::Instance& buffer, std::string& value)); + MOCK_METHOD4(writeMessageBegin, void(Buffer::Instance& buffer, const std::string& name, + MessageType msg_type, int32_t seq_id)); + MOCK_METHOD1(writeMessageEnd, void(Buffer::Instance& buffer)); + MOCK_METHOD2(writeStructBegin, void(Buffer::Instance& buffer, const std::string& name)); + MOCK_METHOD1(writeStructEnd, void(Buffer::Instance& buffer)); + MOCK_METHOD4(writeFieldBegin, void(Buffer::Instance& buffer, const std::string& name, + FieldType field_type, int16_t field_id)); + MOCK_METHOD1(writeFieldEnd, void(Buffer::Instance& buffer)); + MOCK_METHOD4(writeMapBegin, void(Buffer::Instance& buffer, FieldType key_type, + FieldType value_type, uint32_t size)); + MOCK_METHOD1(writeMapEnd, void(Buffer::Instance& buffer)); + MOCK_METHOD3(writeListBegin, void(Buffer::Instance& buffer, FieldType elem_type, uint32_t size)); + MOCK_METHOD1(writeListEnd, void(Buffer::Instance& buffer)); + MOCK_METHOD3(writeSetBegin, void(Buffer::Instance& buffer, FieldType elem_type, uint32_t size)); + MOCK_METHOD1(writeSetEnd, void(Buffer::Instance& buffer)); + MOCK_METHOD2(writeBool, void(Buffer::Instance& buffer, bool value)); + MOCK_METHOD2(writeByte, void(Buffer::Instance& buffer, uint8_t value)); + MOCK_METHOD2(writeInt16, void(Buffer::Instance& buffer, int16_t value)); + MOCK_METHOD2(writeInt32, void(Buffer::Instance& buffer, int32_t value)); + MOCK_METHOD2(writeInt64, void(Buffer::Instance& buffer, int64_t value)); + MOCK_METHOD2(writeDouble, void(Buffer::Instance& buffer, double value)); + MOCK_METHOD2(writeString, void(Buffer::Instance& buffer, const std::string& value)); + MOCK_METHOD2(writeBinary, void(Buffer::Instance& buffer, const std::string& value)); + std::string name_{"mock"}; + ProtocolType type_{ProtocolType::Auto}; +}; + +class MockDecoderCallbacks : public DecoderCallbacks { +public: + MockDecoderCallbacks(); + ~MockDecoderCallbacks(); + + // ThriftProxy::DecoderCallbacks + MOCK_METHOD0(newDecoderFilter, ThriftFilters::DecoderFilter&()); +}; + +namespace ThriftFilters { + +class MockDecoderFilter : public DecoderFilter { +public: + MockDecoderFilter(); + ~MockDecoderFilter(); + + // ThriftProxy::ThriftFilters::DecoderFilter + MOCK_METHOD0(onDestroy, void()); + MOCK_METHOD1(setDecoderFilterCallbacks, void(DecoderFilterCallbacks& callbacks)); + MOCK_METHOD0(resetUpstreamConnection, void()); + MOCK_METHOD1(transportBegin, FilterStatus(absl::optional size)); + MOCK_METHOD0(transportEnd, FilterStatus()); + MOCK_METHOD3(messageBegin, + FilterStatus(const absl::string_view name, MessageType msg_type, int32_t seq_id)); + MOCK_METHOD0(messageEnd, FilterStatus()); + MOCK_METHOD1(structBegin, FilterStatus(const absl::string_view name)); + MOCK_METHOD0(structEnd, FilterStatus()); + MOCK_METHOD3(fieldBegin, + FilterStatus(const absl::string_view name, FieldType msg_type, int16_t field_id)); + MOCK_METHOD0(fieldEnd, FilterStatus()); + MOCK_METHOD1(boolValue, FilterStatus(bool value)); + MOCK_METHOD1(byteValue, FilterStatus(uint8_t value)); + MOCK_METHOD1(int16Value, FilterStatus(int16_t value)); + MOCK_METHOD1(int32Value, FilterStatus(int32_t value)); + MOCK_METHOD1(int64Value, FilterStatus(int64_t value)); + MOCK_METHOD1(doubleValue, FilterStatus(double value)); + MOCK_METHOD1(stringValue, FilterStatus(absl::string_view value)); + MOCK_METHOD3(mapBegin, FilterStatus(FieldType key_type, FieldType value_type, uint32_t size)); + MOCK_METHOD0(mapEnd, FilterStatus()); + MOCK_METHOD2(listBegin, FilterStatus(FieldType elem_type, uint32_t size)); + MOCK_METHOD0(listEnd, FilterStatus()); + MOCK_METHOD2(setBegin, FilterStatus(FieldType elem_type, uint32_t size)); + MOCK_METHOD0(setEnd, FilterStatus()); +}; + +class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { +public: + MockDecoderFilterCallbacks(); + ~MockDecoderFilterCallbacks(); + + // ThriftProxy::ThriftFilters::DecoderFilterCallbacks + MOCK_CONST_METHOD0(streamId, uint64_t()); + MOCK_CONST_METHOD0(connection, const Network::Connection*()); + MOCK_METHOD0(continueDecoding, void()); + MOCK_METHOD0(route, Router::RouteConstSharedPtr()); + MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); + MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); + void sendLocalReply(DirectResponsePtr&& response) override { sendLocalReply_(response); } + MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); + MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); + MOCK_METHOD0(resetDownstreamConnection, void()); + + MOCK_METHOD1(sendLocalReply_, void(DirectResponsePtr&)); + + uint64_t stream_id_{1}; + NiceMock connection_; +}; + +} // namespace ThriftFilters + +namespace Router { + +class MockRouteEntry : public RouteEntry { +public: + MockRouteEntry(); + ~MockRouteEntry(); + + // ThriftProxy::Router::RouteEntry + MOCK_CONST_METHOD0(clusterName, const std::string&()); +}; + +class MockRoute : public Route { +public: + MockRoute(); + ~MockRoute(); + + // ThriftProxy::Router::Route + MOCK_CONST_METHOD0(routeEntry, const RouteEntry*()); }; +} // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/protocol_test.cc b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc similarity index 66% rename from test/extensions/filters/network/thrift_proxy/protocol_test.cc rename to test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc index 43d0417f7bb05..7a8fef74a1490 100644 --- a/test/extensions/filters/network/thrift_proxy/protocol_test.cc +++ b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc @@ -2,9 +2,9 @@ #include "common/buffer/buffer_impl.h" -#include "extensions/filters/network/thrift_proxy/binary_protocol.h" -#include "extensions/filters/network/thrift_proxy/compact_protocol.h" -#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" @@ -24,10 +24,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +TEST(ProtocolNames, FromType) { + for (int i = 0; i <= static_cast(ProtocolType::LastProtocolType); i++) { + ProtocolType type = static_cast(i); + EXPECT_NE("", ProtocolNames::get().fromType(type)); + } +} + TEST(AutoProtocolTest, NotEnoughData) { Buffer::OwnedImpl buffer; - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -41,8 +47,7 @@ TEST(AutoProtocolTest, NotEnoughData) { TEST(AutoProtocolTest, UnknownProtocol) { Buffer::OwnedImpl buffer; - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -59,8 +64,7 @@ TEST(AutoProtocolTest, UnknownProtocol) { TEST(AutoProtocolTest, ReadMessageBegin) { // Binary Protocol { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -79,12 +83,12 @@ TEST(AutoProtocolTest, ReadMessageBegin) { EXPECT_EQ(seq_id, 1); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "binary(auto)"); + EXPECT_EQ(proto.type(), ProtocolType::Binary); } // Compact protocol { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = 1; @@ -101,13 +105,13 @@ TEST(AutoProtocolTest, ReadMessageBegin) { EXPECT_EQ(seq_id, 0x0102); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "compact(auto)"); + EXPECT_EQ(proto.type(), ProtocolType::Compact); } } -TEST(AutoProtocolTest, Delegation) { +TEST(AutoProtocolTest, ReadDelegation) { NiceMock* proto = new NiceMock(); - NiceMock dummy_cb; - AutoProtocolImpl auto_proto(dummy_cb); + AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // readMessageBegin @@ -230,12 +234,103 @@ TEST(AutoProtocolTest, Delegation) { } } +TEST(AutoProtocolTest, WriteDelegation) { + NiceMock* proto = new NiceMock(); + AutoProtocolImpl auto_proto; + auto_proto.setProtocol(ProtocolPtr{proto}); + + // writeMessageBegin + Buffer::OwnedImpl buffer; + EXPECT_CALL(*proto, writeMessageBegin(Ref(buffer), "name", MessageType::Call, 100)); + auto_proto.writeMessageBegin(buffer, "name", MessageType::Call, 100); + + // writeMessageEnd + EXPECT_CALL(*proto, writeMessageEnd(Ref(buffer))); + auto_proto.writeMessageEnd(buffer); + + // writeStructBegin + EXPECT_CALL(*proto, writeStructBegin(Ref(buffer), "name")); + auto_proto.writeStructBegin(buffer, "name"); + + // writeStructEnd + EXPECT_CALL(*proto, writeStructEnd(Ref(buffer))); + auto_proto.writeStructEnd(buffer); + + // writeFieldBegin + EXPECT_CALL(*proto, writeFieldBegin(Ref(buffer), "name", FieldType::Stop, 100)); + auto_proto.writeFieldBegin(buffer, "name", FieldType::Stop, 100); + + // writeFieldEnd + EXPECT_CALL(*proto, writeFieldEnd(Ref(buffer))); + auto_proto.writeFieldEnd(buffer); + + // writeMapBegin + EXPECT_CALL(*proto, writeMapBegin(Ref(buffer), FieldType::I32, FieldType::String, 100)); + auto_proto.writeMapBegin(buffer, FieldType::I32, FieldType::String, 100); + + // writeMapEnd + EXPECT_CALL(*proto, writeMapEnd(Ref(buffer))); + auto_proto.writeMapEnd(buffer); + + // writeListBegin + EXPECT_CALL(*proto, writeListBegin(Ref(buffer), FieldType::String, 100)); + auto_proto.writeListBegin(buffer, FieldType::String, 100); + + // writeListEnd + EXPECT_CALL(*proto, writeListEnd(Ref(buffer))); + auto_proto.writeListEnd(buffer); + + // writeSetBegin + EXPECT_CALL(*proto, writeSetBegin(Ref(buffer), FieldType::String, 100)); + auto_proto.writeSetBegin(buffer, FieldType::String, 100); + + // writeSetEnd + EXPECT_CALL(*proto, writeSetEnd(Ref(buffer))); + auto_proto.writeSetEnd(buffer); + + // writeBool + EXPECT_CALL(*proto, writeBool(Ref(buffer), true)); + auto_proto.writeBool(buffer, true); + + // writeByte + EXPECT_CALL(*proto, writeByte(Ref(buffer), 100)); + auto_proto.writeByte(buffer, 100); + + // writeInt16 + EXPECT_CALL(*proto, writeInt16(Ref(buffer), 100)); + auto_proto.writeInt16(buffer, 100); + + // writeInt32 + EXPECT_CALL(*proto, writeInt32(Ref(buffer), 100)); + auto_proto.writeInt32(buffer, 100); + + // writeInt64 + EXPECT_CALL(*proto, writeInt64(Ref(buffer), 100)); + auto_proto.writeInt64(buffer, 100); + + // writeDouble + EXPECT_CALL(*proto, writeDouble(Ref(buffer), 10.0)); + auto_proto.writeDouble(buffer, 10.0); + + // writeString + EXPECT_CALL(*proto, writeString(Ref(buffer), "string")); + auto_proto.writeString(buffer, "string"); + + // writeBinary + EXPECT_CALL(*proto, writeBinary(Ref(buffer), "binary")); + auto_proto.writeBinary(buffer, "binary"); +} + TEST(AutoProtocolTest, Name) { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; EXPECT_EQ(proto.name(), "auto"); } +TEST(AutoProtocolTest, Type) { + AutoProtocolImpl proto; + EXPECT_EQ(proto.type(), ProtocolType::Auto); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc new file mode 100644 index 0000000000000..0fe593ed2cddd --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -0,0 +1,619 @@ +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/route.pb.h" +#include "envoy/config/filter/network/thrift_proxy/v2alpha1/route.pb.validate.h" +#include "envoy/tcp/conn_pool.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/router/config.h" +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" +#include "test/test_common/registry.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ContainsRegex; +using testing::Invoke; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; +using testing::Test; +using testing::TestWithParam; +using testing::Values; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +namespace { + +envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration +parseRouteConfigurationFromV2Yaml(const std::string& yaml) { + envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration route_config; + MessageUtil::loadFromYaml(yaml, route_config); + MessageUtil::validate(route_config); + return route_config; +} + +class TestNamedTransportConfigFactory : public NamedTransportConfigFactory { +public: + TestNamedTransportConfigFactory(std::function f) : f_(f) {} + + TransportPtr createTransport() override { return TransportPtr{f_()}; } + std::string name() override { return TransportNames::get().FRAMED; } + + std::function f_; +}; + +class TestNamedProtocolConfigFactory : public NamedProtocolConfigFactory { +public: + TestNamedProtocolConfigFactory(std::function f) : f_(f) {} + + ProtocolPtr createProtocol() override { return ProtocolPtr{f_()}; } + std::string name() override { return ProtocolNames::get().BINARY; } + + std::function f_; +}; + +} // namespace + +class ThriftRouterTestBase { +public: + ThriftRouterTestBase() + : transport_factory_([&]() -> MockTransport* { return transport_; }), + protocol_factory_([&]() -> MockProtocol* { return protocol_; }), + transport_register_(transport_factory_), protocol_register_(protocol_factory_) {} + + void initializeRouter() { + route_ = new NiceMock(); + route_ptr_.reset(route_); + + router_.reset(new Router(context_.clusterManager())); + + EXPECT_EQ(nullptr, router_->downstreamConnection()); + + router_->setDecoderFilterCallbacks(callbacks_); + } + + void startRequest(MessageType msg_type) { + msg_type_ = msg_type; + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportBegin({})); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, msg_type_, seq_id_)); + + NiceMock connection; + EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection)); + EXPECT_EQ(&connection, router_->downstreamConnection()); + + // Not yet implemented: + EXPECT_EQ(absl::optional(), router_->computeHashKey()); + EXPECT_EQ(nullptr, router_->metadataMatchCriteria()); + EXPECT_EQ(nullptr, router_->downstreamHeaders()); + } + + void connectUpstream() { + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { + upstream_callbacks_ = &cb; + })); + + EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); + transport_ = new NiceMock(); + ON_CALL(*transport_, type()).WillByDefault(Return(TransportType::Framed)); + + EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); + protocol_ = new NiceMock(); + ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); + EXPECT_CALL(*protocol_, writeMessageBegin(_, method_name_, msg_type_, seq_id_)); + + EXPECT_CALL(callbacks_, continueDecoding()); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + } + + void sendTrivialStruct(FieldType field_type) { + EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({})); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, 1)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldBegin({}, field_type, 1)); + + sendTrivialValue(field_type); + + EXPECT_CALL(*protocol_, writeFieldEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldEnd()); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", FieldType::Stop, 0)); + EXPECT_CALL(*protocol_, writeStructEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structEnd()); + } + + void sendTrivialValue(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + EXPECT_CALL(*protocol_, writeBool(_, true)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->boolValue(true)); + break; + case FieldType::Byte: + EXPECT_CALL(*protocol_, writeByte(_, 2)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->byteValue(2)); + break; + case FieldType::I16: + EXPECT_CALL(*protocol_, writeInt16(_, 3)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int16Value(3)); + break; + case FieldType::I32: + EXPECT_CALL(*protocol_, writeInt32(_, 4)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(4)); + break; + case FieldType::I64: + EXPECT_CALL(*protocol_, writeInt64(_, 5)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int64Value(5)); + break; + case FieldType::Double: + EXPECT_CALL(*protocol_, writeDouble(_, 6.0)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->doubleValue(6.0)); + break; + case FieldType::String: + EXPECT_CALL(*protocol_, writeString(_, "seven")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->stringValue("seven")); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + void completeRequest() { + EXPECT_CALL(*protocol_, writeMessageEnd(_)); + EXPECT_CALL(*transport_, encodeFrame(_, _)); + EXPECT_CALL(upstream_connection_, write(_, false)); + + if (msg_type_ == MessageType::Oneway) { + EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, released(Ref(upstream_connection_))); + } + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->messageEnd()); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportEnd()); + } + + void returnResponse() { + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); + upstream_callbacks_->onUpstreamData(buffer, false); + + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, released(Ref(upstream_connection_))); + upstream_callbacks_->onUpstreamData(buffer, false); + } + + void destroyRouter() { + router_->onDestroy(); + router_.reset(); + } + + TestNamedTransportConfigFactory transport_factory_; + TestNamedProtocolConfigFactory protocol_factory_; + Registry::InjectFactory transport_register_; + Registry::InjectFactory protocol_register_; + + NiceMock context_; + NiceMock callbacks_; + NiceMock* transport_{}; + NiceMock* protocol_{}; + NiceMock* route_{}; + NiceMock route_entry_; + NiceMock* host_{}; + + RouteConstSharedPtr route_ptr_; + std::unique_ptr router_; + + std::string cluster_name_{"cluster"}; + + std::string method_name_{"method"}; + MessageType msg_type_{MessageType::Call}; + int32_t seq_id_{1}; + + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks_{}; + NiceMock upstream_connection_; +}; + +class ThriftRouterTest : public ThriftRouterTestBase, public Test { +public: + ThriftRouterTest() {} +}; + +class ThriftRouterFieldTypeTest : public ThriftRouterTestBase, public TestWithParam { +public: + ThriftRouterFieldTypeTest() {} +}; + +INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, ThriftRouterFieldTypeTest, + Values(FieldType::Bool, FieldType::Byte, FieldType::I16, FieldType::I32, + FieldType::I64, FieldType::Double, FieldType::String), + fieldTypeParamToString); + +class ThriftRouterContainerTest : public ThriftRouterTestBase, public TestWithParam { +public: + ThriftRouterContainerTest() {} +}; + +INSTANTIATE_TEST_CASE_P(ContainerFieldTypes, ThriftRouterContainerTest, + Values(FieldType::Map, FieldType::List, FieldType::Set), + fieldTypeParamToString); + +TEST_F(ThriftRouterTest, PoolRemoteConnectionFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + context_.cluster_manager_.tcp_conn_pool_.poolFailure( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); +} + +TEST_F(ThriftRouterTest, PoolLocalConnectionFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + context_.cluster_manager_.tcp_conn_pool_.poolFailure( + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure); +} + +TEST_F(ThriftRouterTest, PoolTimeout) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + context_.cluster_manager_.tcp_conn_pool_.poolFailure( + Tcp::ConnectionPool::PoolFailureReason::Timeout); +} + +TEST_F(ThriftRouterTest, PoolOverflowFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*too many connections.*")); + })); + context_.cluster_manager_.tcp_conn_pool_.poolFailure( + Tcp::ConnectionPool::PoolFailureReason::Overflow); +} + +TEST_F(ThriftRouterTest, NoRoute) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + if (app_ex != nullptr) { + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::UnknownMethod, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no route.*")); + } + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, NoCluster) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(context_.cluster_manager_, get(cluster_name_)).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*unknown cluster.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, ClusterMaintenanceMode) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(*context_.cluster_manager_.thread_local_cluster_.cluster_.info_, maintenanceMode()) + .WillOnce(Return(true)); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*maintenance mode.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, NoHealthyHosts) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(context_.cluster_manager_, tcpConnPoolForCluster(cluster_name_, _, _)) + .WillOnce(Return(nullptr)); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no healthy upstream.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, TruncatedResponse) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + completeRequest(); + + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, released(Ref(upstream_connection_))); + EXPECT_CALL(callbacks_, resetDownstreamConnection()); + + upstream_callbacks_->onUpstreamData(buffer, true); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + completeRequest(); + + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))) + .WillOnce(Invoke([&](Buffer::Instance&) -> bool { + router_->resetUpstreamConnection(); + return true; + })); + EXPECT_CALL(upstream_connection_, close(Network::ConnectionCloseType::NoFlush)); + + upstream_callbacks_->onUpstreamData(buffer, true); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UnexpectedRouterDestroyBeforeUpstreamConnect) { + initializeRouter(); + startRequest(MessageType::Call); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UnexpectedRouterDestroy) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + EXPECT_CALL(upstream_connection_, close(Network::ConnectionCloseType::NoFlush)); + destroyRouter(); +} + +TEST_P(ThriftRouterFieldTypeTest, OneWay) { + FieldType field_type = GetParam(); + + initializeRouter(); + startRequest(MessageType::Oneway); + connectUpstream(); + sendTrivialStruct(field_type); + completeRequest(); + destroyRouter(); +} + +TEST_P(ThriftRouterFieldTypeTest, Call) { + FieldType field_type = GetParam(); + + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(field_type); + completeRequest(); + returnResponse(); + destroyRouter(); +} + +TEST_P(ThriftRouterContainerTest, DecoderFilterCallbacks) { + FieldType field_type = GetParam(); + + initializeRouter(); + + startRequest(MessageType::Oneway); + connectUpstream(); + + EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({})); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, 1)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldBegin({}, field_type, 1)); + + switch (field_type) { + case FieldType::Map: + EXPECT_CALL(*protocol_, writeMapBegin(_, FieldType::I32, FieldType::I32, 2)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, + router_->mapBegin(FieldType::I32, FieldType::I32, 2)); + for (int i = 0; i < 2; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + EXPECT_CALL(*protocol_, writeInt32(_, i + 100)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i + 100)); + } + EXPECT_CALL(*protocol_, writeMapEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->mapEnd()); + break; + case FieldType::List: + EXPECT_CALL(*protocol_, writeListBegin(_, FieldType::I32, 3)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->listBegin(FieldType::I32, 3)); + for (int i = 0; i < 3; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + } + EXPECT_CALL(*protocol_, writeListEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->listEnd()); + break; + case FieldType::Set: + EXPECT_CALL(*protocol_, writeSetBegin(_, FieldType::I32, 4)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setBegin(FieldType::I32, 4)); + for (int i = 0; i < 4; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + } + EXPECT_CALL(*protocol_, writeSetEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setEnd()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + + EXPECT_CALL(*protocol_, writeFieldEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldEnd()); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, _, FieldType::Stop, 0)); + EXPECT_CALL(*protocol_, writeStructEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structEnd()); + + completeRequest(); + destroyRouter(); +} + +TEST(RouteMatcherTest, Route) { + const std::string yaml = R"EOF( +name: config +routes: + - match: + method: "method1" + route: + cluster: "cluster1" + - match: + method: "method2" + route: + cluster: "cluster2" +)EOF"; + + envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration config = + parseRouteConfigurationFromV2Yaml(yaml); + + RouteMatcher matcher(config); + EXPECT_EQ(nullptr, matcher.route("unknown")); + EXPECT_EQ(nullptr, matcher.route("METHOD1")); + + RouteConstSharedPtr route = matcher.route("method1"); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + RouteConstSharedPtr route2 = matcher.route("method2"); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); +} + +TEST(RouteMatcherTest, RouteMatchAny) { + const std::string yaml = R"EOF( +name: config +routes: + - match: + method: "method1" + route: + cluster: "cluster1" + - match: {} + route: + cluster: "cluster2" +)EOF"; + + envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration config = + parseRouteConfigurationFromV2Yaml(yaml); + + RouteMatcher matcher(config); + RouteConstSharedPtr route = matcher.route("method1"); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + RouteConstSharedPtr route2 = matcher.route("anything"); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc new file mode 100644 index 0000000000000..64ab3cce319be --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc @@ -0,0 +1,176 @@ +#include "envoy/common/exception.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/transport_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::NiceMock; +using testing::Ref; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +TEST(TransportNames, FromType) { + for (int i = 0; i <= static_cast(TransportType::LastTransportType); i++) { + TransportType type = static_cast(i); + EXPECT_NE("", TransportNames::get().fromType(type)); + } +} + +TEST(AutoTransportTest, NotEnoughData) { + Buffer::OwnedImpl buffer; + AutoTransportImpl transport; + absl::optional size = 100; + + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100), size); + + addRepeated(buffer, 7, 0); + + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100), size); +} + +TEST(AutoTransportTest, UnknownTransport) { + AutoTransportImpl transport; + + // Looks like unframed, but fails protocol check. + { + Buffer::OwnedImpl buffer; + addInt32(buffer, 0); + addInt32(buffer, 0); + + absl::optional size = 100; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + "unknown thrift auto transport frame start 00 00 00 00 00 00 00 00"); + EXPECT_EQ(absl::optional(100), size); + } + + // Looks like framed, but fails protocol check. + { + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xFF); + addInt32(buffer, 0); + + absl::optional size = 100; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + "unknown thrift auto transport frame start 00 00 00 ff 00 00 00 00"); + EXPECT_EQ(absl::optional(100), size); + } +} + +TEST(AutoTransportTest, DecodeFrameStart) { + // Framed transport + binary protocol + { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xFF); + addInt16(buffer, 0x8001); + addInt16(buffer, 0); + + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(255), size); + EXPECT_EQ(transport.name(), "framed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Framed); + EXPECT_EQ(buffer.length(), 4); + } + + // Framed transport + compact protocol + { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xFFF); + addInt16(buffer, 0x8201); + addInt16(buffer, 0); + + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(4095), size); + EXPECT_EQ(transport.name(), "framed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Framed); + EXPECT_EQ(buffer.length(), 4); + } + + // Unframed transport + binary protocol + { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt16(buffer, 0x8001); + addRepeated(buffer, 6, 0); + + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); + EXPECT_EQ(transport.name(), "unframed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Unframed); + EXPECT_EQ(buffer.length(), 8); + } + + // Unframed transport + compact protocol + { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt16(buffer, 0x8201); + addRepeated(buffer, 6, 0); + + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); + EXPECT_EQ(transport.name(), "unframed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Unframed); + EXPECT_EQ(buffer.length(), 8); + } +} + +TEST(AutoTransportTest, DecodeFrameEnd) { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xFF); + addInt16(buffer, 0x8001); + addInt16(buffer, 0); + + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(buffer.length(), 4); + + EXPECT_TRUE(transport.decodeFrameEnd(buffer)); +} + +TEST(AutoTransportTest, EncodeFrame) { + MockTransport* mock_transport = new NiceMock(); + + AutoTransportImpl transport; + transport.setTransport(TransportPtr{mock_transport}); + + Buffer::OwnedImpl buffer; + Buffer::OwnedImpl message; + + EXPECT_CALL(*mock_transport, encodeFrame(Ref(buffer), Ref(message))); + transport.encodeFrame(buffer, message); +} + +TEST(AutoTransportTest, Name) { + AutoTransportImpl transport; + EXPECT_EQ(transport.name(), "auto"); +} + +TEST(AutoTransportTest, Type) { + AutoTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Auto); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/transport_test.cc b/test/extensions/filters/network/thrift_proxy/transport_test.cc deleted file mode 100644 index 185ac52af66ca..0000000000000 --- a/test/extensions/filters/network/thrift_proxy/transport_test.cc +++ /dev/null @@ -1,237 +0,0 @@ -#include "envoy/common/exception.h" - -#include "common/buffer/buffer_impl.h" - -#include "extensions/filters/network/thrift_proxy/transport.h" - -#include "test/extensions/filters/network/thrift_proxy/mocks.h" -#include "test/extensions/filters/network/thrift_proxy/utility.h" -#include "test/test_common/printers.h" -#include "test/test_common/utility.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::NiceMock; - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -TEST(FramedTransportTest, Name) { - NiceMock cb; - FramedTransportImpl transport(cb); - EXPECT_EQ(transport.name(), "framed"); -} - -TEST(FramedTransportTest, NotEnoughData) { - Buffer::OwnedImpl buffer; - NiceMock cb; - FramedTransportImpl transport(cb); - - EXPECT_FALSE(transport.decodeFrameStart(buffer)); - - addRepeated(buffer, 3, 0); - - EXPECT_FALSE(transport.decodeFrameStart(buffer)); -} - -TEST(FramedTransportTest, InvalidFrameSize) { - NiceMock cb; - FramedTransportImpl transport(cb); - - { - Buffer::OwnedImpl buffer; - addInt32(buffer, -1); - - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, - "invalid thrift framed transport frame size -1"); - } - - { - Buffer::OwnedImpl buffer; - addInt32(buffer, 0x7fffffff); - - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, - "invalid thrift framed transport frame size 2147483647"); - } -} - -TEST(FramedTransportTest, DecodeFrameStart) { - MockTransportCallbacks cb; - EXPECT_CALL(cb, transportFrameStart(absl::optional(100U))); - - FramedTransportImpl transport(cb); - - Buffer::OwnedImpl buffer; - addInt32(buffer, 100); - - EXPECT_EQ(buffer.length(), 4); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(buffer.length(), 0); -} - -TEST(FramedTransportTest, DecodeFrameEnd) { - MockTransportCallbacks cb; - EXPECT_CALL(cb, transportFrameComplete()); - - FramedTransportImpl transport(cb); - - Buffer::OwnedImpl buffer; - - EXPECT_TRUE(transport.decodeFrameEnd(buffer)); -} - -TEST(UnframedTransportTest, Name) { - NiceMock cb; - UnframedTransportImpl transport(cb); - EXPECT_EQ(transport.name(), "unframed"); -} - -TEST(UnframedTransportTest, DecodeFrameStart) { - MockTransportCallbacks cb; - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - - UnframedTransportImpl transport(cb); - - Buffer::OwnedImpl buffer; - addInt32(buffer, 0xDEADBEEF); - - EXPECT_EQ(buffer.length(), 4); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(buffer.length(), 4); -} - -TEST(UnframedTransportTest, DecodeFrameEnd) { - MockTransportCallbacks cb; - EXPECT_CALL(cb, transportFrameComplete()); - - UnframedTransportImpl transport(cb); - - Buffer::OwnedImpl buffer; - EXPECT_TRUE(transport.decodeFrameEnd(buffer)); -} - -TEST(AutoTransportTest, NotEnoughData) { - Buffer::OwnedImpl buffer; - NiceMock cb; - AutoTransportImpl transport(cb); - - EXPECT_FALSE(transport.decodeFrameStart(buffer)); - - addRepeated(buffer, 7, 0); - - EXPECT_FALSE(transport.decodeFrameStart(buffer)); -} - -TEST(AutoTransportTest, UnknownTransport) { - NiceMock cb; - AutoTransportImpl transport(cb); - - // Looks like unframed, but fails protocol check. - { - Buffer::OwnedImpl buffer; - addInt32(buffer, 0); - addInt32(buffer, 0); - - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, - "unknown thrift auto transport frame start 00 00 00 00 00 00 00 00"); - } - - // Looks like framed, but fails protocol check. - { - Buffer::OwnedImpl buffer; - addInt32(buffer, 0xFF); - addInt32(buffer, 0); - - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, - "unknown thrift auto transport frame start 00 00 00 ff 00 00 00 00"); - } -} - -TEST(AutoTransportTest, DecodeFrameStart) { - NiceMock cb; - - // Framed transport + binary protocol - { - AutoTransportImpl transport(cb); - Buffer::OwnedImpl buffer; - addInt32(buffer, 0xFF); - addInt16(buffer, 0x8001); - addInt16(buffer, 0); - - EXPECT_CALL(cb, transportFrameStart(absl::optional(255U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(transport.name(), "framed(auto)"); - EXPECT_EQ(buffer.length(), 4); - } - - // Framed transport + compact protocol - { - AutoTransportImpl transport(cb); - Buffer::OwnedImpl buffer; - addInt32(buffer, 0xFFF); - addInt16(buffer, 0x8201); - addInt16(buffer, 0); - - EXPECT_CALL(cb, transportFrameStart(absl::optional(4095U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(transport.name(), "framed(auto)"); - EXPECT_EQ(buffer.length(), 4); - } - - // Unframed transport + binary protocol - { - AutoTransportImpl transport(cb); - Buffer::OwnedImpl buffer; - addInt16(buffer, 0x8001); - addRepeated(buffer, 6, 0); - - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(transport.name(), "unframed(auto)"); - EXPECT_EQ(buffer.length(), 8); - } - - // Unframed transport + compact protocol - { - AutoTransportImpl transport(cb); - Buffer::OwnedImpl buffer; - addInt16(buffer, 0x8201); - addRepeated(buffer, 6, 0); - - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(transport.name(), "unframed(auto)"); - EXPECT_EQ(buffer.length(), 8); - } -} - -TEST(AutoTransportTest, DecodeFrameEnd) { - NiceMock cb; - - AutoTransportImpl transport(cb); - Buffer::OwnedImpl buffer; - addInt32(buffer, 0xFF); - addInt16(buffer, 0x8001); - addInt16(buffer, 0); - - EXPECT_CALL(cb, transportFrameStart(absl::optional(255U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); - EXPECT_EQ(buffer.length(), 4); - - EXPECT_CALL(cb, transportFrameComplete()); - EXPECT_TRUE(transport.decodeFrameEnd(buffer)); -} - -TEST(AutoTransportTest, Name) { - NiceMock cb; - AutoTransportImpl transport(cb); - EXPECT_EQ(transport.name(), "auto"); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc new file mode 100644 index 0000000000000..f83119ffaf383 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc @@ -0,0 +1,62 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +TEST(UnframedTransportTest, Name) { + UnframedTransportImpl transport; + EXPECT_EQ(transport.name(), "unframed"); +} + +TEST(UnframedTransportTest, Type) { + UnframedTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Unframed); +} + +TEST(UnframedTransportTest, DecodeFrameStart) { + UnframedTransportImpl transport; + + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xDEADBEEF); + EXPECT_EQ(buffer.length(), 4); + + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); + EXPECT_EQ(buffer.length(), 4); +} + +TEST(UnframedTransportTest, DecodeFrameEnd) { + UnframedTransportImpl transport; + + Buffer::OwnedImpl buffer; + EXPECT_TRUE(transport.decodeFrameEnd(buffer)); +} + +TEST(UnframedTransportTest, EncodeFrame) { + UnframedTransportImpl transport; + + Buffer::OwnedImpl message; + message.add("fake message"); + + Buffer::OwnedImpl buffer; + transport.encodeFrame(buffer, message); + + EXPECT_EQ(0, message.length()); + EXPECT_EQ("fake message", buffer.toString()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/utility.h b/test/extensions/filters/network/thrift_proxy/utility.h index 5b85d4f9533ae..8058f9b337999 100644 --- a/test/extensions/filters/network/thrift_proxy/utility.h +++ b/test/extensions/filters/network/thrift_proxy/utility.h @@ -7,6 +7,10 @@ #include "extensions/filters/network/thrift_proxy/protocol.h" +#include "gtest/gtest.h" + +using testing::TestParamInfo; + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -52,6 +56,43 @@ inline std::string bufferToString(Buffer::Instance& buffer) { return std::string(data, buffer.length()); } +inline std::string fieldTypeToString(const FieldType& field_type) { + switch (field_type) { + case FieldType::Stop: + return "Stop"; + case FieldType::Void: + return "Void"; + case FieldType::Bool: + return "Bool"; + case FieldType::Byte: + return "Byte"; + case FieldType::Double: + return "Double"; + case FieldType::I16: + return "I16"; + case FieldType::I32: + return "I32"; + case FieldType::I64: + return "I64"; + case FieldType::String: + return "String"; + case FieldType::Struct: + return "Struct"; + case FieldType::Map: + return "Map"; + case FieldType::Set: + return "Set"; + case FieldType::List: + return "List"; + default: + return "UnknownFieldType"; + } +} + +inline std::string fieldTypeParamToString(const TestParamInfo& params) { + return fieldTypeToString(params.param); +} + } // namespace } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/test/extensions/grpc_credentials/file_based_metadata/file_based_metadata_grpc_credentials_test.cc b/test/extensions/grpc_credentials/file_based_metadata/file_based_metadata_grpc_credentials_test.cc index dec34bb1c9ab5..dee133eaf6062 100644 --- a/test/extensions/grpc_credentials/file_based_metadata/file_based_metadata_grpc_credentials_test.cc +++ b/test/extensions/grpc_credentials/file_based_metadata/file_based_metadata_grpc_credentials_test.cc @@ -19,7 +19,8 @@ namespace { class GrpcFileBasedMetadataClientIntegrationTest : public GrpcSslClientIntegrationTest { public: void expectExtraHeaders(FakeStream& fake_stream) override { - fake_stream.waitForHeadersComplete(); + AssertionResult result = fake_stream.waitForHeadersComplete(); + RELEASE_ASSERT(result, result.message()); Http::TestHeaderMapImpl stream_headers(fake_stream.headers()); if (!header_value_1_.empty()) { EXPECT_EQ(header_prefix_1_ + header_value_1_, stream_headers.get_(header_key_1_)); @@ -86,7 +87,7 @@ TEST_P(GrpcFileBasedMetadataClientIntegrationTest, FileBasedMetadataGrpcAuthRequ header_prefix_1_ = "prefix1"; header_value_1_ = "secretvalue"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().FILE_BASED_METADATA; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().FileBasedMetadata; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); @@ -101,7 +102,7 @@ TEST_P(GrpcFileBasedMetadataClientIntegrationTest, DoubleFileBasedMetadataGrpcAu header_value_1_ = "secretvalue"; header_value_2_ = "secret2"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().FILE_BASED_METADATA; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().FileBasedMetadata; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); @@ -112,7 +113,7 @@ TEST_P(GrpcFileBasedMetadataClientIntegrationTest, DoubleFileBasedMetadataGrpcAu TEST_P(GrpcFileBasedMetadataClientIntegrationTest, EmptyFileBasedMetadataGrpcAuthRequest) { SKIP_IF_GRPC_CLIENT(ClientType::EnvoyGrpc); credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().FILE_BASED_METADATA; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().FileBasedMetadata; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); @@ -127,7 +128,7 @@ TEST_P(GrpcFileBasedMetadataClientIntegrationTest, ExtraConfigFileBasedMetadataG header_prefix_1_ = "prefix1"; header_value_1_ = "secretvalue"; credentials_factory_name_ = - Extensions::GrpcCredentials::GrpcCredentialsNames::get().FILE_BASED_METADATA; + Extensions::GrpcCredentials::GrpcCredentialsNames::get().FileBasedMetadata; initialize(); auto request = createRequest(empty_metadata_); request->sendReply(); diff --git a/test/extensions/health_checkers/redis/config_test.cc b/test/extensions/health_checkers/redis/config_test.cc index 656aa61eb7856..1f70f8115b030 100644 --- a/test/extensions/health_checkers/redis/config_test.cc +++ b/test/extensions/health_checkers/redis/config_test.cc @@ -39,7 +39,7 @@ TEST(HealthCheckerFactoryTest, createRedis) { .get())); } -TEST(HealthCheckerFactoryTest, createRedisViaUpstreamHealthCheckerFactory) { +TEST(HealthCheckerFactoryTest, createRedisWithoutKey) { const std::string yaml = R"EOF( timeout: 1s interval: 1s @@ -50,21 +50,19 @@ TEST(HealthCheckerFactoryTest, createRedisViaUpstreamHealthCheckerFactory) { custom_health_check: name: envoy.health_checkers.redis config: - key: foo )EOF"; - NiceMock cluster; - Runtime::MockLoader runtime; - Runtime::MockRandomGenerator random; - Event::MockDispatcher dispatcher; - EXPECT_NE(nullptr, - dynamic_cast( - Upstream::HealthCheckerFactory::create(Upstream::parseHealthCheckFromV2Yaml(yaml), - cluster, runtime, random, dispatcher) - .get())); + NiceMock context; + + RedisHealthCheckerFactory factory; + EXPECT_NE( + nullptr, + dynamic_cast( + factory.createCustomHealthChecker(Upstream::parseHealthCheckFromV2Yaml(yaml), context) + .get())); } -TEST(HealthCheckerFactoryTest, createRedisWithDeprecatedConfig) { +TEST(HealthCheckerFactoryTest, createRedisViaUpstreamHealthCheckerFactory) { const std::string yaml = R"EOF( timeout: 1s interval: 1s @@ -72,22 +70,22 @@ TEST(HealthCheckerFactoryTest, createRedisWithDeprecatedConfig) { interval_jitter: 1s unhealthy_threshold: 1 healthy_threshold: 1 - # Using the deprecated redis_health_check should work. - redis_health_check: - key: foo + custom_health_check: + name: envoy.health_checkers.redis + config: + key: foo )EOF"; NiceMock cluster; Runtime::MockLoader runtime; Runtime::MockRandomGenerator random; Event::MockDispatcher dispatcher; - EXPECT_NE(nullptr, - dynamic_cast( - // Always use Upstream's HealthCheckerFactory when creating instance using - // deprecated config. - Upstream::HealthCheckerFactory::create(Upstream::parseHealthCheckFromV2Yaml(yaml), - cluster, runtime, random, dispatcher) - .get())); + AccessLog::MockAccessLogManager log_manager; + EXPECT_NE(nullptr, dynamic_cast( + Upstream::HealthCheckerFactory::create( + Upstream::parseHealthCheckFromV2Yaml(yaml), cluster, runtime, random, + dispatcher, log_manager) + .get())); } TEST(HealthCheckerFactoryTest, createRedisWithDeprecatedV1JsonConfig) { @@ -105,13 +103,14 @@ TEST(HealthCheckerFactoryTest, createRedisWithDeprecatedV1JsonConfig) { Runtime::MockLoader runtime; Runtime::MockRandomGenerator random; Event::MockDispatcher dispatcher; - EXPECT_NE(nullptr, - dynamic_cast( - // Always use Upstream's HealthCheckerFactory when creating instance using - // deprecated config. - Upstream::HealthCheckerFactory::create(Upstream::parseHealthCheckFromV1Json(json), - cluster, runtime, random, dispatcher) - .get())); + AccessLog::MockAccessLogManager log_manager; + EXPECT_NE(nullptr, dynamic_cast( + // Always use Upstream's HealthCheckerFactory when creating instance using + // deprecated config. + Upstream::HealthCheckerFactory::create( + Upstream::parseHealthCheckFromV1Json(json), cluster, runtime, random, + dispatcher, log_manager) + .get())); } TEST(HealthCheckerFactoryTest, createRedisWithDeprecatedV1JsonConfigWithKey) { @@ -130,16 +129,17 @@ TEST(HealthCheckerFactoryTest, createRedisWithDeprecatedV1JsonConfigWithKey) { Runtime::MockLoader runtime; Runtime::MockRandomGenerator random; Event::MockDispatcher dispatcher; - EXPECT_NE(nullptr, - dynamic_cast( - // Always use Upstream's HealthCheckerFactory when creating instance using - // deprecated config. - Upstream::HealthCheckerFactory::create(Upstream::parseHealthCheckFromV1Json(json), - cluster, runtime, random, dispatcher) - .get())); + AccessLog::MockAccessLogManager log_manager; + EXPECT_NE(nullptr, dynamic_cast( + // Always use Upstream's HealthCheckerFactory when creating instance using + // deprecated config. + Upstream::HealthCheckerFactory::create( + Upstream::parseHealthCheckFromV1Json(json), cluster, runtime, random, + dispatcher, log_manager) + .get())); } } // namespace RedisHealthChecker } // namespace HealthCheckers } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/test/extensions/health_checkers/redis/redis_test.cc b/test/extensions/health_checkers/redis/redis_test.cc index 9ade2e2e63801..015a7c08c9b5b 100644 --- a/test/extensions/health_checkers/redis/redis_test.cc +++ b/test/extensions/health_checkers/redis/redis_test.cc @@ -26,27 +26,9 @@ class RedisHealthCheckerTest : public testing::Test, public Extensions::NetworkFilters::RedisProxy::ConnPool::ClientFactory { public: - RedisHealthCheckerTest() : cluster_(new NiceMock()) {} - - void setupExistsHealthcheckDeprecated() { - const std::string yaml = R"EOF( - timeout: 1s - interval: 1s - no_traffic_interval: 5s - interval_jitter: 1s - unhealthy_threshold: 1 - healthy_threshold: 1 - # Using the deprecated redis_health_check should work. - redis_health_check: - key: foo - )EOF"; - - const auto& hc_config = Upstream::parseHealthCheckFromV2Yaml(yaml); - const auto& redis_config = getRedisHealthCheckConfig(hc_config); - - health_checker_.reset(new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, - runtime_, random_, *this)); - } + RedisHealthCheckerTest() + : cluster_(new NiceMock()), + event_logger_(new Upstream::MockHealthCheckEventLogger()) {} void setup() { const std::string yaml = R"EOF( @@ -64,8 +46,9 @@ class RedisHealthCheckerTest const auto& hc_config = Upstream::parseHealthCheckFromV2Yaml(yaml); const auto& redis_config = getRedisHealthCheckConfig(hc_config); - health_checker_.reset(new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, - runtime_, random_, *this)); + health_checker_.reset( + new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, runtime_, random_, + Upstream::HealthCheckEventLoggerPtr(event_logger_), *this)); } void setupExistsHealthcheck() { @@ -85,8 +68,9 @@ class RedisHealthCheckerTest const auto& hc_config = Upstream::parseHealthCheckFromV2Yaml(yaml); const auto& redis_config = getRedisHealthCheckConfig(hc_config); - health_checker_.reset(new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, - runtime_, random_, *this)); + health_checker_.reset( + new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, runtime_, random_, + Upstream::HealthCheckEventLoggerPtr(event_logger_), *this)); } void setupDontReuseConnection() { @@ -106,8 +90,9 @@ class RedisHealthCheckerTest const auto& hc_config = Upstream::parseHealthCheckFromV2Yaml(yaml); const auto& redis_config = getRedisHealthCheckConfig(hc_config); - health_checker_.reset(new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, - runtime_, random_, *this)); + health_checker_.reset( + new RedisHealthChecker(*cluster_, hc_config, redis_config, dispatcher_, runtime_, random_, + Upstream::HealthCheckEventLoggerPtr(event_logger_), *this)); } Extensions::NetworkFilters::RedisProxy::ConnPool::ClientPtr @@ -145,6 +130,7 @@ class RedisHealthCheckerTest NiceMock dispatcher_; NiceMock runtime_; NiceMock random_; + Upstream::MockHealthCheckEventLogger* event_logger_{}; Event::MockTimer* timeout_timer_{}; Event::MockTimer* interval_timer_{}; Extensions::NetworkFilters::RedisProxy::ConnPool::MockClient* client_{}; @@ -181,6 +167,7 @@ TEST_F(RedisHealthCheckerTest, PingAndVariousFailures) { interval_timer_->callback_(); // Failure + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*timeout_timer_, disableTimer()); EXPECT_CALL(*interval_timer_, enableTimer(_)); response.reset(new Extensions::NetworkFilters::RedisProxy::RespValue()); @@ -248,57 +235,7 @@ TEST_F(RedisHealthCheckerTest, Exists) { interval_timer_->callback_(); // Failure, exists - EXPECT_CALL(*timeout_timer_, disableTimer()); - EXPECT_CALL(*interval_timer_, enableTimer(_)); - response.reset(new Extensions::NetworkFilters::RedisProxy::RespValue()); - response->type(Extensions::NetworkFilters::RedisProxy::RespType::Integer); - response->asInteger() = 1; - pool_callbacks_->onResponse(std::move(response)); - - expectExistsRequestCreate(); - interval_timer_->callback_(); - - // Failure, no value - EXPECT_CALL(*timeout_timer_, disableTimer()); - EXPECT_CALL(*interval_timer_, enableTimer(_)); - response.reset(new Extensions::NetworkFilters::RedisProxy::RespValue()); - pool_callbacks_->onResponse(std::move(response)); - - EXPECT_CALL(*client_, close()); - - EXPECT_EQ(3UL, cluster_->info_->stats_store_.counter("health_check.attempt").value()); - EXPECT_EQ(1UL, cluster_->info_->stats_store_.counter("health_check.success").value()); - EXPECT_EQ(2UL, cluster_->info_->stats_store_.counter("health_check.failure").value()); -} - -TEST_F(RedisHealthCheckerTest, ExistsDeprecated) { - InSequence s; - setupExistsHealthcheckDeprecated(); - - cluster_->prioritySet().getMockHostSet(0)->hosts_ = { - Upstream::makeTestHost(cluster_->info_, "tcp://127.0.0.1:80")}; - - expectSessionCreate(); - expectClientCreate(); - expectExistsRequestCreate(); - health_checker_->start(); - - client_->runHighWatermarkCallbacks(); - client_->runLowWatermarkCallbacks(); - - // Success - EXPECT_CALL(*timeout_timer_, disableTimer()); - EXPECT_CALL(*interval_timer_, enableTimer(_)); - Extensions::NetworkFilters::RedisProxy::RespValuePtr response( - new Extensions::NetworkFilters::RedisProxy::RespValue()); - response->type(Extensions::NetworkFilters::RedisProxy::RespType::Integer); - response->asInteger() = 0; - pool_callbacks_->onResponse(std::move(response)); - - expectExistsRequestCreate(); - interval_timer_->callback_(); - - // Failure, exists + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*timeout_timer_, disableTimer()); EXPECT_CALL(*interval_timer_, enableTimer(_)); response.reset(new Extensions::NetworkFilters::RedisProxy::RespValue()); @@ -350,6 +287,7 @@ TEST_F(RedisHealthCheckerTest, NoConnectionReuse) { interval_timer_->callback_(); // The connection will close on failure. + EXPECT_CALL(*event_logger_, logEjectUnhealthy(_, _, _)); EXPECT_CALL(*timeout_timer_, disableTimer()); EXPECT_CALL(*interval_timer_, enableTimer(_)); EXPECT_CALL(*client_, close()); @@ -395,4 +333,4 @@ TEST_F(RedisHealthCheckerTest, NoConnectionReuse) { } // namespace RedisHealthChecker } // namespace HealthCheckers } // namespace Extensions -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/test/extensions/resource_monitors/fixed_heap/BUILD b/test/extensions/resource_monitors/fixed_heap/BUILD new file mode 100644 index 0000000000000..3d1c8eff0ab3b --- /dev/null +++ b/test/extensions/resource_monitors/fixed_heap/BUILD @@ -0,0 +1,35 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_package", +) +load( + "//test/extensions:extensions_build_system.bzl", + "envoy_extension_cc_test", +) + +envoy_package() + +envoy_extension_cc_test( + name = "fixed_heap_monitor_test", + srcs = ["fixed_heap_monitor_test.cc"], + extension_name = "envoy.resource_monitors.fixed_heap", + external_deps = ["abseil_optional"], + deps = [ + "//source/extensions/resource_monitors/fixed_heap:fixed_heap_monitor", + ], +) + +envoy_extension_cc_test( + name = "config_test", + srcs = ["config_test.cc"], + extension_name = "envoy.resource_monitors.fixed_heap", + deps = [ + "//include/envoy/registry", + "//source/extensions/resource_monitors/fixed_heap:config", + "//source/server:resource_monitor_config_lib", + "//test/mocks/event:event_mocks", + "@envoy_api//envoy/config/resource_monitor/fixed_heap/v2alpha:fixed_heap_cc", + ], +) diff --git a/test/extensions/resource_monitors/fixed_heap/config_test.cc b/test/extensions/resource_monitors/fixed_heap/config_test.cc new file mode 100644 index 0000000000000..5bdd672804d2c --- /dev/null +++ b/test/extensions/resource_monitors/fixed_heap/config_test.cc @@ -0,0 +1,34 @@ +#include "envoy/config/resource_monitor/fixed_heap/v2alpha/fixed_heap.pb.validate.h" +#include "envoy/registry/registry.h" + +#include "server/resource_monitor_config_impl.h" + +#include "extensions/resource_monitors/fixed_heap/config.h" + +#include "test/mocks/event/mocks.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace FixedHeapMonitor { + +TEST(FixedHeapMonitorFactoryTest, CreateMonitor) { + auto factory = + Registry::FactoryRegistry::getFactory( + "envoy.resource_monitors.fixed_heap"); + EXPECT_NE(factory, nullptr); + + envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig config; + config.set_max_heap_size_bytes(std::numeric_limits::max()); + Event::MockDispatcher dispatcher; + Server::Configuration::ResourceMonitorFactoryContextImpl context(dispatcher); + auto monitor = factory->createResourceMonitor(config, context); + EXPECT_NE(monitor, nullptr); +} + +} // namespace FixedHeapMonitor +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/resource_monitors/fixed_heap/fixed_heap_monitor_test.cc b/test/extensions/resource_monitors/fixed_heap/fixed_heap_monitor_test.cc new file mode 100644 index 0000000000000..637d5fe2a9433 --- /dev/null +++ b/test/extensions/resource_monitors/fixed_heap/fixed_heap_monitor_test.cc @@ -0,0 +1,56 @@ +#include "extensions/resource_monitors/fixed_heap/fixed_heap_monitor.h" + +#include "absl/types/optional.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace ResourceMonitors { +namespace FixedHeapMonitor { + +class MockMemoryStatsReader : public MemoryStatsReader { +public: + MockMemoryStatsReader() {} + + MOCK_METHOD0(reservedHeapBytes, uint64_t()); + MOCK_METHOD0(unmappedHeapBytes, uint64_t()); +}; + +class ResourcePressure : public Server::ResourceMonitor::Callbacks { +public: + void onSuccess(const Server::ResourceUsage& usage) override { + pressure_ = usage.resource_pressure_; + } + + void onFailure(const EnvoyException& error) override { error_ = error; } + + bool hasPressure() const { return pressure_.has_value(); } + bool hasError() const { return error_.has_value(); } + + double pressure() const { return *pressure_; } + +private: + absl::optional pressure_; + absl::optional error_; +}; + +TEST(FixedHeapMonitorTest, ComputesCorrectUsage) { + envoy::config::resource_monitor::fixed_heap::v2alpha::FixedHeapConfig config; + config.set_max_heap_size_bytes(1000); + auto stats_reader = std::make_unique(); + EXPECT_CALL(*stats_reader, reservedHeapBytes()).WillOnce(testing::Return(800)); + EXPECT_CALL(*stats_reader, unmappedHeapBytes()).WillOnce(testing::Return(100)); + std::unique_ptr monitor(new FixedHeapMonitor(config, std::move(stats_reader))); + + ResourcePressure resource; + monitor->updateResourceUsage(resource); + EXPECT_TRUE(resource.hasPressure()); + EXPECT_FALSE(resource.hasError()); + EXPECT_EQ(resource.pressure(), 0.7); +} + +} // namespace FixedHeapMonitor +} // namespace ResourceMonitors +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/stats_sinks/common/statsd/statsd_test.cc b/test/extensions/stats_sinks/common/statsd/statsd_test.cc index 02786f7e406ed..82015050207fe 100644 --- a/test/extensions/stats_sinks/common/statsd/statsd_test.cc +++ b/test/extensions/stats_sinks/common/statsd/statsd_test.cc @@ -118,7 +118,7 @@ TEST_F(TcpStatsdSinkTest, BufferReallocate) { for (int i = 0; i < 2000; i++) { compare += "envoy.test_counter:1|c\n"; } - EXPECT_EQ(compare, TestUtility::bufferToString(buffer)); + EXPECT_EQ(compare, buffer.toString()); })); sink_->flush(source_); } diff --git a/test/extensions/stats_sinks/dog_statsd/config_test.cc b/test/extensions/stats_sinks/dog_statsd/config_test.cc index d3afecfa8f99b..c54a1687772f1 100644 --- a/test/extensions/stats_sinks/dog_statsd/config_test.cc +++ b/test/extensions/stats_sinks/dog_statsd/config_test.cc @@ -32,7 +32,7 @@ INSTANTIATE_TEST_CASE_P(IpVersions, DogStatsdConfigLoopbackTest, TestUtility::ipTestParamsToString); TEST_P(DogStatsdConfigLoopbackTest, ValidUdpIp) { - const std::string name = StatsSinkNames::get().DOG_STATSD; + const std::string name = StatsSinkNames::get().DogStatsd; envoy::config::metrics::v2::DogStatsdSink sink_config; envoy::api::v2::core::Address& address = *sink_config.mutable_address(); diff --git a/test/extensions/stats_sinks/hystrix/BUILD b/test/extensions/stats_sinks/hystrix/BUILD new file mode 100644 index 0000000000000..4d50225268359 --- /dev/null +++ b/test/extensions/stats_sinks/hystrix/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_package", +) +load( + "//test/extensions:extensions_build_system.bzl", + "envoy_extension_cc_test", +) + +envoy_package() + +envoy_extension_cc_test( + name = "config_test", + srcs = ["config_test.cc"], + extension_name = "envoy.stat_sinks.hystrix", + deps = [ + "//include/envoy/registry", + "//source/common/protobuf:utility_lib", + "//source/extensions/stat_sinks/hystrix:config", + "//test/mocks/server:server_mocks", + "//test/test_common:environment_lib", + "//test/test_common:network_utility_lib", + "//test/test_common:utility_lib", + ], +) + +envoy_extension_cc_test( + name = "hystrix_test", + srcs = ["hystrix_test.cc"], + extension_name = "envoy.stat_sinks.hystrix", + deps = [ + "//source/common/stats:stats_lib", + "//source/extensions/stat_sinks/hystrix:hystrix_lib", + "//test/mocks/server:server_mocks", + "//test/mocks/stats:stats_mocks", + "//test/mocks/upstream:upstream_mocks", + ], +) diff --git a/test/extensions/stats_sinks/hystrix/config_test.cc b/test/extensions/stats_sinks/hystrix/config_test.cc new file mode 100644 index 0000000000000..ec224c4ddb914 --- /dev/null +++ b/test/extensions/stats_sinks/hystrix/config_test.cc @@ -0,0 +1,49 @@ +#include "envoy/config/bootstrap/v2/bootstrap.pb.h" +#include "envoy/registry/registry.h" + +#include "common/protobuf/utility.h" + +#include "extensions/stat_sinks/hystrix/config.h" +#include "extensions/stat_sinks/hystrix/hystrix.h" +#include "extensions/stat_sinks/well_known_names.h" + +#include "test/mocks/server/mocks.h" +#include "test/test_common/environment.h" +#include "test/test_common/network_utility.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::NiceMock; +using testing::Return; +using testing::ReturnRef; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace StatSinks { +namespace Hystrix { + +TEST(StatsConfigTest, ValidHystrixSink) { + const std::string name = StatsSinkNames::get().Hystrix; + + envoy::config::metrics::v2::HystrixSink sink_config; + + Server::Configuration::StatsSinkFactory* factory = + Registry::FactoryRegistry::getFactory(name); + ASSERT_NE(factory, nullptr); + + ProtobufTypes::MessagePtr message = factory->createEmptyConfigProto(); + MessageUtil::jsonConvert(sink_config, *message); + + NiceMock server; + Stats::SinkPtr sink = factory->createStatsSink(*message, server); + EXPECT_NE(sink, nullptr); + EXPECT_NE(dynamic_cast(sink.get()), nullptr); +} + +} // namespace Hystrix +} // namespace StatSinks +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/stats_sinks/hystrix/hystrix_test.cc b/test/extensions/stats_sinks/hystrix/hystrix_test.cc new file mode 100644 index 0000000000000..13981082f83bd --- /dev/null +++ b/test/extensions/stats_sinks/hystrix/hystrix_test.cc @@ -0,0 +1,436 @@ +#include +#include +#include + +#include "common/stats/stats_impl.h" + +#include "extensions/stat_sinks/hystrix/hystrix.h" + +#include "test/mocks/server/mocks.h" +#include "test/mocks/stats/mocks.h" +#include "test/mocks/upstream/mocks.h" + +#include "absl/strings/str_split.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::InSequence; +using testing::Invoke; +using testing::NiceMock; +using testing::Return; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace StatSinks { +namespace Hystrix { + +class ClusterTestInfo { + +public: + ClusterTestInfo(const std::string cluster_name) : cluster_name_(cluster_name) { + ON_CALL(cluster_, info()).WillByDefault(Return(cluster_info_ptr_)); + ON_CALL(*cluster_info_, name()).WillByDefault(testing::ReturnRefOfCopy(cluster_name_)); + ON_CALL(*cluster_info_, statsScope()).WillByDefault(ReturnRef(cluster_stats_scope_)); + + // Set gauge value. + membership_total_gauge_.name_ = "membership_total"; + ON_CALL(cluster_stats_scope_, gauge("membership_total")) + .WillByDefault(ReturnRef(membership_total_gauge_)); + ON_CALL(membership_total_gauge_, value()).WillByDefault(Return(5)); + + // Attach counters. + setCounterForTest(success_counter_, "upstream_rq_2xx"); + setCounterForTest(error_5xx_counter_, "upstream_rq_5xx"); + setCounterForTest(retry_5xx_counter_, "retry.upstream_rq_5xx"); + setCounterForTest(error_4xx_counter_, "upstream_rq_4xx"); + setCounterForTest(retry_4xx_counter_, "retry.upstream_rq_4xx"); + setCountersToZero(); + } + + // Attach the counter to cluster_stat_scope and set default value. + void setCounterForTest(NiceMock& counter, std::string counter_name) { + counter.name_ = counter_name; + ON_CALL(cluster_stats_scope_, counter(counter_name)).WillByDefault(ReturnRef(counter)); + } + + void setCountersToZero() { + ON_CALL(error_5xx_counter_, value()).WillByDefault(Return(0)); + ON_CALL(retry_5xx_counter_, value()).WillByDefault(Return(0)); + ON_CALL(error_4xx_counter_, value()).WillByDefault(Return(0)); + ON_CALL(retry_4xx_counter_, value()).WillByDefault(Return(0)); + ON_CALL(success_counter_, value()).WillByDefault(Return(0)); + } + + // Set counter return values to simulate traffic + void setCounterReturnValues(const uint64_t i, const uint64_t success_step, + const uint64_t error_4xx_step, const uint64_t error_4xx_retry_step, + const uint64_t error_5xx_step, const uint64_t error_5xx_retry_step, + const uint64_t timeout_step, const uint64_t timeout_retry_step, + const uint64_t rejected_step) { + ON_CALL(error_5xx_counter_, value()).WillByDefault(Return((i + 1) * error_5xx_step)); + ON_CALL(retry_5xx_counter_, value()).WillByDefault(Return((i + 1) * error_5xx_retry_step)); + ON_CALL(error_4xx_counter_, value()).WillByDefault(Return((i + 1) * error_4xx_step)); + ON_CALL(retry_4xx_counter_, value()).WillByDefault(Return((i + 1) * error_4xx_retry_step)); + ON_CALL(success_counter_, value()).WillByDefault(Return((i + 1) * success_step)); + cluster_info_->stats().upstream_rq_timeout_.add(timeout_step); + cluster_info_->stats().upstream_rq_per_try_timeout_.add(timeout_retry_step); + cluster_info_->stats().upstream_rq_pending_overflow_.add(rejected_step); + } + + NiceMock cluster_; + Upstream::MockClusterInfo* cluster_info_ = new NiceMock(); + Upstream::ClusterInfoConstSharedPtr cluster_info_ptr_{cluster_info_}; + + NiceMock stats_store_; + NiceMock cluster_stats_scope_; + const std::string cluster_name_; + + NiceMock membership_total_gauge_; + NiceMock success_counter_; + NiceMock error_5xx_counter_; + NiceMock retry_5xx_counter_; + NiceMock error_4xx_counter_; + NiceMock retry_4xx_counter_; +}; + +class HystrixSinkTest : public testing::Test { +public: + HystrixSinkTest() { sink_.reset(new HystrixSink(server_, window_size_)); } + + Buffer::OwnedImpl createClusterAndCallbacks() { + // Set cluster. + cluster_map_.emplace(cluster1_name_, cluster1_.cluster_); + ON_CALL(server_, clusterManager()).WillByDefault(ReturnRef(cluster_manager_)); + ON_CALL(cluster_manager_, clusters()).WillByDefault(Return(cluster_map_)); + + Buffer::OwnedImpl buffer; + auto encode_callback = [&buffer](Buffer::Instance& data, bool) { + // Set callbacks to send data to buffer. This will append to the end of the buffer, so + // multiple calls will all be dumped one after another into this buffer. + buffer.add(data); + }; + ON_CALL(callbacks_, encodeData(_, _)).WillByDefault(Invoke(encode_callback)); + return buffer; + } + + void addClusterToMap(const std::string& cluster_name, NiceMock& cluster) { + cluster_map_.emplace(cluster_name, cluster); + // Redefining since cluster_map_ is returned by value. + ON_CALL(cluster_manager_, clusters()).WillByDefault(Return(cluster_map_)); + } + + void removeClusterFromMap(const std::string& cluster_name) { + cluster_map_.erase(cluster_name); + // Redefining since cluster_map_ is returned by value. + ON_CALL(cluster_manager_, clusters()).WillByDefault(Return(cluster_map_)); + } + + void addSecondClusterHelper(Buffer::OwnedImpl& buffer) { + buffer.drain(buffer.length()); + cluster2_.setCountersToZero(); + addClusterToMap(cluster2_name_, cluster2_.cluster_); + } + + std::unordered_map + addSecondClusterAndSendDataHelper(Buffer::OwnedImpl& buffer, const uint64_t success_step, + const uint64_t error_step, const uint64_t timeout_step, + const uint64_t success_step2, const uint64_t error_step2, + const uint64_t timeout_step2) { + + // Add new cluster. + addSecondClusterHelper(buffer); + + // Generate data to both clusters. + for (uint64_t i = 0; i < (window_size_ + 1); i++) { + buffer.drain(buffer.length()); + cluster1_.setCounterReturnValues(i, success_step, error_step, 0, 0, 0, timeout_step, 0, 0); + cluster2_.setCounterReturnValues(i, success_step2, error_step2, 0, 0, 0, timeout_step2, 0, 0); + sink_->flush(source_); + } + + return buildClusterMap(buffer.toString()); + } + + void removeSecondClusterHelper(Buffer::OwnedImpl& buffer) { + buffer.drain(buffer.length()); + removeClusterFromMap(cluster2_name_); + sink_->flush(source_); + } + + void validateResults(const std::string& data_message, uint64_t success_step, uint64_t error_step, + uint64_t timeout_step, uint64_t timeout_retry_step, uint64_t rejected_step, + uint64_t window_size) { + // Convert to json object. + Json::ObjectSharedPtr json_data_message = Json::Factory::loadFromString(data_message); + EXPECT_EQ(json_data_message->getInteger("rollingCountSemaphoreRejected"), + (window_size * rejected_step)) + << "window_size=" << window_size << ", rejected_step=" << rejected_step; + EXPECT_EQ(json_data_message->getInteger("rollingCountSuccess"), (window_size * success_step)) + << "window_size=" << window_size << ", success_step=" << success_step; + EXPECT_EQ(json_data_message->getInteger("rollingCountTimeout"), + (window_size * (timeout_step + timeout_retry_step))) + << "window_size=" << window_size << ", timeout_step=" << timeout_step + << ", timeout_retry_step=" << timeout_retry_step; + EXPECT_EQ(json_data_message->getInteger("errorCount"), + (window_size * (error_step - timeout_step))) + << "window_size=" << window_size << ", error_step=" << error_step + << ", timeout_step=" << timeout_step; + uint64_t total = error_step + success_step + rejected_step + timeout_retry_step; + EXPECT_EQ(json_data_message->getInteger("requestCount"), (window_size * total)) + << "window_size=" << window_size << ", total=" << total; + + if (total != 0) { + EXPECT_EQ(json_data_message->getInteger("errorPercentage"), + (static_cast(100 * (static_cast(total - success_step) / + static_cast(total))))) + << "total=" << total << ", success_step=" << success_step; + + } else { + EXPECT_EQ(json_data_message->getInteger("errorPercentage"), 0); + } + } + + std::unordered_map buildClusterMap(absl::string_view data_message) { + std::unordered_map cluster_message_map; + std::vector messages = + absl::StrSplit(data_message, "data: ", absl::SkipWhitespace()); + for (auto message : messages) { + // Arrange message to remove ":" that comes from the keepalive sync. + absl::RemoveExtraAsciiWhitespace(&message); + std::string clear_message(absl::StripSuffix(message, ":")); + Json::ObjectSharedPtr json_message = Json::Factory::loadFromString(clear_message); + if (absl::StrContains(json_message->getString("type"), "HystrixCommand")) { + std::string cluster_name(json_message->getString("name")); + cluster_message_map[cluster_name] = message; + } + } + return cluster_message_map; + } + + TestRandomGenerator rand_; + uint64_t window_size_ = rand_.random() % 10 + 5; // Arbitrary reasonable number. + const std::string cluster1_name_{"test_cluster1"}; + ClusterTestInfo cluster1_{cluster1_name_}; + + // Second cluster for "end and remove cluster" tests. + const std::string cluster2_name_{"test_cluster2"}; + ClusterTestInfo cluster2_{cluster2_name_}; + + NiceMock callbacks_; + NiceMock server_; + Upstream::ClusterManager::ClusterInfoMap cluster_map_; + + std::unique_ptr sink_; + NiceMock source_; + NiceMock cluster_manager_; +}; + +TEST_F(HystrixSinkTest, EmptyFlush) { + InSequence s; + Buffer::OwnedImpl buffer = createClusterAndCallbacks(); + // Register callback to sink. + sink_->registerConnection(&callbacks_); + sink_->flush(source_); + std::unordered_map cluster_message_map = + buildClusterMap(buffer.toString()); + validateResults(cluster_message_map[cluster1_name_], 0, 0, 0, 0, 0, window_size_); +} + +TEST_F(HystrixSinkTest, BasicFlow) { + InSequence s; + Buffer::OwnedImpl buffer = createClusterAndCallbacks(); + // Register callback to sink. + sink_->registerConnection(&callbacks_); + + // Only success traffic, check randomly increasing traffic + // Later in the test we'll "shortcut" by constant traffic + uint64_t traffic_counter = 0; + + sink_->flush(source_); // init window with 0 + for (uint64_t i = 0; i < (window_size_ - 1); i++) { + buffer.drain(buffer.length()); + traffic_counter += rand_.random() % 1000; + ON_CALL(cluster1_.success_counter_, value()).WillByDefault(Return(traffic_counter)); + sink_->flush(source_); + } + + std::unordered_map cluster_message_map = + buildClusterMap(buffer.toString()); + + Json::ObjectSharedPtr json_buffer = + Json::Factory::loadFromString(cluster_message_map[cluster1_name_]); + EXPECT_EQ(json_buffer->getInteger("rollingCountSuccess"), traffic_counter); + EXPECT_EQ(json_buffer->getInteger("requestCount"), traffic_counter); + EXPECT_EQ(json_buffer->getInteger("errorCount"), 0); + EXPECT_EQ(json_buffer->getInteger("errorPercentage"), 0); + + // Check mixed traffic. + // Values are unimportant - they represent traffic statistics, and for the purpose of the test any + // arbitrary number will do. Only restriction is that errors >= timeouts, since in Envoy timeouts + // are counted as errors and therefore the code that prepares the stream for the dashboard deducts + // the number of timeouts from total number of errors. + const uint64_t success_step = 13; + const uint64_t error_4xx_step = 12; + const uint64_t error_4xx_retry_step = 11; + const uint64_t error_5xx_step = 10; + const uint64_t error_5xx_retry_step = 9; + const uint64_t timeout_step = 8; + const uint64_t timeout_retry_step = 7; + const uint64_t rejected_step = 6; + + for (uint64_t i = 0; i < (window_size_ + 1); i++) { + buffer.drain(buffer.length()); + cluster1_.setCounterReturnValues(i, success_step, error_4xx_step, error_4xx_retry_step, + error_5xx_step, error_5xx_retry_step, timeout_step, + timeout_retry_step, rejected_step); + sink_->flush(source_); + } + + std::string rolling_map = sink_->printRollingWindows(); + EXPECT_NE(std::string::npos, rolling_map.find(cluster1_name_ + ".total")) + << "cluster1_name = " << cluster1_name_; + + cluster_message_map = buildClusterMap(buffer.toString()); + + // Check stream format and data. + validateResults(cluster_message_map[cluster1_name_], success_step, + error_4xx_step + error_4xx_retry_step + error_5xx_step + error_5xx_retry_step, + timeout_step, timeout_retry_step, rejected_step, window_size_); + + // Check the values are reset. + buffer.drain(buffer.length()); + sink_->resetRollingWindow(); + sink_->flush(source_); + cluster_message_map = buildClusterMap(buffer.toString()); + validateResults(cluster_message_map[cluster1_name_], 0, 0, 0, 0, 0, window_size_); +} + +// +TEST_F(HystrixSinkTest, Disconnect) { + InSequence s; + Buffer::OwnedImpl buffer = createClusterAndCallbacks(); + + sink_->flush(source_); + EXPECT_EQ(buffer.length(), 0); + + // Register callback to sink. + sink_->registerConnection(&callbacks_); + + // Arbitrary numbers for testing. Make sure error > timeout. + uint64_t success_step = 1; + + for (uint64_t i = 0; i < (window_size_ + 1); i++) { + buffer.drain(buffer.length()); + ON_CALL(cluster1_.success_counter_, value()).WillByDefault(Return((i + 1) * success_step)); + sink_->flush(source_); + } + + EXPECT_NE(buffer.length(), 0); + std::unordered_map cluster_message_map = + buildClusterMap(buffer.toString()); + Json::ObjectSharedPtr json_buffer = + Json::Factory::loadFromString(cluster_message_map[cluster1_name_]); + EXPECT_EQ(json_buffer->getInteger("rollingCountSuccess"), (success_step * window_size_)); + + // Disconnect. + buffer.drain(buffer.length()); + sink_->unregisterConnection(&callbacks_); + sink_->flush(source_); + EXPECT_EQ(buffer.length(), 0); + + // Reconnect. + buffer.drain(buffer.length()); + sink_->registerConnection(&callbacks_); + ON_CALL(cluster1_.success_counter_, value()).WillByDefault(Return(success_step)); + sink_->flush(source_); + EXPECT_NE(buffer.length(), 0); + cluster_message_map = buildClusterMap(buffer.toString()); + json_buffer = Json::Factory::loadFromString(cluster_message_map[cluster1_name_]); + EXPECT_EQ(json_buffer->getInteger("rollingCountSuccess"), 0); +} + +TEST_F(HystrixSinkTest, AddCluster) { + InSequence s; + // Register callback to sink. + sink_->registerConnection(&callbacks_); + + // Arbitrary values for testing. Make sure error > timeout. + const uint64_t success_step = 6; + const uint64_t error_step = 3; + const uint64_t timeout_step = 1; + + const uint64_t success_step2 = 44; + const uint64_t error_step2 = 33; + const uint64_t timeout_step2 = 22; + + Buffer::OwnedImpl buffer = createClusterAndCallbacks(); + + // Add cluster and "run" some traffic. + std::unordered_map cluster_message_map = + addSecondClusterAndSendDataHelper(buffer, success_step, error_step, timeout_step, + success_step2, error_step2, timeout_step2); + + // Expect that add worked. + ASSERT_NE(cluster_message_map.find(cluster1_name_), cluster_message_map.end()) + << "cluster1_name = " << cluster1_name_; + ASSERT_NE(cluster_message_map.find(cluster2_name_), cluster_message_map.end()) + << "cluster2_name = " << cluster2_name_; + + // Check stream format and data. + validateResults(cluster_message_map[cluster1_name_], success_step, error_step, timeout_step, 0, 0, + window_size_); + validateResults(cluster_message_map[cluster2_name_], success_step2, error_step2, timeout_step2, 0, + 0, window_size_); +} + +TEST_F(HystrixSinkTest, AddAndRemoveClusters) { + InSequence s; + // Register callback to sink. + sink_->registerConnection(&callbacks_); + + // Arbitrary values for testing. Make sure error > timeout. + const uint64_t success_step = 436; + const uint64_t error_step = 547; + const uint64_t timeout_step = 156; + + const uint64_t success_step2 = 309; + const uint64_t error_step2 = 934; + const uint64_t timeout_step2 = 212; + + Buffer::OwnedImpl buffer = createClusterAndCallbacks(); + + // Add cluster and "run" some traffic. + addSecondClusterAndSendDataHelper(buffer, success_step, error_step, timeout_step, success_step2, + error_step2, timeout_step2); + + // Remove cluster and flush data to sink. + removeSecondClusterHelper(buffer); + + // Check that removed worked. + std::unordered_map cluster_message_map = + buildClusterMap(buffer.toString()); + ASSERT_NE(cluster_message_map.find(cluster1_name_), cluster_message_map.end()) + << "cluster1_name = " << cluster1_name_; + ASSERT_EQ(cluster_message_map.find(cluster2_name_), cluster_message_map.end()) + << "cluster2_name = " << cluster2_name_; + + // Add cluster again and flush data to sink. + addSecondClusterHelper(buffer); + + sink_->flush(source_); + + // Check that add worked. + cluster_message_map = buildClusterMap(buffer.toString()); + ASSERT_NE(cluster_message_map.find(cluster1_name_), cluster_message_map.end()) + << "cluster1_name = " << cluster1_name_; + ASSERT_NE(cluster_message_map.find(cluster2_name_), cluster_message_map.end()) + << "cluster2_name = " << cluster2_name_; + + // Check that old values of test_cluster2 were deleted. + validateResults(cluster_message_map[cluster2_name_], 0, 0, 0, 0, 0, window_size_); +} +} // namespace Hystrix +} // namespace StatSinks +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/stats_sinks/metrics_service/BUILD b/test/extensions/stats_sinks/metrics_service/BUILD index 3dba545540d45..e9a79fae08128 100644 --- a/test/extensions/stats_sinks/metrics_service/BUILD +++ b/test/extensions/stats_sinks/metrics_service/BUILD @@ -39,6 +39,7 @@ envoy_extension_cc_test( "//source/extensions/stat_sinks/metrics_service:config", "//test/common/grpc:grpc_client_integration_lib", "//test/integration:http_integration_lib", + "//test/test_common:utility_lib", "@envoy_api//envoy/service/metrics/v2:metrics_service_cc", ], ) diff --git a/test/extensions/stats_sinks/metrics_service/metrics_service_integration_test.cc b/test/extensions/stats_sinks/metrics_service/metrics_service_integration_test.cc index 6a01afc1fb6a1..fb27999178082 100644 --- a/test/extensions/stats_sinks/metrics_service/metrics_service_integration_test.cc +++ b/test/extensions/stats_sinks/metrics_service/metrics_service_integration_test.cc @@ -7,9 +7,12 @@ #include "test/common/grpc/grpc_client_integration.h" #include "test/integration/http_integration.h" +#include "test/test_common/utility.h" #include "gtest/gtest.h" +using testing::AssertionResult; + namespace Envoy { namespace { @@ -47,15 +50,20 @@ class MetricsServiceIntegrationTest : public HttpIntegrationTest, HttpIntegrationTest::initialize(); } - void waitForMetricsServiceConnection() { - fake_metrics_service_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); + ABSL_MUST_USE_RESULT + AssertionResult waitForMetricsServiceConnection() { + return fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, + fake_metrics_service_connection_); } - void waitForMetricsStream() { - metrics_service_request_ = fake_metrics_service_connection_->waitForNewStream(*dispatcher_); + ABSL_MUST_USE_RESULT + AssertionResult waitForMetricsStream() { + return fake_metrics_service_connection_->waitForNewStream(*dispatcher_, + metrics_service_request_); } - void waitForMetricsRequest() { + ABSL_MUST_USE_RESULT + AssertionResult waitForMetricsRequest() { bool known_histogram_exists = false; bool known_counter_exists = false; bool known_gauge_exists = false; @@ -65,7 +73,7 @@ class MetricsServiceIntegrationTest : public HttpIntegrationTest, // flushed. while (!(known_counter_exists && known_gauge_exists && known_histogram_exists)) { envoy::service::metrics::v2::StreamMetricsMessage request_msg; - metrics_service_request_->waitForGrpcMessage(*dispatcher_, request_msg); + VERIFY_ASSERTION(metrics_service_request_->waitForGrpcMessage(*dispatcher_, request_msg)); EXPECT_STREQ("POST", metrics_service_request_->headers().Method()->value().c_str()); EXPECT_STREQ("/envoy.service.metrics.v2.MetricsService/StreamMetrics", metrics_service_request_->headers().Path()->value().c_str()); @@ -102,12 +110,16 @@ class MetricsServiceIntegrationTest : public HttpIntegrationTest, EXPECT_TRUE(known_counter_exists); EXPECT_TRUE(known_gauge_exists); EXPECT_TRUE(known_histogram_exists); + + return AssertionSuccess(); } void cleanup() { if (fake_metrics_service_connection_ != nullptr) { - fake_metrics_service_connection_->close(); - fake_metrics_service_connection_->waitForDisconnect(); + AssertionResult result = fake_metrics_service_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_metrics_service_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } @@ -130,9 +142,9 @@ TEST_P(MetricsServiceIntegrationTest, BasicFlow) { {"x-lyft-user-id", "123"}}; sendRequestAndWaitForResponse(request_headers, 0, default_response_headers_, 0); - waitForMetricsServiceConnection(); - waitForMetricsStream(); - waitForMetricsRequest(); + ASSERT_TRUE(waitForMetricsServiceConnection()); + ASSERT_TRUE(waitForMetricsStream()); + ASSERT_TRUE(waitForMetricsRequest()); // Send an empty response and end the stream. This should never happen but make sure nothing // breaks and we make a new stream on a follow up request. @@ -149,7 +161,7 @@ TEST_P(MetricsServiceIntegrationTest, BasicFlow) { test_server_->waitForCounterGe("grpc.metrics_service.streams_closed_0", 1); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } cleanup(); } diff --git a/test/extensions/stats_sinks/statsd/config_test.cc b/test/extensions/stats_sinks/statsd/config_test.cc index 68b31829242dc..fc4eeee983c46 100644 --- a/test/extensions/stats_sinks/statsd/config_test.cc +++ b/test/extensions/stats_sinks/statsd/config_test.cc @@ -28,7 +28,7 @@ namespace StatSinks { namespace Statsd { TEST(StatsConfigTest, ValidTcpStatsd) { - const std::string name = StatsSinkNames::get().STATSD; + const std::string name = StatsSinkNames::get().Statsd; envoy::config::metrics::v2::StatsdSink sink_config; sink_config.set_tcp_cluster_name("fake_cluster"); @@ -53,7 +53,7 @@ INSTANTIATE_TEST_CASE_P(IpVersions, StatsConfigParameterizedTest, TestUtility::ipTestParamsToString); TEST_P(StatsConfigParameterizedTest, UdpSinkDefaultPrefix) { - const std::string name = StatsSinkNames::get().STATSD; + const std::string name = StatsSinkNames::get().Statsd; auto defaultPrefix = Common::Statsd::getDefaultPrefix(); envoy::config::metrics::v2::StatsdSink sink_config; @@ -84,7 +84,7 @@ TEST_P(StatsConfigParameterizedTest, UdpSinkDefaultPrefix) { } TEST_P(StatsConfigParameterizedTest, UdpSinkCustomPrefix) { - const std::string name = StatsSinkNames::get().STATSD; + const std::string name = StatsSinkNames::get().Statsd; const std::string customPrefix = "prefix.test"; envoy::config::metrics::v2::StatsdSink sink_config; @@ -116,7 +116,7 @@ TEST_P(StatsConfigParameterizedTest, UdpSinkCustomPrefix) { } TEST(StatsConfigTest, TcpSinkDefaultPrefix) { - const std::string name = StatsSinkNames::get().STATSD; + const std::string name = StatsSinkNames::get().Statsd; envoy::config::metrics::v2::StatsdSink sink_config; auto defaultPrefix = Common::Statsd::getDefaultPrefix(); @@ -139,7 +139,7 @@ TEST(StatsConfigTest, TcpSinkDefaultPrefix) { } TEST(StatsConfigTest, TcpSinkCustomPrefix) { - const std::string name = StatsSinkNames::get().STATSD; + const std::string name = StatsSinkNames::get().Statsd; envoy::config::metrics::v2::StatsdSink sink_config; ProtobufTypes::String prefix = "prefixTest"; @@ -169,7 +169,7 @@ INSTANTIATE_TEST_CASE_P(IpVersions, StatsConfigLoopbackTest, TestUtility::ipTestParamsToString); TEST_P(StatsConfigLoopbackTest, ValidUdpIpStatsd) { - const std::string name = StatsSinkNames::get().STATSD; + const std::string name = StatsSinkNames::get().Statsd; envoy::config::metrics::v2::StatsdSink sink_config; envoy::api::v2::core::Address& address = *sink_config.mutable_address(); diff --git a/test/extensions/transport_sockets/alts/BUILD b/test/extensions/transport_sockets/alts/BUILD index 87a6a480ff3ff..171a97fd28e43 100644 --- a/test/extensions/transport_sockets/alts/BUILD +++ b/test/extensions/transport_sockets/alts/BUILD @@ -11,6 +11,16 @@ load( envoy_package() +envoy_extension_cc_test( + name = "tsi_frame_protector_test", + srcs = ["tsi_frame_protector_test.cc"], + extension_name = "envoy.transport_sockets.alts", + deps = [ + "//source/extensions/transport_sockets/alts:tsi_frame_protector", + "//test/mocks/buffer:buffer_mocks", + ], +) + envoy_extension_cc_test( name = "tsi_handshaker_test", srcs = ["tsi_handshaker_test.cc"], diff --git a/test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc b/test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc new file mode 100644 index 0000000000000..2604e837c8cd7 --- /dev/null +++ b/test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc @@ -0,0 +1,150 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/transport_sockets/alts/tsi_frame_protector.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "src/core/tsi/fake_transport_security.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +using testing::InSequence; +using testing::Invoke; +using testing::NiceMock; +using testing::SaveArg; +using testing::Test; +using testing::_; +using namespace std::string_literals; + +/** + * Test with fake frame protector. The protected frame header is 4 byte length (little endian, + * include header itself) and following the body. + */ +class TsiFrameProtectorTest : public Test { +public: + TsiFrameProtectorTest() + : raw_frame_protector_(tsi_create_fake_frame_protector(nullptr)), + frame_protector_(CFrameProtectorPtr{raw_frame_protector_}) {} + +protected: + tsi_frame_protector* raw_frame_protector_; + TsiFrameProtector frame_protector_; +}; + +TEST_F(TsiFrameProtectorTest, Protect) { + { + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString()); + } + + { + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString()); + + input.add("bar"); + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + EXPECT_EQ("\x07\0\0\0foo\x07\0\0\0bar"s, encrypted.toString()); + } + + { + Buffer::OwnedImpl input, encrypted; + input.add(std::string(20000, 'a')); + + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + + // fake frame protector will split long buffer to 2 "encrypted" frames with length 16K. + std::string expected = + "\0\x40\0\0"s + std::string(16380, 'a') + "\x28\x0e\0\0"s + std::string(3620, 'a'); + EXPECT_EQ(expected, encrypted.toString()); + } +} + +TEST_F(TsiFrameProtectorTest, ProtectError) { + const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable; + tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable; + mock_vtable.protect = [](tsi_frame_protector*, const unsigned char*, size_t*, unsigned char*, + size_t*) { return TSI_INTERNAL_ERROR; }; + raw_frame_protector_->vtable = &mock_vtable; + + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted)); + + raw_frame_protector_->vtable = vtable; +} + +TEST_F(TsiFrameProtectorTest, ProtectFlushError) { + const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable; + tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable; + mock_vtable.protect_flush = [](tsi_frame_protector*, unsigned char*, size_t*, size_t*) { + return TSI_INTERNAL_ERROR; + }; + raw_frame_protector_->vtable = &mock_vtable; + + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted)); + + raw_frame_protector_->vtable = vtable; +} + +TEST_F(TsiFrameProtectorTest, Unprotect) { + { + Buffer::OwnedImpl input, decrypted; + input.add("\x07\0\0\0bar"s); + + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ("bar", decrypted.toString()); + } + + { + Buffer::OwnedImpl input, decrypted; + input.add("\x0a\0\0\0foo"s); + + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ("", decrypted.toString()); + + input.add("bar"); + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ("foobar", decrypted.toString()); + } + + { + Buffer::OwnedImpl input, decrypted; + input.add("\0\x40\0\0"s + std::string(16380, 'a')); + input.add("\x28\x0e\0\0"s + std::string(3620, 'a')); + + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ(std::string(20000, 'a'), decrypted.toString()); + } +} +TEST_F(TsiFrameProtectorTest, UnprotectError) { + const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable; + tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable; + mock_vtable.unprotect = [](tsi_frame_protector*, const unsigned char*, size_t*, unsigned char*, + size_t*) { return TSI_INTERNAL_ERROR; }; + raw_frame_protector_->vtable = &mock_vtable; + + Buffer::OwnedImpl input, decrypted; + input.add("\x0a\0\0\0foo"s); + + EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.unprotect(input, decrypted)); + + raw_frame_protector_->vtable = vtable; +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/transport_sockets/alts/tsi_handshaker_test.cc b/test/extensions/transport_sockets/alts/tsi_handshaker_test.cc index 8ae87bba8ed93..883462f77c042 100644 --- a/test/extensions/transport_sockets/alts/tsi_handshaker_test.cc +++ b/test/extensions/transport_sockets/alts/tsi_handshaker_test.cc @@ -62,27 +62,27 @@ TEST_F(TsiHandshakerTest, DoHandshake) { client_callbacks_.expectDone(TSI_OK, client_sent, client_result); client_handshaker_.next(server_sent); // Initially server_sent is empty. EXPECT_EQ(nullptr, client_result); - EXPECT_EQ("CLIENT_INIT", TestUtility::bufferToString(client_sent).substr(4)); + EXPECT_EQ("CLIENT_INIT", client_sent.toString().substr(4)); server_callbacks_.expectDone(TSI_OK, server_sent, server_result); server_handshaker_.next(client_sent); EXPECT_EQ(nullptr, client_result); - EXPECT_EQ("SERVER_INIT", TestUtility::bufferToString(server_sent).substr(4)); + EXPECT_EQ("SERVER_INIT", server_sent.toString().substr(4)); client_callbacks_.expectDone(TSI_OK, client_sent, client_result); client_handshaker_.next(server_sent); EXPECT_EQ(nullptr, client_result); - EXPECT_EQ("CLIENT_FINISHED", TestUtility::bufferToString(client_sent).substr(4)); + EXPECT_EQ("CLIENT_FINISHED", client_sent.toString().substr(4)); server_callbacks_.expectDone(TSI_OK, server_sent, server_result); server_handshaker_.next(client_sent); EXPECT_NE(nullptr, server_result); - EXPECT_EQ("SERVER_FINISHED", TestUtility::bufferToString(server_sent).substr(4)); + EXPECT_EQ("SERVER_FINISHED", server_sent.toString().substr(4)); client_callbacks_.expectDone(TSI_OK, client_sent, client_result); client_handshaker_.next(server_sent); EXPECT_NE(nullptr, client_result); - EXPECT_EQ("", TestUtility::bufferToString(client_sent)); + EXPECT_EQ("", client_sent.toString()); tsi_peer client_peer; EXPECT_EQ(TSI_OK, tsi_handshaker_result_extract_peer(client_result.get(), &client_peer)); @@ -116,13 +116,13 @@ TEST_F(TsiHandshakerTest, IncompleteData) { client_callbacks_.expectDone(TSI_OK, client_sent, client_result); client_handshaker_.next(server_sent); // Initially server_sent is empty. EXPECT_EQ(nullptr, client_result); - EXPECT_EQ("CLIENT_INIT", TestUtility::bufferToString(client_sent).substr(4)); + EXPECT_EQ("CLIENT_INIT", client_sent.toString().substr(4)); client_sent.drain(3); // make data incomplete server_callbacks_.expectDone(TSI_INCOMPLETE_DATA, server_sent, server_result); server_handshaker_.next(client_sent); EXPECT_EQ(nullptr, client_result); - EXPECT_EQ("", TestUtility::bufferToString(server_sent)); + EXPECT_EQ("", server_sent.toString()); } TEST_F(TsiHandshakerTest, DeferredDelete) { diff --git a/test/fuzz/main.cc b/test/fuzz/main.cc index 6b29c46e5e516..f0158be3ecc7f 100644 --- a/test/fuzz/main.cc +++ b/test/fuzz/main.cc @@ -44,11 +44,11 @@ INSTANTIATE_TEST_CASE_P(CorpusExamples, FuzzerCorpusTest, testing::ValuesIn(test int main(int argc, char** argv) { // Expected usage: [other gtest flags] - RELEASE_ASSERT(argc-- >= 2); + RELEASE_ASSERT(argc-- >= 2, ""); const std::string corpus_path = Envoy::TestEnvironment::getCheckedEnvVar("TEST_SRCDIR") + "/" + Envoy::TestEnvironment::getCheckedEnvVar("TEST_WORKSPACE") + "/" + *++argv; - RELEASE_ASSERT(Envoy::Filesystem::directoryExists(corpus_path)); + RELEASE_ASSERT(Envoy::Filesystem::directoryExists(corpus_path), ""); Envoy::test_corpus_ = Envoy::TestUtility::listFiles(corpus_path, true); testing::InitGoogleTest(&argc, argv); Envoy::Fuzz::Runner::setupEnvironment(argc, argv); diff --git a/test/integration/BUILD b/test/integration/BUILD index d750ff73d07a7..7fb6421dd4355 100644 --- a/test/integration/BUILD +++ b/test/integration/BUILD @@ -221,6 +221,12 @@ envoy_cc_test_library( ], ) +envoy_cc_test( + name = "idle_timeout_integration_test", + srcs = ["idle_timeout_integration_test.cc"], + deps = [":http_protocol_integration_lib"], +) + envoy_cc_test_library( name = "integration_lib", srcs = [ @@ -253,6 +259,7 @@ envoy_cc_test_library( "//include/envoy/server:configuration_interface", "//include/envoy/server:hot_restart_interface", "//include/envoy/server:options_interface", + "//include/envoy/stats:stats_interface", "//source/common/api:api_lib", "//source/common/buffer:buffer_lib", "//source/common/buffer:zero_copy_input_stream_lib", @@ -284,6 +291,7 @@ envoy_cc_test_library( "//test/common/upstream:utility_lib", "//test/config:utility_lib", "//test/mocks/buffer:buffer_mocks", + "//test/mocks/server:server_mocks", "//test/mocks/upstream:upstream_mocks", "//test/test_common:environment_lib", "//test/test_common:network_utility_lib", @@ -320,10 +328,7 @@ envoy_cc_test( ":http_integration_lib", "//source/common/http:header_map_lib", "//source/extensions/access_loggers/file:config", - "//source/extensions/filters/http/cors:config", - "//source/extensions/filters/http/dynamo:config", - "//source/extensions/filters/http/grpc_http1_bridge:config", - "//source/extensions/filters/http/health_check:config", + "//source/extensions/filters/http/buffer:config", "//test/test_common:utility_lib", ], ) @@ -373,10 +378,18 @@ envoy_cc_test( srcs = ["hds_integration_test.cc"], deps = [ ":http_integration_lib", + ":integration_lib", + "//include/envoy/upstream:upstream_interface", + "//source/common/config:metadata_lib", "//source/common/config:resources_lib", + "//source/common/json:config_schemas_lib", + "//source/common/json:json_loader_lib", + "//source/common/network:utility_lib", + "//source/common/upstream:health_checker_lib", + "//source/common/upstream:health_discovery_service_lib", + "//test/common/upstream:utility_lib", "//test/config:utility_lib", "//test/test_common:network_utility_lib", - "//test/test_common:utility_lib", "@envoy_api//envoy/api/v2:eds_cc", "@envoy_api//envoy/service/discovery/v2:hds_cc", ], @@ -418,6 +431,32 @@ envoy_cc_test_library( deps = ["//include/envoy/stats:stats_interface"], ) +envoy_cc_test( + name = "sds_static_integration_test", + srcs = [ + "sds_static_integration_test.cc", + ], + data = [ + "//test/config/integration/certs", + ], + deps = [ + ":http_integration_lib", + "//source/common/event:dispatcher_includes", + "//source/common/event:dispatcher_lib", + "//source/common/network:connection_lib", + "//source/common/network:utility_lib", + "//source/common/ssl:context_config_lib", + "//source/common/ssl:context_lib", + "//source/extensions/filters/listener/tls_inspector:config", + "//source/extensions/transport_sockets/ssl:config", + "//test/mocks/runtime:runtime_mocks", + "//test/mocks/secret:secret_mocks", + "//test/test_common:utility_lib", + "@envoy_api//envoy/config/transport_socket/capture/v2alpha:capture_cc", + "@envoy_api//envoy/data/tap/v2alpha:capture_cc", + ], +) + envoy_cc_test( name = "ssl_integration_test", srcs = [ @@ -466,6 +505,21 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "tcp_conn_pool_integration_test", + srcs = [ + "tcp_conn_pool_integration_test.cc", + ], + deps = [ + ":integration_lib", + "//include/envoy/server:filter_config_interface", + "//include/envoy/tcp:conn_pool_interface", + "//test/server:utility_lib", + "//test/test_common:registry_lib", + "//test/test_common:utility_lib", + ], +) + envoy_cc_test( name = "uds_integration_test", srcs = [ diff --git a/test/integration/README.md b/test/integration/README.md index a6a81daebdad4..f87b4aefef1b5 100644 --- a/test/integration/README.md +++ b/test/integration/README.md @@ -106,3 +106,82 @@ if the changes will be needed by one specific test file, or will be likely reused in other integration tests. If it's likely be reused, please add the appropriate functions to existing utilities or add new test utilities. If it's likely a one-off change, it can be scoped to the existing test file. + + +# Deflaking tests + +The instructions below assume the developer is running tests natively with bazel +rather than in docker. For developers using docker the best workaround today is +to replace `//test/...` on the relevant `ci/do_ci.sh`with the command lines +referenced below and remember to back those changes out before sending the fix +upstream! + +## Reproducing test flakes + +The first step of fixing test flakes is reproducing the test flake. In general +if you have written a test which flakes, you can start by running + +`` +bazel test [test_name] --runs_per_test=1000 +`` + +Which runs the full test many times. If this works, great! If not, it's worth +trying to stress your system more by running more tests in parallel, by setting +`--jobs` and `--local_resources.` + +Once you've managed to reproduce a failure it may be beneficial to limit your +test run to the specific failing test(s) with `--gtest_filter`. This may cause +the test to flake less often (i.e. if two tests are interfering with each other, +scoping to your specific test name may harm rather than help reproducibility.) +but if it works it lets you iterate faster. + +Another helpful tip for debugging is turn turn up Envoy trace logs with +`--test_arg="-l trace"`. Again if the test failure is due to a race, this may make +it harder to reproduce, and it may also hide any custom logging you add, but it's a +handy thing to know of to follow the general flow. + +The full command might look something like + +``` +bazel test //test/integration:http2_upstream_integration_test \ +--test_arg=--gtest_filter="IpVersions/Http2UpstreamIntegrationTest.RouterRequestAndResponseWithBodyNoBuffer/IPv6" \ +--jobs 60 --local_resources 100000000000,100000000000,10000000 --test_arg="-l trace" +``` + +## Debugging test flakes + +Once you've managed to reproduce your test flake, you get to figure out what's +going on. If your failure mode isn't documented below, ideally some combination +of cerr << logging and trace logs will help you sort out what is going on (and +please add to this document as you figure it out!) + +## Unexpected disconnects + +As commented in `HttpIntegrationTest::cleanupUpstreamAndDownstream()`, the +tear-down sequence between upstream, Envoy, and client is somewhat sensitive to +ordering. If a given unit test does not use the provided member variables, for +example opens multiple client or upstream connections, the test author should be +aware of test best practices for clean-up which boil down to "Clean up upstream +first". + +Upstream connections run in their own threads, so if the client disconnects with +open streams, there's a race where Envoy detects the disconnect, and kills the +corresponding upstream stream, which is indistinguishable from an unexpected +disconnect and triggers test failure. Because the client is run from the main +thread, if upstream is closed first, the client will not detect the inverse +close, so no test failure will occur. + +## Unparented upstream connections + +The most common failure mode here is if the test adds additional fake +upstreams for *DS connections (ADS, EDS etc) which are not properly shut down +(for a very sensitive test framework) + +The failure mode here is that during test teardown one closes the DS connection +and then shuts down Envoy. Unfortunately as Envoy is running in its own thread, +it will try to re-establish the *DS connection, sometimes creating a connection +which is then "unparented". The solution here is to explicitly allow Envoy +reconnects before closing the connection, using + +`my_ds_upstream_->set_allow_unexpected_disconnects(true);` + diff --git a/test/integration/ads_integration_test.cc b/test/integration/ads_integration_test.cc index d348cdab997d9..1f09256426c63 100644 --- a/test/integration/ads_integration_test.cc +++ b/test/integration/ads_integration_test.cc @@ -59,20 +59,50 @@ const std::string config = R"EOF( port_value: 0 )EOF"; -class AdsIntegrationTest : public HttpIntegrationTest, public Grpc::GrpcClientIntegrationParamTest { +class AdsIntegrationBaseTest : public HttpIntegrationTest { public: - AdsIntegrationTest() : HttpIntegrationTest(Http::CodecClient::Type::HTTP2, ipVersion(), config) {} + AdsIntegrationBaseTest(Http::CodecClient::Type downstream_protocol, + Network::Address::IpVersion version, + const std::string& config = ConfigHelper::HTTP_PROXY_CONFIG) + : HttpIntegrationTest(downstream_protocol, version, config) {} + + void createAdsConnection(FakeUpstream& upstream) { + ads_upstream_ = &upstream; + AssertionResult result = ads_upstream_->waitForHttpConnection(*dispatcher_, ads_connection_); + RELEASE_ASSERT(result, result.message()); + } - void TearDown() override { - ads_connection_->close(); - ads_connection_->waitForDisconnect(); + void cleanUpAdsConnection() { + ASSERT(ads_upstream_ != nullptr); + + // Don't ASSERT fail if an ADS reconnect ends up unparented. + ads_upstream_->set_allow_unexpected_disconnects(true); + AssertionResult result = ads_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = ads_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); ads_connection_.reset(); + } + +protected: + FakeHttpConnectionPtr ads_connection_; + FakeUpstream* ads_upstream_{}; +}; + +class AdsIntegrationTest : public AdsIntegrationBaseTest, + public Grpc::GrpcClientIntegrationParamTest { +public: + AdsIntegrationTest() + : AdsIntegrationBaseTest(Http::CodecClient::Type::HTTP2, ipVersion(), config) {} + + void TearDown() override { + cleanUpAdsConnection(); test_server_.reset(); fake_upstreams_.clear(); } void createUpstreams() override { - HttpIntegrationTest::createUpstreams(); + AdsIntegrationBaseTest::createUpstreams(); fake_upstreams_.emplace_back( new FakeUpstream(createUpstreamSslContext(), 0, FakeHttpConnection::Type::HTTP2, version_)); } @@ -99,7 +129,7 @@ class AdsIntegrationTest : public HttpIntegrationTest, public Grpc::GrpcClientIn const Protobuf::int32 expected_error_code = Grpc::Status::GrpcStatus::Ok, const std::string& expected_error_message = "") { envoy::api::v2::DiscoveryRequest discovery_request; - ads_stream_->waitForGrpcMessage(*dispatcher_, discovery_request); + VERIFY_ASSERTION(ads_stream_->waitForGrpcMessage(*dispatcher_, discovery_request)); // TODO(PiotrSikora): Remove this hack once fixed internally. if (!(expected_type_url == discovery_request.type_url())) { @@ -171,9 +201,10 @@ class AdsIntegrationTest : public HttpIntegrationTest, public Grpc::GrpcClientIn fake_upstreams_[0]->localAddress()->ip()->port())); } - envoy::api::v2::Listener buildListener(const std::string& name, const std::string& route_config) { - return TestUtility::parseYaml( - fmt::format(R"EOF( + envoy::api::v2::Listener buildListener(const std::string& name, const std::string& route_config, + const std::string& stat_prefix = "ads_test") { + return TestUtility::parseYaml(fmt::format( + R"EOF( name: {} address: socket_address: @@ -183,14 +214,14 @@ class AdsIntegrationTest : public HttpIntegrationTest, public Grpc::GrpcClientIn filters: - name: envoy.http_connection_manager config: - stat_prefix: ads_test + stat_prefix: {} codec_type: HTTP2 rds: route_config_name: {} config_source: {{ ads: {{}} }} http_filters: [{{ name: envoy.router }}] )EOF", - name, Network::Test::getLoopbackAddressString(ipVersion()), route_config)); + name, Network::Test::getLoopbackAddressString(ipVersion()), stat_prefix, route_config)); } envoy::api::v2::RouteConfiguration buildRouteConfig(const std::string& name, @@ -236,10 +267,11 @@ class AdsIntegrationTest : public HttpIntegrationTest, public Grpc::GrpcClientIn } }); setUpstreamProtocol(FakeHttpConnection::Type::HTTP2); - HttpIntegrationTest::initialize(); + AdsIntegrationBaseTest::initialize(); if (ads_stream_ == nullptr) { - ads_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - ads_stream_ = ads_connection_->waitForNewStream(*dispatcher_); + createAdsConnection(*(fake_upstreams_[1])); + AssertionResult result = ads_connection_->waitForNewStream(*dispatcher_, ads_stream_); + RELEASE_ASSERT(result, result.message()); ads_stream_->startGrpcStream(); } } @@ -265,7 +297,6 @@ class AdsIntegrationTest : public HttpIntegrationTest, public Grpc::GrpcClientIn Secret::MockSecretManager secret_manager_; Runtime::MockLoader runtime_; Ssl::ContextManagerImpl context_manager_{runtime_}; - FakeHttpConnectionPtr ads_connection_; FakeStreamPtr ads_stream_; }; @@ -445,22 +476,93 @@ TEST_P(AdsIntegrationTest, Failure) { makeSingleRequest(); } -class AdsFailIntegrationTest : public HttpIntegrationTest, +// Regression test for the use-after-free crash when processing RDS update (#3953). +TEST_P(AdsIntegrationTest, RdsAfterLdsWithNoRdsChanges) { + initialize(); + + // Send initial configuration. + sendDiscoveryResponse(Config::TypeUrl::get().Cluster, + {buildCluster("cluster_0")}, "1"); + sendDiscoveryResponse( + Config::TypeUrl::get().ClusterLoadAssignment, {buildClusterLoadAssignment("cluster_0")}, "1"); + sendDiscoveryResponse( + Config::TypeUrl::get().Listener, {buildListener("listener_0", "route_config_0")}, "1"); + sendDiscoveryResponse( + Config::TypeUrl::get().RouteConfiguration, {buildRouteConfig("route_config_0", "cluster_0")}, + "1"); + test_server_->waitForCounterGe("listener_manager.listener_create_success", 1); + + // Validate that we can process a request. + makeSingleRequest(); + + // Update existing LDS (change stat_prefix). + sendDiscoveryResponse( + Config::TypeUrl::get().Listener, {buildListener("listener_0", "route_config_0", "rds_crash")}, + "2"); + test_server_->waitForCounterGe("listener_manager.listener_create_success", 2); + + // Update existing RDS (no changes). + sendDiscoveryResponse( + Config::TypeUrl::get().RouteConfiguration, {buildRouteConfig("route_config_0", "cluster_0")}, + "2"); + + // Validate that we can process a request again + makeSingleRequest(); +} + +// Regression test for the use-after-free crash when processing RDS update (#3953). +TEST_P(AdsIntegrationTest, RdsAfterLdsWithRdsChange) { + initialize(); + + // Send initial configuration. + sendDiscoveryResponse(Config::TypeUrl::get().Cluster, + {buildCluster("cluster_0")}, "1"); + sendDiscoveryResponse( + Config::TypeUrl::get().ClusterLoadAssignment, {buildClusterLoadAssignment("cluster_0")}, "1"); + sendDiscoveryResponse( + Config::TypeUrl::get().Listener, {buildListener("listener_0", "route_config_0")}, "1"); + sendDiscoveryResponse( + Config::TypeUrl::get().RouteConfiguration, {buildRouteConfig("route_config_0", "cluster_0")}, + "1"); + test_server_->waitForCounterGe("listener_manager.listener_create_success", 1); + + // Validate that we can process a request. + makeSingleRequest(); + + // Update existing LDS (change stat_prefix). + sendDiscoveryResponse(Config::TypeUrl::get().Cluster, + {buildCluster("cluster_1")}, "2"); + sendDiscoveryResponse( + Config::TypeUrl::get().ClusterLoadAssignment, {buildClusterLoadAssignment("cluster_1")}, "2"); + sendDiscoveryResponse( + Config::TypeUrl::get().Listener, {buildListener("listener_0", "route_config_0", "rds_crash")}, + "2"); + test_server_->waitForCounterGe("listener_manager.listener_create_success", 2); + + // Update existing RDS (migrate traffic to cluster_1). + sendDiscoveryResponse( + Config::TypeUrl::get().RouteConfiguration, {buildRouteConfig("route_config_0", "cluster_1")}, + "2"); + + // Validate that we can process a request after RDS update + test_server_->waitForCounterGe("http.ads_test.rds.route_config_0.config_reload", 2); + makeSingleRequest(); +} + +class AdsFailIntegrationTest : public AdsIntegrationBaseTest, public Grpc::GrpcClientIntegrationParamTest { public: AdsFailIntegrationTest() - : HttpIntegrationTest(Http::CodecClient::Type::HTTP2, ipVersion(), config) {} + : AdsIntegrationBaseTest(Http::CodecClient::Type::HTTP2, ipVersion(), config) {} void TearDown() override { - ads_connection_->close(); - ads_connection_->waitForDisconnect(); - ads_connection_.reset(); + cleanUpAdsConnection(); test_server_.reset(); fake_upstreams_.clear(); } void createUpstreams() override { - HttpIntegrationTest::createUpstreams(); + AdsIntegrationBaseTest::createUpstreams(); fake_upstreams_.emplace_back(new FakeUpstream(0, FakeHttpConnection::Type::HTTP2, version_)); } @@ -474,10 +576,9 @@ class AdsFailIntegrationTest : public HttpIntegrationTest, ads_cluster->set_name("ads_cluster"); }); setUpstreamProtocol(FakeHttpConnection::Type::HTTP2); - HttpIntegrationTest::initialize(); + AdsIntegrationBaseTest::initialize(); } - FakeHttpConnectionPtr ads_connection_; FakeStreamPtr ads_stream_; }; @@ -487,28 +588,26 @@ INSTANTIATE_TEST_CASE_P(IpVersionsClientType, AdsFailIntegrationTest, // Validate that we don't crash on failed ADS stream. TEST_P(AdsFailIntegrationTest, ConnectDisconnect) { initialize(); - ads_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - ads_stream_ = ads_connection_->waitForNewStream(*dispatcher_); + createAdsConnection(*fake_upstreams_[1]); + ASSERT_TRUE(ads_connection_->waitForNewStream(*dispatcher_, ads_stream_)); ads_stream_->startGrpcStream(); ads_stream_->finishGrpcStream(Grpc::Status::Internal); } -class AdsConfigIntegrationTest : public HttpIntegrationTest, +class AdsConfigIntegrationTest : public AdsIntegrationBaseTest, public Grpc::GrpcClientIntegrationParamTest { public: AdsConfigIntegrationTest() - : HttpIntegrationTest(Http::CodecClient::Type::HTTP2, ipVersion(), config) {} + : AdsIntegrationBaseTest(Http::CodecClient::Type::HTTP2, ipVersion(), config) {} void TearDown() override { - ads_connection_->close(); - ads_connection_->waitForDisconnect(); - ads_connection_.reset(); + cleanUpAdsConnection(); test_server_.reset(); fake_upstreams_.clear(); } void createUpstreams() override { - HttpIntegrationTest::createUpstreams(); + AdsIntegrationBaseTest::createUpstreams(); fake_upstreams_.emplace_back(new FakeUpstream(0, FakeHttpConnection::Type::HTTP2, version_)); } @@ -530,10 +629,9 @@ class AdsConfigIntegrationTest : public HttpIntegrationTest, eds_config->mutable_ads(); }); setUpstreamProtocol(FakeHttpConnection::Type::HTTP2); - HttpIntegrationTest::initialize(); + AdsIntegrationBaseTest::initialize(); } - FakeHttpConnectionPtr ads_connection_; FakeStreamPtr ads_stream_; }; @@ -543,8 +641,8 @@ INSTANTIATE_TEST_CASE_P(IpVersionsClientType, AdsConfigIntegrationTest, // This is s regression validating that we don't crash on EDS static Cluster that uses ADS. TEST_P(AdsConfigIntegrationTest, EdsClusterWithAdsConfigSource) { initialize(); - ads_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - ads_stream_ = ads_connection_->waitForNewStream(*dispatcher_); + createAdsConnection(*fake_upstreams_[1]); + ASSERT_TRUE(ads_connection_->waitForNewStream(*dispatcher_, ads_stream_)); ads_stream_->startGrpcStream(); ads_stream_->finishGrpcStream(Grpc::Status::Ok); } @@ -564,8 +662,8 @@ TEST_P(AdsIntegrationTest, XdsBatching) { }); pre_worker_start_test_steps_ = [this]() { - ads_connection_ = fake_upstreams_.back()->waitForHttpConnection(*dispatcher_); - ads_stream_ = ads_connection_->waitForNewStream(*dispatcher_); + createAdsConnection(*fake_upstreams_.back()); + ASSERT_TRUE(ads_connection_->waitForNewStream(*dispatcher_, ads_stream_)); ads_stream_->startGrpcStream(); EXPECT_TRUE(compareDiscoveryRequest(Config::TypeUrl::get().ClusterLoadAssignment, "", diff --git a/test/integration/autonomous_upstream.cc b/test/integration/autonomous_upstream.cc index 79fd62ed7fc29..c3c0a5f94f914 100644 --- a/test/integration/autonomous_upstream.cc +++ b/test/integration/autonomous_upstream.cc @@ -9,7 +9,8 @@ void HeaderToInt(const char header_name[], int32_t& return_int, Http::TestHeader if (!header_value.empty()) { uint64_t parsed_value; RELEASE_ASSERT(StringUtil::atoul(header_value.c_str(), parsed_value, 10) && - parsed_value < std::numeric_limits::max()); + parsed_value < std::numeric_limits::max(), + ""); return_int = parsed_value; } } @@ -22,7 +23,7 @@ const char AutonomousStream::RESET_AFTER_REQUEST[] = "reset_after_request"; // For now, assert all streams which are started are completed. // Support for incomplete streams can be added when needed. -AutonomousStream::~AutonomousStream() { RELEASE_ASSERT(complete()); } +AutonomousStream::~AutonomousStream() { RELEASE_ASSERT(complete(), ""); } // By default, automatically send a response when the request is complete. void AutonomousStream::setEndStream(bool end_stream) { @@ -72,7 +73,8 @@ bool AutonomousUpstream::createNetworkFilterChain(Network::Connection& connectio shared_connections_.emplace_back(new SharedConnectionWrapper(connection, true)); AutonomousHttpConnectionPtr http_connection( new AutonomousHttpConnection(*shared_connections_.back(), stats_store_, http_type_, *this)); - http_connection->initialize(); + testing::AssertionResult result = http_connection->initialize(); + RELEASE_ASSERT(result, result.message()); http_connections_.push_back(std::move(http_connection)); return true; } diff --git a/test/integration/echo_integration_test.cc b/test/integration/echo_integration_test.cc index 5375a7167b29d..342ba5866f0d0 100644 --- a/test/integration/echo_integration_test.cc +++ b/test/integration/echo_integration_test.cc @@ -52,7 +52,7 @@ TEST_P(EchoIntegrationTest, Hello) { RawConnectionDriver connection( lookupPort("listener_0"), buffer, [&](Network::ClientConnection&, const Buffer::Instance& data) -> void { - response.append(TestUtility::bufferToString(data)); + response.append(data.toString()); connection.close(); }, version_); @@ -101,7 +101,7 @@ TEST_P(EchoIntegrationTest, AddRemoveListener) { RawConnectionDriver connection( new_listener_port, buffer, [&](Network::ClientConnection&, const Buffer::Instance& data) -> void { - response.append(TestUtility::bufferToString(data)); + response.append(data.toString()); connection.close(); }, version_); diff --git a/test/integration/fake_upstream.cc b/test/integration/fake_upstream.cc index 842d70ea667a5..6aa74011e4f1d 100644 --- a/test/integration/fake_upstream.cc +++ b/test/integration/fake_upstream.cc @@ -23,6 +23,15 @@ #include "test/test_common/printers.h" #include "test/test_common/utility.h" +#include "absl/strings/str_cat.h" + +using namespace std::chrono_literals; + +using std::chrono::milliseconds; +using testing::AssertionFailure; +using testing::AssertionResult; +using testing::AssertionSuccess; + namespace Envoy { FakeStream::FakeStream(FakeHttpConnection& parent, Http::StreamEncoder& encoder) : parent_(parent), encoder_(encoder) { @@ -100,40 +109,63 @@ void FakeStream::onResetStream(Http::StreamResetReason) { decoder_event_.notifyOne(); } -void FakeStream::waitForHeadersComplete() { +AssertionResult FakeStream::waitForHeadersComplete(milliseconds timeout) { Thread::LockGuard lock(lock_); + auto end_time = std::chrono::steady_clock::now() + timeout; while (!headers_) { - decoder_event_.wait(lock_); + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionFailure() << "Timed out waiting for headers."; + } + decoder_event_.waitFor(lock_, 5ms); } + return AssertionSuccess(); } -void FakeStream::waitForData(Event::Dispatcher& client_dispatcher, uint64_t body_length) { +AssertionResult FakeStream::waitForData(Event::Dispatcher& client_dispatcher, uint64_t body_length, + milliseconds timeout) { Thread::LockGuard lock(lock_); + auto start_time = std::chrono::steady_clock::now(); while (bodyLength() < body_length) { - decoder_event_.waitFor(lock_, std::chrono::milliseconds(5)); + if (std::chrono::steady_clock::now() >= start_time + timeout) { + return AssertionFailure() << "Timed out waiting for data."; + } + decoder_event_.waitFor(lock_, 5ms); if (bodyLength() < body_length) { // Run the client dispatcher since we may need to process window updates, etc. client_dispatcher.run(Event::Dispatcher::RunType::NonBlock); } } + return AssertionSuccess(); } -void FakeStream::waitForEndStream(Event::Dispatcher& client_dispatcher) { +AssertionResult FakeStream::waitForEndStream(Event::Dispatcher& client_dispatcher, + milliseconds timeout) { Thread::LockGuard lock(lock_); + auto start_time = std::chrono::steady_clock::now(); while (!end_stream_) { - decoder_event_.waitFor(lock_, std::chrono::milliseconds(5)); + if (std::chrono::steady_clock::now() >= start_time + timeout) { + return AssertionFailure() << "Timed out waiting for end of stream."; + } + decoder_event_.waitFor(lock_, 5ms); if (!end_stream_) { // Run the client dispatcher since we may need to process window updates, etc. client_dispatcher.run(Event::Dispatcher::RunType::NonBlock); } } + return AssertionSuccess(); } -void FakeStream::waitForReset() { +AssertionResult FakeStream::waitForReset(milliseconds timeout) { Thread::LockGuard lock(lock_); + auto start_time = std::chrono::steady_clock::now(); while (!saw_reset_) { - decoder_event_.wait(lock_); // Safe since CondVar::wait won't throw. + if (std::chrono::steady_clock::now() >= start_time + timeout) { + return AssertionFailure() << "Timed out waiting for reset."; + } + // Safe since CondVar::waitFor won't throw. + decoder_event_.waitFor(lock_, 5ms); } + return AssertionSuccess(); } void FakeStream::startGrpcStream() { @@ -161,20 +193,23 @@ FakeHttpConnection::FakeHttpConnection(SharedConnectionWrapper& shared_connectio Network::ReadFilterSharedPtr{new ReadFilter(*this)}); } -void FakeConnectionBase::close() { - shared_connection_.executeOnDispatcher([](Network::Connection& connection) { - connection.close(Network::ConnectionCloseType::FlushWrite); - }); +AssertionResult FakeConnectionBase::close(std::chrono::milliseconds timeout) { + return shared_connection_.executeOnDispatcher( + [](Network::Connection& connection) { + connection.close(Network::ConnectionCloseType::FlushWrite); + }, + timeout); } -void FakeConnectionBase::readDisable(bool disable) { - shared_connection_.executeOnDispatcher( - [disable](Network::Connection& connection) { connection.readDisable(disable); }); +AssertionResult FakeConnectionBase::readDisable(bool disable, std::chrono::milliseconds timeout) { + return shared_connection_.executeOnDispatcher( + [disable](Network::Connection& connection) { connection.readDisable(disable); }, timeout); } -void FakeConnectionBase::enableHalfClose(bool enable) { - shared_connection_.executeOnDispatcher( - [enable](Network::Connection& connection) { connection.enableHalfClose(enable); }); +AssertionResult FakeConnectionBase::enableHalfClose(bool enable, + std::chrono::milliseconds timeout) { + return shared_connection_.executeOnDispatcher( + [enable](Network::Connection& connection) { connection.enableHalfClose(enable); }, timeout); } Http::StreamDecoder& FakeHttpConnection::newStream(Http::StreamEncoder& encoder) { @@ -184,27 +219,42 @@ Http::StreamDecoder& FakeHttpConnection::newStream(Http::StreamEncoder& encoder) return *new_streams_.back(); } -void FakeConnectionBase::waitForDisconnect(bool ignore_spurious_events) { +AssertionResult FakeConnectionBase::waitForDisconnect(bool ignore_spurious_events, + milliseconds timeout) { + ENVOY_LOG(trace, "FakeConnectionBase waiting for disconnect"); + auto end_time = std::chrono::steady_clock::now() + timeout; Thread::LockGuard lock(lock_); while (shared_connection_.connected()) { - connection_event_.wait(lock_); // Safe since CondVar::wait won't throw. + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionResult("Timed out waiting for disconnect."); + } + Thread::CondVar::WaitStatus status = connection_event_.waitFor(lock_, 5ms); // The default behavior of waitForDisconnect is to assume the test cleanly // calls waitForData, waitForNewStream, etc. to handle all events on the // connection. If the caller explicitly notes that other events should be // ignored, continue looping until a disconnect is detected. Otherwise fall // through and hit the assert below. - if (!ignore_spurious_events) { + if ((status == Thread::CondVar::WaitStatus::NoTimeout) && !ignore_spurious_events) { break; } } - ASSERT(!shared_connection_.connected()); + if (shared_connection_.connected()) { + return AssertionFailure() << "Expected disconnect, but got a different event."; + } + ENVOY_LOG(trace, "FakeConnectionBase done waiting for disconnect"); + return AssertionSuccess(); } -void FakeConnectionBase::waitForHalfClose(bool ignore_spurious_events) { +AssertionResult FakeConnectionBase::waitForHalfClose(bool ignore_spurious_events, + milliseconds timeout) { + auto end_time = std::chrono::steady_clock::now() + timeout; Thread::LockGuard lock(lock_); while (!half_closed_) { - connection_event_.wait(lock_); // Safe since CondVar::wait won't throw. + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionFailure() << "Timed out waiting for half close."; + } + connection_event_.waitFor(lock_, 5ms); // Safe since CondVar::waitFor won't throw. // The default behavior of waitForHalfClose is to assume the test cleanly // calls waitForData, waitForNewStream, etc. to handle all events on the // connection. If the caller explicitly notes that other events should be @@ -215,15 +265,22 @@ void FakeConnectionBase::waitForHalfClose(bool ignore_spurious_events) { } } - ASSERT(half_closed_); + return half_closed_ + ? AssertionSuccess() + : (AssertionFailure() << "Expected half close event, but got a different event."); } -FakeStreamPtr FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_dispatcher, - bool ignore_spurious_events) { +AssertionResult FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_dispatcher, + FakeStreamPtr& stream, + bool ignore_spurious_events, + milliseconds timeout) { + auto end_time = std::chrono::steady_clock::now() + timeout; Thread::LockGuard lock(lock_); while (new_streams_.empty()) { - Thread::CondVar::WaitStatus status = - connection_event_.waitFor(lock_, std::chrono::milliseconds(5)); + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionResult("Timed out waiting for new stream."); + } + Thread::CondVar::WaitStatus status = connection_event_.waitFor(lock_, 5ms); // As with waitForDisconnect, by default, waitForNewStream returns after the next event. // If the caller explicitly notes other events should be ignored, it will instead actually // wait for the next new stream, ignoring other events such as onData() @@ -236,10 +293,12 @@ FakeStreamPtr FakeHttpConnection::waitForNewStream(Event::Dispatcher& client_dis } } - ASSERT(!new_streams_.empty()); - FakeStreamPtr stream = std::move(new_streams_.front()); + if (new_streams_.empty()) { + return AssertionFailure() << "Expected new stream event, but got a different event."; + } + stream = std::move(new_streams_.front()); new_streams_.pop_front(); - return stream; + return AssertionSuccess(); } FakeUpstream::FakeUpstream(const std::string& uds_path, FakeHttpConnection::Type type) @@ -277,8 +336,7 @@ FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket FakeUpstream::FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, Network::SocketPtr&& listen_socket, FakeHttpConnection::Type type, bool enable_half_close) - : http_type_(type), socket_(std::move(listen_socket)), - api_(new Api::Impl(std::chrono::milliseconds(10000))), + : http_type_(type), socket_(std::move(listen_socket)), api_(new Api::Impl(milliseconds(10000))), dispatcher_(api_->allocateDispatcher()), handler_(new Server::ConnectionHandlerImpl(ENVOY_LOGGER(), *dispatcher_)), allow_unexpected_disconnects_(false), enable_half_close_(enable_half_close), listener_(*this), @@ -323,71 +381,81 @@ void FakeUpstream::threadRoutine() { } } -FakeHttpConnectionPtr FakeUpstream::waitForHttpConnection(Event::Dispatcher& client_dispatcher) { - FakeHttpConnectionPtr connection; +AssertionResult FakeUpstream::waitForHttpConnection(Event::Dispatcher& client_dispatcher, + FakeHttpConnectionPtr& connection, + milliseconds timeout) { + auto end_time = std::chrono::steady_clock::now() + timeout; { Thread::LockGuard lock(lock_); while (new_connections_.empty()) { - new_connection_event_.waitFor(lock_, std::chrono::milliseconds(5)); + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionFailure() << "Timed out waiting for new connection."; + } + new_connection_event_.waitFor(lock_, 5ms); if (new_connections_.empty()) { // Run the client dispatcher since we may need to process window updates, etc. client_dispatcher.run(Event::Dispatcher::RunType::NonBlock); } } - ASSERT(!new_connections_.empty()); + if (new_connections_.empty()) { + return AssertionFailure() << "Got a new connection event, but didn't create a connection."; + } connection = std::make_unique(consumeConnection(), stats_store_, http_type_); } - connection->initialize(); - connection->readDisable(false); - return connection; + VERIFY_ASSERTION(connection->initialize()); + VERIFY_ASSERTION(connection->readDisable(false)); + return AssertionSuccess(); } -FakeHttpConnectionPtr +AssertionResult FakeUpstream::waitForHttpConnection(Event::Dispatcher& client_dispatcher, - std::vector>& upstreams) { - for (;;) { + std::vector>& upstreams, + FakeHttpConnectionPtr& connection, milliseconds timeout) { + auto end_time = std::chrono::steady_clock::now() + timeout; + while (std::chrono::steady_clock::now() < end_time) { for (auto it = upstreams.begin(); it != upstreams.end(); ++it) { FakeUpstream& upstream = **it; Thread::ReleasableLockGuard lock(upstream.lock_); if (upstream.new_connections_.empty()) { - upstream.new_connection_event_.waitFor(upstream.lock_, std::chrono::milliseconds(5)); + upstream.new_connection_event_.waitFor(upstream.lock_, 5ms); } if (upstream.new_connections_.empty()) { // Run the client dispatcher since we may need to process window updates, etc. client_dispatcher.run(Event::Dispatcher::RunType::NonBlock); } else { - FakeHttpConnectionPtr connection(new FakeHttpConnection( - upstream.consumeConnection(), upstream.stats_store_, upstream.http_type_)); + connection = std::make_unique( + upstream.consumeConnection(), upstream.stats_store_, upstream.http_type_); lock.release(); - connection->initialize(); - connection->readDisable(false); - return connection; + VERIFY_ASSERTION(connection->initialize()); + VERIFY_ASSERTION(connection->readDisable(false)); + return AssertionSuccess(); } } } + return AssertionFailure() << "Timed out waiting for HTTP connection."; } -FakeRawConnectionPtr FakeUpstream::waitForRawConnection(std::chrono::milliseconds wait_for_ms) { - FakeRawConnectionPtr connection; +AssertionResult FakeUpstream::waitForRawConnection(FakeRawConnectionPtr& connection, + milliseconds timeout) { { Thread::LockGuard lock(lock_); if (new_connections_.empty()) { ENVOY_LOG(debug, "waiting for raw connection"); - new_connection_event_.waitFor(lock_, wait_for_ms); // Safe since CondVar::wait won't throw. + new_connection_event_.waitFor(lock_, timeout); // Safe since CondVar::waitFor won't throw. } if (new_connections_.empty()) { - return nullptr; + return AssertionFailure() << "Timed out waiting for raw connection"; } connection = std::make_unique(consumeConnection()); } - connection->initialize(); - connection->readDisable(false); - connection->enableHalfClose(enable_half_close_); - return connection; + VERIFY_ASSERTION(connection->initialize()); + VERIFY_ASSERTION(connection->readDisable(false)); + VERIFY_ASSERTION(connection->enableHalfClose(enable_half_close_)); + return AssertionSuccess(); } SharedConnectionWrapper& FakeUpstream::consumeConnection() { @@ -398,37 +466,56 @@ SharedConnectionWrapper& FakeUpstream::consumeConnection() { return connection_wrapper->shared_connection(); } -std::string FakeRawConnection::waitForData(uint64_t num_bytes) { +AssertionResult FakeRawConnection::waitForData(uint64_t num_bytes, std::string* data, + milliseconds timeout) { Thread::LockGuard lock(lock_); + ENVOY_LOG(debug, "waiting for {} bytes of data", num_bytes); + auto end_time = std::chrono::steady_clock::now() + timeout; while (data_.size() != num_bytes) { - ENVOY_LOG(debug, "waiting for {} bytes of data", num_bytes); - connection_event_.wait(lock_); // Safe since CondVar::wait won't throw. + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionFailure() << "Timed out waiting for data."; + } + connection_event_.waitFor(lock_, 5ms); // Safe since CondVar::waitFor won't throw. + } + if (data != nullptr) { + *data = data_; } - return data_; + return AssertionSuccess(); } -std::string -FakeRawConnection::waitForData(const std::function& data_validator) { +AssertionResult +FakeRawConnection::waitForData(const std::function& data_validator, + std::string* data, milliseconds timeout) { Thread::LockGuard lock(lock_); + ENVOY_LOG(debug, "waiting for data"); + auto end_time = std::chrono::steady_clock::now() + timeout; while (!data_validator(data_)) { - ENVOY_LOG(debug, "waiting for data"); - connection_event_.wait(lock_); // Safe since CondVar::wait won't throw. + if (std::chrono::steady_clock::now() >= end_time) { + return AssertionFailure() << "Timed out waiting for data."; + } + connection_event_.waitFor(lock_, 5ms); // Safe since CondVar::waitFor won't throw. } - return data_; + if (data != nullptr) { + *data = data_; + } + return AssertionSuccess(); } -void FakeRawConnection::write(const std::string& data, bool end_stream) { - shared_connection_.executeOnDispatcher([&data, end_stream](Network::Connection& connection) { - Buffer::OwnedImpl to_write(data); - connection.write(to_write, end_stream); - }); +AssertionResult FakeRawConnection::write(const std::string& data, bool end_stream, + milliseconds timeout) { + return shared_connection_.executeOnDispatcher( + [&data, end_stream](Network::Connection& connection) { + Buffer::OwnedImpl to_write(data); + connection.write(to_write, end_stream); + }, + timeout); } Network::FilterStatus FakeRawConnection::ReadFilter::onData(Buffer::Instance& data, bool end_stream) { Thread::LockGuard lock(parent_.lock_); ENVOY_LOG(debug, "got {} bytes", data.length()); - parent_.data_.append(TestUtility::bufferToString(data)); + parent_.data_.append(data.toString()); parent_.half_closed_ = end_stream; data.drain(data.length()); parent_.connection_event_.notifyOne(); diff --git a/test/integration/fake_upstream.h b/test/integration/fake_upstream.h index ec82d2be25936..435801c4e3abb 100644 --- a/test/integration/fake_upstream.h +++ b/test/integration/fake_upstream.h @@ -43,7 +43,10 @@ class FakeStream : public Http::StreamDecoder, uint64_t bodyLength() { return body_.length(); } Buffer::Instance& body() { return body_; } - bool complete() { return end_stream_; } + bool complete() { + Thread::LockGuard lock(lock_); + return end_stream_; + } void encode100ContinueHeaders(const Http::HeaderMapImpl& headers); void encodeHeaders(const Http::HeaderMapImpl& headers, bool end_stream); void encodeData(uint64_t size, bool end_stream); @@ -53,10 +56,24 @@ class FakeStream : public Http::StreamDecoder, const Http::HeaderMap& headers() { return *headers_; } void setAddServedByHeader(bool add_header) { add_served_by_header_ = add_header; } const Http::HeaderMapPtr& trailers() { return trailers_; } - void waitForHeadersComplete(); - void waitForData(Event::Dispatcher& client_dispatcher, uint64_t body_length); - void waitForEndStream(Event::Dispatcher& client_dispatcher); - void waitForReset(); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForHeadersComplete(std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForData(Event::Dispatcher& client_dispatcher, uint64_t body_length, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForEndStream(Event::Dispatcher& client_dispatcher, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForReset(std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); // gRPC convenience methods. void startGrpcStream(); @@ -78,26 +95,43 @@ class FakeStream : public Http::StreamDecoder, ENVOY_LOG(debug, "Received gRPC message: {}", message.DebugString()); decoded_grpc_frames_.erase(decoded_grpc_frames_.begin()); } - template void waitForGrpcMessage(Event::Dispatcher& client_dispatcher, T& message) { + template + ABSL_MUST_USE_RESULT testing::AssertionResult + waitForGrpcMessage(Event::Dispatcher& client_dispatcher, T& message, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout) { + auto end_time = std::chrono::steady_clock::now() + timeout; ENVOY_LOG(debug, "Waiting for gRPC message..."); if (!decoded_grpc_frames_.empty()) { decodeGrpcFrame(message); - return; + return AssertionSuccess(); + } + if (!waitForData(client_dispatcher, 5, timeout)) { + return testing::AssertionFailure() << "Timed out waiting for start of gRPC message."; } - waitForData(client_dispatcher, 5); { Thread::LockGuard lock(lock_); - EXPECT_TRUE(grpc_decoder_.decode(body(), decoded_grpc_frames_)); + if (!grpc_decoder_.decode(body(), decoded_grpc_frames_)) { + return testing::AssertionFailure() + << "Couldn't decode gRPC data frame: " << body().toString(); + } } if (decoded_grpc_frames_.size() < 1) { - waitForData(client_dispatcher, grpc_decoder_.length()); + timeout = std::chrono::duration_cast( + end_time - std::chrono::steady_clock::now()); + if (!waitForData(client_dispatcher, grpc_decoder_.length(), timeout)) { + return testing::AssertionFailure() << "Timed out waiting for end of gRPC message."; + } { Thread::LockGuard lock(lock_); - EXPECT_TRUE(grpc_decoder_.decode(body(), decoded_grpc_frames_)); + if (!grpc_decoder_.decode(body(), decoded_grpc_frames_)) { + return testing::AssertionFailure() + << "Couldn't decode gRPC data frame: " << body().toString(); + } } } decodeGrpcFrame(message); ENVOY_LOG(debug, "Received gRPC message: {}", message.DebugString()); + return AssertionSuccess(); } // Http::StreamDecoder @@ -193,24 +227,41 @@ class SharedConnectionWrapper : public Network::ConnectionCallbacks { // wait-for-completion. If the connection is disconnected, either prior to post or when the // dispatcher schedules the callback, we silently ignore if allow_unexpected_disconnects_ // is set. - void executeOnDispatcher(std::function f) { + ABSL_MUST_USE_RESULT + testing::AssertionResult + executeOnDispatcher(std::function f, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout) { Thread::LockGuard lock(lock_); if (disconnected_) { - return; + return testing::AssertionSuccess(); } Thread::CondVar callback_ready_event; - connection_.dispatcher().post([this, f, &callback_ready_event]() -> void { - // The use of connected() here, vs. !disconnected_, is because we want to use the lock_ - // acquisition to briefly serialize. This avoids us entering this completion and issuing a - // notifyOne() until the wait() is ready to receive it below. - if (connected()) { - f(connection_); - } else { - RELEASE_ASSERT(allow_unexpected_disconnects_); - } - callback_ready_event.notifyOne(); - }); - callback_ready_event.wait(lock_); + bool unexpected_disconnect = false; + connection_.dispatcher().post( + [this, f, &callback_ready_event, &unexpected_disconnect]() -> void { + // The use of connected() here, vs. !disconnected_, is because we want to use the lock_ + // acquisition to briefly serialize. This avoids us entering this completion and issuing a + // notifyOne() until the wait() is ready to receive it below. + if (connected()) { + f(connection_); + } else { + unexpected_disconnect = true; + } + callback_ready_event.notifyOne(); + }); + Thread::CondVar::WaitStatus status = callback_ready_event.waitFor(lock_, timeout); + if (status == Thread::CondVar::WaitStatus::Timeout) { + return testing::AssertionFailure() << "Timed out while executing on dispatcher."; + } + if (unexpected_disconnect && !allow_unexpected_disconnects_) { + return testing::AssertionFailure() + << "The connection disconnected unexpectedly, and allow_unexpected_disconnects_ is " + "false." + "\n See " + "https://github.com/envoyproxy/envoy/blob/master/test/integration/README.md#" + "unexpected-disconnects"; + } + return testing::AssertionSuccess(); } private: @@ -242,7 +293,13 @@ class QueuedConnectionWrapper : public LinkedObject { allow_unexpected_disconnects_(allow_unexpected_disconnects) { shared_connection_.addDisconnectCallback([this] { Thread::LockGuard lock(lock_); - RELEASE_ASSERT(parented_ || allow_unexpected_disconnects_); + RELEASE_ASSERT(parented_ || allow_unexpected_disconnects_, + "An queued upstream connection was torn down without being associated " + "with a fake connection. Either manage the connection via " + "waitForRawConnection() or waitForHttpConnection(), or " + "set_allow_unexpected_disconnects(true).\n See " + "https://github.com/envoyproxy/envoy/blob/master/test/integration/README.md#" + "unparented-upstream-connections"); }); } @@ -263,27 +320,45 @@ class QueuedConnectionWrapper : public LinkedObject { /** * Base class for both fake raw connections and fake HTTP connections. */ -class FakeConnectionBase { +class FakeConnectionBase : public Logger::Loggable { public: virtual ~FakeConnectionBase() { ASSERT(initialized_); ASSERT(disconnect_callback_handle_ != nullptr); shared_connection_.removeDisconnectCallback(disconnect_callback_handle_); } - void close(); - void readDisable(bool disable); - // By default waitForDisconnect and waitForHalfClose assume the next event is a disconnect and - // fails an assert if an unexpected event occurs. If a caller truly wishes to wait until - // disconnect, set ignore_spurious_events = true. - void waitForDisconnect(bool ignore_spurious_events = false); - void waitForHalfClose(bool ignore_spurious_events = false); - - virtual void initialize() { + + ABSL_MUST_USE_RESULT + testing::AssertionResult close(std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + readDisable(bool disable, std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + // By default waitForDisconnect and waitForHalfClose assume the next event is + // a disconnect and return an AssertionFailure if an unexpected event occurs. + // If a caller truly wishes to wait until disconnect, set + // ignore_spurious_events = true. + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForDisconnect(bool ignore_spurious_events = false, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForHalfClose(bool ignore_spurious_events = false, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + virtual testing::AssertionResult initialize() { initialized_ = true; disconnect_callback_handle_ = shared_connection_.addDisconnectCallback([this] { connection_event_.notifyOne(); }); + return testing::AssertionSuccess(); } - void enableHalfClose(bool enabled); + ABSL_MUST_USE_RESULT + testing::AssertionResult + enableHalfClose(bool enabled, std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); SharedConnectionWrapper& shared_connection() { return shared_connection_; } // The same caveats apply here as in SharedConnectionWrapper::connection(). Network::Connection& connection() const { return shared_connection_.connection(); } @@ -309,15 +384,20 @@ class FakeHttpConnection : public Http::ServerConnectionCallbacks, public FakeCo enum class Type { HTTP1, HTTP2 }; FakeHttpConnection(SharedConnectionWrapper& shared_connection, Stats::Store& store, Type type); + // By default waitForNewStream assumes the next event is a new stream and - // fails an assert if an unexpected event occurs. If a caller truly wishes to - // wait for a new stream, set ignore_spurious_events = true. - FakeStreamPtr waitForNewStream(Event::Dispatcher& client_dispatcher, - bool ignore_spurious_events = false); + // returns AssertionFaliure if an unexpected event occurs. If a caller truly + // wishes to wait for a new stream, set ignore_spurious_events = true. Returns + // the new stream via the stream argument. + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForNewStream(Event::Dispatcher& client_dispatcher, FakeStreamPtr& stream, + bool ignore_spurious_events = false, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); // Http::ServerConnectionCallbacks Http::StreamDecoder& newStream(Http::StreamEncoder& response_encoder) override; - void onGoAway() override { NOT_IMPLEMENTED; } + void onGoAway() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } private: struct ReadFilter : public Network::ReadFilterBaseImpl { @@ -341,23 +421,42 @@ typedef std::unique_ptr FakeHttpConnectionPtr; /** * Fake raw connection for integration testing. */ -class FakeRawConnection : Logger::Loggable, public FakeConnectionBase { +class FakeRawConnection : public FakeConnectionBase { public: FakeRawConnection(SharedConnectionWrapper& shared_connection) : FakeConnectionBase(shared_connection) {} typedef const std::function ValidatorFunction; - std::string waitForData(uint64_t num_bytes); - // Wait until data_validator returns true. - // example usage: waitForData(FakeRawConnection::waitForInexactMatch("foo")); - std::string waitForData(const ValidatorFunction& data_validator); - void write(const std::string& data, bool end_stream = false); + // Writes to data. If data is nullptr, discards the received data. + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForData(uint64_t num_bytes, std::string* data = nullptr, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); - void initialize() override { - shared_connection_.executeOnDispatcher([this](Network::Connection& connection) { - connection.addReadFilter(Network::ReadFilterSharedPtr{new ReadFilter(*this)}); - }); - FakeConnectionBase::initialize(); + // Wait until data_validator returns true. + // example usage: + // std::string data; + // ASSERT_TRUE(waitForData(FakeRawConnection::waitForInexactMatch("foo"), &data)); + // EXPECT_EQ(data, "foobar"); + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForData(const ValidatorFunction& data_validator, std::string* data = nullptr, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult write(const std::string& data, bool end_stream = false, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult initialize() override { + testing::AssertionResult result = + shared_connection_.executeOnDispatcher([this](Network::Connection& connection) { + connection.addReadFilter(Network::ReadFilterSharedPtr{new ReadFilter(*this)}); + }); + if (!result) { + return result; + } + return FakeConnectionBase::initialize(); } // Creates a ValidatorFunction which returns true when data_to_wait_for is @@ -399,15 +498,26 @@ class FakeUpstream : Logger::Loggable, ~FakeUpstream(); FakeHttpConnection::Type httpType() { return http_type_; } - FakeHttpConnectionPtr waitForHttpConnection(Event::Dispatcher& client_dispatcher); - FakeRawConnectionPtr - waitForRawConnection(std::chrono::milliseconds wait_for_ms = std::chrono::milliseconds{10000}); + + // Returns the new connection via the connection argument. + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForHttpConnection(Event::Dispatcher& client_dispatcher, FakeHttpConnectionPtr& connection, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); + + ABSL_MUST_USE_RESULT + testing::AssertionResult + waitForRawConnection(FakeRawConnectionPtr& connection, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); Network::Address::InstanceConstSharedPtr localAddress() const { return socket_->localAddress(); } // Wait for one of the upstreams to receive a connection - static FakeHttpConnectionPtr + ABSL_MUST_USE_RESULT + static testing::AssertionResult waitForHttpConnection(Event::Dispatcher& client_dispatcher, - std::vector>& upstreams); + std::vector>& upstreams, + FakeHttpConnectionPtr& connection, + std::chrono::milliseconds timeout = TestUtility::DefaultTimeout); // Network::FilterChainManager const Network::FilterChain* findFilterChain(const Network::ConnectionSocket&) const override { @@ -421,10 +531,12 @@ class FakeUpstream : Logger::Loggable, bool createListenerFilterChain(Network::ListenerFilterManager& listener) override; void set_allow_unexpected_disconnects(bool value) { allow_unexpected_disconnects_ = value; } + // Stops the dispatcher loop and joins the listening thread. + void cleanUp(); + protected: Stats::IsolatedStoreImpl stats_store_; const FakeHttpConnection::Type http_type_; - void cleanUp(); private: FakeUpstream(Network::TransportSocketFactoryPtr&& transport_socket_factory, @@ -474,4 +586,7 @@ class FakeUpstream : Logger::Loggable, FakeListener listener_; const Network::FilterChainSharedPtr filter_chain_; }; + +typedef std::unique_ptr FakeUpstreamPtr; + } // namespace Envoy diff --git a/test/integration/h1_capture_corpus/upstream_extra_crlf.pb_text b/test/integration/h1_capture_corpus/upstream_extra_crlf.pb_text new file mode 100644 index 0000000000000..6407b0be219e5 --- /dev/null +++ b/test/integration/h1_capture_corpus/upstream_extra_crlf.pb_text @@ -0,0 +1,38 @@ +events { + downstream_send_bytes: "POST /test/long/url HTTP/1.1\r\nhost: host\r\nx-lyft-user-id: 123\r\nx-forwarded-for: 10.0.0.1\r\ntransfer-encoding: chunked\r\n\r\n400\r\naaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa;aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\r\n0\r\n\r\n" +} +events { + upstream_recv_bytes { + } +} +events { + upstream_recv_bytes { + } +} +events { + upstream_send_bytes: "\r" +} +events { +} +events { + upstream_send_bytes: "\nStack trace:\n" +} +events { + upstream_recv_bytes { + } +} +events { + upstream_send_bytes: "\nStack trace:\n" +} +events { + upstream_send_bytes: "" +} +events { + upstream_recv_bytes { + } +} +events { + upstream_send_bytes: "1" +} +events { +} diff --git a/test/integration/h1_capture_fuzz_test.cc b/test/integration/h1_capture_fuzz_test.cc index 380cd22902dac..fe7d41bf0c51c 100644 --- a/test/integration/h1_capture_fuzz_test.cc +++ b/test/integration/h1_capture_fuzz_test.cc @@ -37,9 +37,8 @@ class H1FuzzIntegrationTest : public HttpIntegrationTest { break; case test::integration::Event::kUpstreamSendBytes: if (fake_upstream_connection == nullptr) { - fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(max_wait_ms_); - // If we timed out, we fail out. - if (fake_upstream_connection == nullptr) { + if (!fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection, max_wait_ms_)) { + // If we timed out, we fail out. tcp_client->close(); return; } @@ -49,7 +48,10 @@ class H1FuzzIntegrationTest : public HttpIntegrationTest { tcp_client->close(); return; } - fake_upstream_connection->write(event.upstream_send_bytes()); + { + AssertionResult result = fake_upstream_connection->write(event.upstream_send_bytes()); + RELEASE_ASSERT(result, result.message()); + } break; case test::integration::Event::kUpstreamRecvBytes: // TODO(htuch): Should we wait for some data? @@ -61,9 +63,11 @@ class H1FuzzIntegrationTest : public HttpIntegrationTest { } if (fake_upstream_connection != nullptr) { if (fake_upstream_connection->connected()) { - fake_upstream_connection->close(); + AssertionResult result = fake_upstream_connection->close(); + RELEASE_ASSERT(result, result.message()); } - fake_upstream_connection->waitForDisconnect(true); + AssertionResult result = fake_upstream_connection->waitForDisconnect(true); + RELEASE_ASSERT(result, result.message()); } tcp_client->close(); } @@ -76,7 +80,7 @@ class H1FuzzIntegrationTest : public HttpIntegrationTest { // Fuzz the H1 processing pipeline. DEFINE_PROTO_FUZZER(const test::integration::CaptureFuzzTestCase& input) { // Pick an IP version to use for loopback, it doesn't matter which. - RELEASE_ASSERT(TestEnvironment::getIpVersionsForTest().size() > 0); + RELEASE_ASSERT(TestEnvironment::getIpVersionsForTest().size() > 0, ""); const auto ip_version = TestEnvironment::getIpVersionsForTest()[0]; H1FuzzIntegrationTest h1_fuzz_integration_test(ip_version); h1_fuzz_integration_test.replay(input); diff --git a/test/integration/hds_integration_test.cc b/test/integration/hds_integration_test.cc index 6e8235785048f..af3053cce249e 100644 --- a/test/integration/hds_integration_test.cc +++ b/test/integration/hds_integration_test.cc @@ -1,14 +1,21 @@ #include "envoy/api/v2/eds.pb.h" #include "envoy/api/v2/endpoint/endpoint.pb.h" #include "envoy/service/discovery/v2/hds.pb.h" +#include "envoy/upstream/upstream.h" +#include "common/config/metadata.h" #include "common/config/resources.h" +#include "common/network/utility.h" +#include "common/protobuf/utility.h" +#include "common/upstream/health_checker_impl.h" +#include "common/upstream/health_discovery_service.h" +#include "test/common/upstream/utility.h" #include "test/config/utility.h" #include "test/integration/http_integration.h" #include "test/test_common/network_utility.h" -#include "test/test_common/utility.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" namespace Envoy { @@ -27,103 +34,435 @@ class HdsIntegrationTest : public HttpIntegrationTest, void initialize() override { setUpstreamCount(upstream_endpoints_); - config_helper_.addConfigModifier([this](envoy::config::bootstrap::v2::Bootstrap& bootstrap) { + config_helper_.addConfigModifier([](envoy::config::bootstrap::v2::Bootstrap& bootstrap) { // Setup hds and corresponding gRPC cluster. - auto* hds_confid = bootstrap.mutable_hds_config(); - hds_confid->set_api_type(envoy::api::v2::core::ApiConfigSource::GRPC); - hds_confid->add_grpc_services()->mutable_envoy_grpc()->set_cluster_name("hds_delegate"); + auto* hds_config = bootstrap.mutable_hds_config(); + hds_config->set_api_type(envoy::api::v2::core::ApiConfigSource::GRPC); + hds_config->add_grpc_services()->mutable_envoy_grpc()->set_cluster_name("hds_cluster"); auto* hds_cluster = bootstrap.mutable_static_resources()->add_clusters(); hds_cluster->MergeFrom(bootstrap.static_resources().clusters()[0]); hds_cluster->mutable_circuit_breakers()->Clear(); - hds_cluster->set_name("hds_delegate"); + hds_cluster->set_name("hds_cluster"); hds_cluster->mutable_http2_protocol_options(); - // Switch predefined cluster_0 to EDS filesystem sourcing. - // TODO(lilika): Remove eds dependency auto* cluster_0 = bootstrap.mutable_static_resources()->mutable_clusters(0); cluster_0->mutable_hosts()->Clear(); - cluster_0->set_type(envoy::api::v2::Cluster::EDS); - auto* eds_cluster_config = cluster_0->mutable_eds_cluster_config(); - eds_cluster_config->mutable_eds_config()->set_path(eds_helper_.eds_path()); }); + HttpIntegrationTest::initialize(); - hds_upstream_ = fake_upstreams_[0].get(); - for (uint32_t i = 0; i < upstream_endpoints_; ++i) { - service_upstream_[i] = fake_upstreams_[i + 1].get(); - } + + // Endpoint connections + host_upstream_.reset(new FakeUpstream(0, FakeHttpConnection::Type::HTTP1, version_)); + host2_upstream_.reset(new FakeUpstream(0, FakeHttpConnection::Type::HTTP1, version_)); } + // Sets up a connection between Envoy and the management server. void waitForHdsStream() { - fake_hds_connection_ = hds_upstream_->waitForHttpConnection(*dispatcher_); - hds_stream_ = fake_hds_connection_->waitForNewStream(*dispatcher_); + AssertionResult result = + hds_upstream_->waitForHttpConnection(*dispatcher_, hds_fake_connection_); + RELEASE_ASSERT(result, result.message()); + result = hds_fake_connection_->waitForNewStream(*dispatcher_, hds_stream_); + RELEASE_ASSERT(result, result.message()); } - void requestHealthCheckSpecifier() { - envoy::service::discovery::v2::HealthCheckSpecifier server_health_check_specifier; - server_health_check_specifier.mutable_interval()->set_nanos(500000000); // 500ms + // Envoy sends healthcheck messages to the endpoints + void healthcheckEndpoints(std::string cluster2 = "") { + ASSERT_TRUE(host_upstream_->waitForHttpConnection(*dispatcher_, host_fake_connection_)); + ASSERT_TRUE(host_fake_connection_->waitForNewStream(*dispatcher_, host_stream_)); + ASSERT_TRUE(host_stream_->waitForEndStream(*dispatcher_)); - hds_stream_->sendGrpcMessage(server_health_check_specifier); - // Wait until the request has been received by Envoy. - test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); - } + EXPECT_STREQ(host_stream_->headers().Path()->value().c_str(), "/healthcheck"); + EXPECT_STREQ(host_stream_->headers().Method()->value().c_str(), "GET"); + EXPECT_STREQ(host_stream_->headers().Host()->value().c_str(), "anna"); - void cleanupUpstreamConnection() { - codec_client_->close(); - if (fake_upstream_connection_ != nullptr) { - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + if (cluster2 != "") { + ASSERT_TRUE(host2_upstream_->waitForHttpConnection(*dispatcher_, host2_fake_connection_)); + ASSERT_TRUE(host2_fake_connection_->waitForNewStream(*dispatcher_, host2_stream_)); + ASSERT_TRUE(host2_stream_->waitForEndStream(*dispatcher_)); + + EXPECT_STREQ(host2_stream_->headers().Path()->value().c_str(), "/healthcheck"); + EXPECT_STREQ(host2_stream_->headers().Method()->value().c_str(), "GET"); + EXPECT_STREQ(host2_stream_->headers().Host()->value().c_str(), cluster2.c_str()); } } + // Clean up the connection between Envoy and the management server void cleanupHdsConnection() { - if (fake_hds_connection_ != nullptr) { - fake_hds_connection_->close(); - fake_hds_connection_->waitForDisconnect(); + if (hds_fake_connection_ != nullptr) { + AssertionResult result = hds_fake_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = hds_fake_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } - static constexpr uint32_t upstream_endpoints_ = 5; + // Clean up connections between Envoy and endpoints + void cleanupHostConnections() { + if (host_fake_connection_ != nullptr) { + AssertionResult result = host_fake_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = host_fake_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); + } + if (host2_fake_connection_ != nullptr) { + AssertionResult result = host2_fake_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = host2_fake_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); + } + } + + // Creates a basic HealthCheckSpecifier message containing one endpoint and + // one health_check + envoy::service::discovery::v2::HealthCheckSpecifier makeHealthCheckSpecifier() { + envoy::service::discovery::v2::HealthCheckSpecifier server_health_check_specifier_; + server_health_check_specifier_.mutable_interval()->set_seconds(1); + + auto* health_check = server_health_check_specifier_.add_health_check(); - IntegrationStreamDecoderPtr response_; - std::string sub_zone_{"winter"}; - FakeHttpConnectionPtr fake_hds_connection_; + health_check->set_cluster_name("anna"); + health_check->add_endpoints() + ->add_endpoints() + ->mutable_address() + ->mutable_socket_address() + ->set_address(host_upstream_->localAddress()->ip()->addressAsString()); + health_check->mutable_endpoints(0) + ->mutable_endpoints(0) + ->mutable_address() + ->mutable_socket_address() + ->set_port_value(host_upstream_->localAddress()->ip()->port()); + health_check->mutable_endpoints(0)->mutable_locality()->set_region("some_region"); + health_check->mutable_endpoints(0)->mutable_locality()->set_zone("some_zone"); + health_check->mutable_endpoints(0)->mutable_locality()->set_sub_zone("crete"); + + health_check->add_health_checks()->mutable_timeout()->set_seconds(1); + health_check->mutable_health_checks(0)->mutable_interval()->set_seconds(1); + health_check->mutable_health_checks(0)->mutable_unhealthy_threshold()->set_value(2); + health_check->mutable_health_checks(0)->mutable_healthy_threshold()->set_value(2); + health_check->mutable_health_checks(0)->mutable_grpc_health_check(); + health_check->mutable_health_checks(0)->mutable_http_health_check()->set_use_http2(false); + health_check->mutable_health_checks(0)->mutable_http_health_check()->set_path("/healthcheck"); + + return server_health_check_specifier_; + } + + // Checks if Envoy reported the health status of an endpoint correctly + void checkEndpointHealthResponse(envoy::service::discovery::v2::EndpointHealth endpoint, + envoy::api::v2::core::HealthStatus healthy, + Network::Address::InstanceConstSharedPtr address) { + + EXPECT_EQ(healthy, endpoint.health_status()); + EXPECT_EQ(address->ip()->port(), endpoint.endpoint().address().socket_address().port_value()); + EXPECT_EQ(address->ip()->addressAsString(), + endpoint.endpoint().address().socket_address().address()); + } + + // Checks if the cluster counters are correct + void checkCounters(int requests, int response_s, int successes, int failures) { + EXPECT_EQ(requests, test_server_->counter("hds_delegate.requests")->value()); + EXPECT_EQ(response_s, test_server_->counter("hds_delegate.responses")->value()); + EXPECT_EQ(successes, test_server_->counter("cluster.anna.health_check.success")->value()); + EXPECT_EQ(failures, test_server_->counter("cluster.anna.health_check.failure")->value()); + } + + static constexpr uint32_t upstream_endpoints_ = 0; + + FakeHttpConnectionPtr hds_fake_connection_; FakeStreamPtr hds_stream_; FakeUpstream* hds_upstream_{}; - FakeUpstream* service_upstream_[upstream_endpoints_]{}; uint32_t hds_requests_{}; - EdsHelper eds_helper_; + FakeUpstreamPtr host_upstream_{}; + FakeUpstreamPtr host2_upstream_{}; + FakeStreamPtr host_stream_; + FakeStreamPtr host2_stream_; + FakeHttpConnectionPtr host_fake_connection_; + FakeHttpConnectionPtr host2_fake_connection_; + + envoy::service::discovery::v2::HealthCheckRequest envoy_msg_; + envoy::service::discovery::v2::HealthCheckRequestOrEndpointHealthResponse response_; + envoy::service::discovery::v2::HealthCheckSpecifier server_health_check_specifier_; }; INSTANTIATE_TEST_CASE_P(IpVersions, HdsIntegrationTest, testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), TestUtility::ipTestParamsToString); -// Test connectivity of Envoy and the Server -TEST_P(HdsIntegrationTest, Simple) { +// Tests Envoy healthchecking a single healthy endpoint and reporting that it is +// indeed healthy to the server. +TEST_P(HdsIntegrationTest, SingleEndpointHealthy) { + initialize(); + + // Server <--> Envoy + waitForHdsStream(); + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_)); + + // Server asks for healthchecking + server_health_check_specifier_ = makeHealthCheckSpecifier(); + hds_stream_->startGrpcStream(); + hds_stream_->sendGrpcMessage(server_health_check_specifier_); + test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); + + // Envoy sends a healthcheck message to an endpoint + healthcheckEndpoints(); + + // Endpoint responds to the healthcheck + host_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + host_stream_->encodeData(1024, true); + + // Envoy reports back to server + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, response_)); + + // Check that the response_ is correct + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(0), + envoy::api::v2::core::HealthStatus::HEALTHY, + host_upstream_->localAddress()); + checkCounters(1, 2, 1, 0); + + // Clean up connections + cleanupHostConnections(); + cleanupHdsConnection(); +} + +// Tests Envoy healthchecking a single endpoint that times out and reporting +// that it is unhealthy to the server. +TEST_P(HdsIntegrationTest, SingleEndpointTimeout) { + initialize(); + server_health_check_specifier_ = makeHealthCheckSpecifier(); + + // Server <--> Envoy + waitForHdsStream(); + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_)); + + // Server asks for healthchecking + hds_stream_->startGrpcStream(); + hds_stream_->sendGrpcMessage(server_health_check_specifier_); + test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); + + // Envoy sends a healthcheck message to an endpoint + healthcheckEndpoints(); + + // Endpoint doesn't repond to the healthcheck + + // Envoy reports back to server + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, response_)); + + // Check that the response_ is correct + // TODO(lilika): Ideally this would be envoy::api::v2::core::HealthStatus::TIMEOUT + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(0), + envoy::api::v2::core::HealthStatus::UNHEALTHY, + host_upstream_->localAddress()); + checkCounters(1, 2, 0, 1); + + // Clean up connections + cleanupHostConnections(); + cleanupHdsConnection(); +} + +// Tests Envoy healthchecking a single unhealthy endpoint and reporting that it is +// indeed unhealthy to the server. +TEST_P(HdsIntegrationTest, SingleEndpointUnhealthy) { initialize(); - envoy::service::discovery::v2::HealthCheckRequest envoy_msg; - envoy::service::discovery::v2::HealthCheckRequest envoy_msg_2; - envoy::service::discovery::v2::HealthCheckSpecifier server_health_check_specifier; - server_health_check_specifier.mutable_interval()->set_nanos(500000000); // 500ms + server_health_check_specifier_ = makeHealthCheckSpecifier(); // Server <--> Envoy - fake_hds_connection_ = hds_upstream_->waitForHttpConnection(*dispatcher_); - hds_stream_ = fake_hds_connection_->waitForNewStream(*dispatcher_); - hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg); + waitForHdsStream(); + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_)); - EXPECT_EQ(0, test_server_->counter("hds_delegate.requests")->value()); - EXPECT_EQ(1, test_server_->counter("hds_delegate.responses")->value()); + // Server asks for healthchecking + hds_stream_->startGrpcStream(); + hds_stream_->sendGrpcMessage(server_health_check_specifier_); + test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); + + // Envoy sends a healthcheck message to an endpoint + healthcheckEndpoints(); + + // Endpoint responds to the healthcheck + host_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "404"}}, false); + host_stream_->encodeData(1024, true); - // Send a message to Envoy, and wait until it's received + // Envoy reports back to server + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, response_)); + + // Check that the response_ is correct + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(0), + envoy::api::v2::core::HealthStatus::UNHEALTHY, + host_upstream_->localAddress()); + checkCounters(1, 2, 0, 1); + + // Clean up connections + cleanupHostConnections(); + cleanupHdsConnection(); +} + +// Tests that Envoy can healthcheck two hosts that are in the same cluster, and +// the same locality and report back the correct health statuses. +TEST_P(HdsIntegrationTest, TwoEndpointsSameLocality) { + initialize(); + + server_health_check_specifier_ = makeHealthCheckSpecifier(); + auto* endpoint = server_health_check_specifier_.mutable_health_check(0)->mutable_endpoints(0); + endpoint->add_endpoints()->mutable_address()->mutable_socket_address()->set_address( + host2_upstream_->localAddress()->ip()->addressAsString()); + endpoint->mutable_endpoints(1)->mutable_address()->mutable_socket_address()->set_port_value( + host2_upstream_->localAddress()->ip()->port()); + + // Server <--> Envoy + waitForHdsStream(); + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_)); + + // Server asks for healthchecking hds_stream_->startGrpcStream(); - hds_stream_->sendGrpcMessage(server_health_check_specifier); + hds_stream_->sendGrpcMessage(server_health_check_specifier_); test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); - // Wait for Envoy to reply - hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_2); + healthcheckEndpoints("anna"); + + // Endpoints repond to the healthcheck + host_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "404"}}, false); + host_stream_->encodeData(1024, true); + host2_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + host2_stream_->encodeData(1024, true); + + // Envoy reports back to server + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, response_)); + + // Check that the response_ is correct + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(0), + envoy::api::v2::core::HealthStatus::UNHEALTHY, + host_upstream_->localAddress()); + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(1), + envoy::api::v2::core::HealthStatus::HEALTHY, + host2_upstream_->localAddress()); + checkCounters(1, 2, 1, 1); + + // Clean up connections + cleanupHostConnections(); + cleanupHdsConnection(); +} + +// Tests that Envoy can healthcheck two hosts that are in the same cluster, and +// different localities and report back the correct health statuses. +TEST_P(HdsIntegrationTest, TwoEndpointsDifferentLocality) { + initialize(); + server_health_check_specifier_ = makeHealthCheckSpecifier(); + + // Add endpoint + auto* health_check = server_health_check_specifier_.mutable_health_check(0); + + health_check->add_endpoints() + ->add_endpoints() + ->mutable_address() + ->mutable_socket_address() + ->set_address(host2_upstream_->localAddress()->ip()->addressAsString()); + health_check->mutable_endpoints(1) + ->mutable_endpoints(0) + ->mutable_address() + ->mutable_socket_address() + ->set_port_value(host2_upstream_->localAddress()->ip()->port()); + health_check->mutable_endpoints(1)->mutable_locality()->set_region("different_region"); + health_check->mutable_endpoints(1)->mutable_locality()->set_zone("different_zone"); + health_check->mutable_endpoints(1)->mutable_locality()->set_sub_zone("emplisi"); + + // Server <--> Envoy + waitForHdsStream(); + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_)); + + // Server asks for healthchecking + hds_stream_->startGrpcStream(); + hds_stream_->sendGrpcMessage(server_health_check_specifier_); + test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); + + // Envoy sends healthcheck messages to two endpoints + healthcheckEndpoints("anna"); + + // Endpoint responds to the healthcheck + host_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "404"}}, false); + host_stream_->encodeData(1024, true); + host2_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + host2_stream_->encodeData(1024, true); + + // Envoy reports back to server + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, response_)); + + // Check that the response_ is correct + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(0), + envoy::api::v2::core::HealthStatus::UNHEALTHY, + host_upstream_->localAddress()); + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(1), + envoy::api::v2::core::HealthStatus::HEALTHY, + host2_upstream_->localAddress()); + checkCounters(1, 2, 1, 1); + + // Clean up connections + cleanupHostConnections(); + cleanupHdsConnection(); +} + +// Tests that Envoy can healthcheck two hosts that are in different clusters, and +// report back the correct health statuses. +TEST_P(HdsIntegrationTest, TwoEndpointsDifferentClusters) { + initialize(); + server_health_check_specifier_ = makeHealthCheckSpecifier(); + + // Add endpoint + auto* health_check = server_health_check_specifier_.add_health_check(); + + health_check->set_cluster_name("cat"); + health_check->add_endpoints() + ->add_endpoints() + ->mutable_address() + ->mutable_socket_address() + ->set_address(host2_upstream_->localAddress()->ip()->addressAsString()); + health_check->mutable_endpoints(0) + ->mutable_endpoints(0) + ->mutable_address() + ->mutable_socket_address() + ->set_port_value(host2_upstream_->localAddress()->ip()->port()); + health_check->mutable_endpoints(0)->mutable_locality()->set_region("peculiar_region"); + health_check->mutable_endpoints(0)->mutable_locality()->set_zone("peculiar_zone"); + health_check->mutable_endpoints(0)->mutable_locality()->set_sub_zone("paris"); + + health_check->add_health_checks()->mutable_timeout()->set_seconds(1); + health_check->mutable_health_checks(0)->mutable_interval()->set_seconds(1); + health_check->mutable_health_checks(0)->mutable_unhealthy_threshold()->set_value(2); + health_check->mutable_health_checks(0)->mutable_healthy_threshold()->set_value(2); + health_check->mutable_health_checks(0)->mutable_grpc_health_check(); + health_check->mutable_health_checks(0)->mutable_http_health_check()->set_use_http2(false); + health_check->mutable_health_checks(0)->mutable_http_health_check()->set_path("/healthcheck"); + + // Server <--> Envoy + waitForHdsStream(); + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, envoy_msg_)); + + // Server asks for healthchecking + hds_stream_->startGrpcStream(); + hds_stream_->sendGrpcMessage(server_health_check_specifier_); + test_server_->waitForCounterGe("hds_delegate.requests", ++hds_requests_); + + // Envoy sends healthcheck messages to two endpoints + healthcheckEndpoints("cat"); + + // Endpoint responds to the healthcheck + host_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "404"}}, false); + host_stream_->encodeData(1024, true); + host2_stream_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + host2_stream_->encodeData(1024, true); + + // Envoy reports back to server + ASSERT_TRUE(hds_stream_->waitForGrpcMessage(*dispatcher_, response_)); - EXPECT_EQ(1, test_server_->counter("hds_delegate.requests")->value()); - EXPECT_EQ(2, test_server_->counter("hds_delegate.responses")->value()); + // Check that the response_ is correct + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(0), + envoy::api::v2::core::HealthStatus::UNHEALTHY, + host_upstream_->localAddress()); + checkEndpointHealthResponse(response_.endpoint_health_response().endpoints_health(1), + envoy::api::v2::core::HealthStatus::HEALTHY, + host2_upstream_->localAddress()); + checkCounters(1, 2, 0, 1); + EXPECT_EQ(1, test_server_->counter("cluster.cat.health_check.success")->value()); + EXPECT_EQ(0, test_server_->counter("cluster.cat.health_check.failure")->value()); + // Clean up connections + cleanupHostConnections(); cleanupHdsConnection(); } diff --git a/test/integration/header_integration_test.cc b/test/integration/header_integration_test.cc index 606a8cc287efc..e25e7b79a4f62 100644 --- a/test/integration/header_integration_test.cc +++ b/test/integration/header_integration_test.cc @@ -140,10 +140,15 @@ class HeaderIntegrationTest void TearDown() override { if (eds_connection_ != nullptr) { - eds_connection_->close(); - eds_connection_->waitForDisconnect(); + // Don't ASSERT fail if an EDS reconnect ends up unparented. + fake_upstreams_[1]->set_allow_unexpected_disconnects(true); + AssertionResult result = eds_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = eds_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); eds_connection_.reset(); } + cleanupUpstreamAndDownstream(); test_server_.reset(); fake_upstream_connection_.reset(); fake_upstreams_.clear(); @@ -306,12 +311,16 @@ class HeaderIntegrationTest void initialize() override { if (use_eds_) { pre_worker_start_test_steps_ = [this]() { - eds_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - eds_stream_ = eds_connection_->waitForNewStream(*dispatcher_); + AssertionResult result = + fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, eds_connection_); + RELEASE_ASSERT(result, result.message()); + result = eds_connection_->waitForNewStream(*dispatcher_, eds_stream_); + RELEASE_ASSERT(result, result.message()); eds_stream_->startGrpcStream(); envoy::api::v2::DiscoveryRequest discovery_request; - eds_stream_->waitForGrpcMessage(*dispatcher_, discovery_request); + result = eds_stream_->waitForGrpcMessage(*dispatcher_, discovery_request); + RELEASE_ASSERT(result, result.message()); envoy::api::v2::DiscoveryResponse discovery_response; discovery_response.set_version_info("1"); @@ -340,7 +349,8 @@ class HeaderIntegrationTest eds_stream_->sendGrpcMessage(discovery_response); // Wait for the next request to make sure the first response was consumed. - eds_stream_->waitForGrpcMessage(*dispatcher_, discovery_request); + result = eds_stream_->waitForGrpcMessage(*dispatcher_, discovery_request); + RELEASE_ASSERT(result, result.message()); }; } @@ -353,13 +363,8 @@ class HeaderIntegrationTest Http::TestHeaderMapImpl&& response_headers, Http::TestHeaderMapImpl&& expected_response_headers) { registerTestServerPorts({"http"}); - codec_client_ = makeHttpConnection(makeClientConnection(lookupPort("http"))); - auto response = codec_client_->makeHeaderOnlyRequest(request_headers); - waitForNextUpstreamRequest(); - - upstream_request_->encodeHeaders(response_headers, true); - response->waitForEndStream(); + auto response = sendRequestAndWaitForResponse(request_headers, 0, response_headers, 0); compareHeaders(upstream_request_->headers(), expected_request_headers); compareHeaders(response->headers(), expected_response_headers); diff --git a/test/integration/hotrestart_test.sh b/test/integration/hotrestart_test.sh index 09cbe85caec4d..09430a869657d 100755 --- a/test/integration/hotrestart_test.sh +++ b/test/integration/hotrestart_test.sh @@ -90,24 +90,26 @@ JSON_TEST_ARRAY=() # Parameterize IPv4 and IPv6 testing. if [[ -z "${ENVOY_IP_TEST_VERSIONS}" ]] || [[ "${ENVOY_IP_TEST_VERSIONS}" == "all" ]] \ || [[ "${ENVOY_IP_TEST_VERSIONS}" == "v4only" ]]; then - HOT_RESTART_JSON_V4="${TEST_TMPDIR}"/hot_restart_v4.json + HOT_RESTART_JSON_V4="${TEST_TMPDIR}"/hot_restart_v4.yaml echo building ${HOT_RESTART_JSON_V4} ... - cat "${TEST_RUNDIR}"/test/config/integration/server.json | + cat "${TEST_RUNDIR}"/test/config/integration/server.yaml | sed -e "s#{{ upstream_. }}#0#g" | \ sed -e "s#{{ test_rundir }}#$TEST_RUNDIR#" | \ + sed -e "s#{{ test_tmpdir }}#$TEST_TMPDIR#" | \ sed -e "s#{{ ip_loopback_address }}#127.0.0.1#" | \ - sed -e "s#{{ dns_lookup_family }}#v4_only#" | \ + sed -e "s#{{ dns_lookup_family }}#V4_ONLY#" | \ cat > "${HOT_RESTART_JSON_V4}" JSON_TEST_ARRAY+=("${HOT_RESTART_JSON_V4}") fi if [[ -z "${ENVOY_IP_TEST_VERSIONS}" ]] || [[ "${ENVOY_IP_TEST_VERSIONS}" == "all" ]] \ || [[ "${ENVOY_IP_TEST_VERSIONS}" == "v6only" ]]; then - HOT_RESTART_JSON_V6="${TEST_TMPDIR}"/hot_restart_v6.json - cat "${TEST_RUNDIR}"/test/config/integration/server.json | + HOT_RESTART_JSON_V6="${TEST_TMPDIR}"/hot_restart_v6.yaml + cat "${TEST_RUNDIR}"/test/config/integration/server.yaml | sed -e "s#{{ upstream_. }}#0#g" | \ sed -e "s#{{ test_rundir }}#$TEST_RUNDIR#" | \ - sed -e "s#{{ ip_loopback_address }}#[::1]#" | \ + sed -e "s#{{ test_tmpdir }}#$TEST_TMPDIR#" | \ + sed -e "s#{{ ip_loopback_address }}#::1#" | \ sed -e "s#{{ dns_lookup_family }}#v6_only#" | \ cat > "${HOT_RESTART_JSON_V6}" JSON_TEST_ARRAY+=("${HOT_RESTART_JSON_V6}") @@ -115,9 +117,9 @@ fi # Also test for listening on UNIX domain sockets. We use IPv4 for the # upstreams to avoid too much wild sedding. -HOT_RESTART_JSON_UDS="${TEST_TMPDIR}"/hot_restart_uds.json +HOT_RESTART_JSON_UDS="${TEST_TMPDIR}"/hot_restart_uds.yaml SOCKET_DIR="$(mktemp -d /tmp/envoy_test_hotrestart.XXXXXX)" -cat "${TEST_RUNDIR}"/test/config/integration/server_unix_listener.json | +cat "${TEST_RUNDIR}"/test/config/integration/server_unix_listener.yaml | sed -e "s#{{ socket_dir }}#${SOCKET_DIR}#" | \ sed -e "s#{{ ip_loopback_address }}#127.0.0.1#" | \ cat > "${HOT_RESTART_JSON_UDS}" @@ -150,9 +152,9 @@ do FIRST_SERVER_PID=$BACKGROUND_PID - start_test Updating original config json listener addresses + start_test Updating original config listener addresses sleep 3 - UPDATED_HOT_RESTART_JSON="${TEST_TMPDIR}"/hot_restart_updated."${TEST_INDEX}".json + UPDATED_HOT_RESTART_JSON="${TEST_TMPDIR}"/hot_restart_updated."${TEST_INDEX}".yaml "${TEST_RUNDIR}"/tools/socket_passing "-o" "${HOT_RESTART_JSON}" "-a" "${ADMIN_ADDRESS_PATH_0}" \ "-u" "${UPDATED_HOT_RESTART_JSON}" @@ -165,6 +167,19 @@ do disableHeapCheck + # To ensure that we don't accidentally change the /hot_restart_version + # string, compare it against a hard-coded string. + start_test Checking for consistency of /hot_restart_version + CLI_HOT_RESTART_VERSION=$("${ENVOY_BIN}" --hot-restart-version --base-id "${BASE_ID}" 2>&1) + EXPECTED_CLI_HOT_RESTART_VERSION="10.200.16384.127.options=capacity=16384, num_slots=8209 hash=228984379728933363 size=2654312" + check [ "${CLI_HOT_RESTART_VERSION}" = "${EXPECTED_CLI_HOT_RESTART_VERSION}" ] + + start_test Checking for consistency of /hot_restart_version with --max-obj-name-len 500 + CLI_HOT_RESTART_VERSION=$("${ENVOY_BIN}" --hot-restart-version --base-id "${BASE_ID}" \ + --max-obj-name-len 500 2>&1) + EXPECTED_CLI_HOT_RESTART_VERSION="10.200.16384.567.options=capacity=16384, num_slots=8209 hash=228984379728933363 size=9863272" + check [ "${CLI_HOT_RESTART_VERSION}" = "${EXPECTED_CLI_HOT_RESTART_VERSION}" ] + start_test Checking for match of --hot-restart-version and admin /hot_restart_version ADMIN_ADDRESS_0=$(cat "${ADMIN_ADDRESS_PATH_0}") echo fetching hot restart version from http://${ADMIN_ADDRESS_0}/hot_restart_version ... @@ -197,7 +212,7 @@ do sleep 7 start_test Checking that listener addresses have not changed - HOT_RESTART_JSON_1="${TEST_TMPDIR}"/hot_restart.1."${TEST_INDEX}".json + HOT_RESTART_JSON_1="${TEST_TMPDIR}"/hot_restart.1."${TEST_INDEX}".yaml "${TEST_RUNDIR}"/tools/socket_passing "-o" "${UPDATED_HOT_RESTART_JSON}" "-a" "${ADMIN_ADDRESS_PATH_1}" \ "-u" "${HOT_RESTART_JSON_1}" CONFIG_DIFF=$(diff "${UPDATED_HOT_RESTART_JSON}" "${HOT_RESTART_JSON_1}") @@ -213,7 +228,7 @@ do sleep 3 start_test Checking that listener addresses have not changed - HOT_RESTART_JSON_2="${TEST_TMPDIR}"/hot_restart.2."${TEST_INDEX}".json + HOT_RESTART_JSON_2="${TEST_TMPDIR}"/hot_restart.2."${TEST_INDEX}".yaml "${TEST_RUNDIR}"/tools/socket_passing "-o" "${UPDATED_HOT_RESTART_JSON}" "-a" "${ADMIN_ADDRESS_PATH_2}" \ "-u" "${HOT_RESTART_JSON_2}" CONFIG_DIFF=$(diff "${UPDATED_HOT_RESTART_JSON}" "${HOT_RESTART_JSON_2}") diff --git a/test/integration/http2_integration_test.cc b/test/integration/http2_integration_test.cc index 2770fb12af805..35359c8b6ab73 100644 --- a/test/integration/http2_integration_test.cc +++ b/test/integration/http2_integration_test.cc @@ -131,7 +131,7 @@ TEST_P(Http2IntegrationTest, BadMagic) { RawConnectionDriver connection( lookupPort("http"), buffer, [&](Network::ClientConnection&, const Buffer::Instance& data) -> void { - response.append(TestUtility::bufferToString(data)); + response.append(data.toString()); }, version_); @@ -146,7 +146,7 @@ TEST_P(Http2IntegrationTest, BadFrame) { RawConnectionDriver connection( lookupPort("http"), buffer, [&](Network::ClientConnection&, const Buffer::Instance& data) -> void { - response.append(TestUtility::bufferToString(data)); + response.append(data.toString()); }, version_); @@ -253,8 +253,8 @@ TEST_P(Http2IntegrationTest, IdleTimeoutWithSimultaneousRequests) { encoder1 = &encoder_decoder.first; auto response1 = std::move(encoder_decoder.second); - fake_upstream_connection1 = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request1 = fake_upstream_connection1->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection1)); + ASSERT_TRUE(fake_upstream_connection1->waitForNewStream(*dispatcher_, upstream_request1)); // Start request 2 auto encoder_decoder2 = @@ -264,16 +264,16 @@ TEST_P(Http2IntegrationTest, IdleTimeoutWithSimultaneousRequests) { {":authority", "host"}}); encoder2 = &encoder_decoder2.first; auto response2 = std::move(encoder_decoder2.second); - fake_upstream_connection2 = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request2 = fake_upstream_connection2->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection2)); + ASSERT_TRUE(fake_upstream_connection2->waitForNewStream(*dispatcher_, upstream_request2)); // Finish request 1 codec_client_->sendData(*encoder1, request1_bytes, true); - upstream_request1->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request1->waitForEndStream(*dispatcher_)); // Finish request i2 codec_client_->sendData(*encoder2, request2_bytes, true); - upstream_request2->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request2->waitForEndStream(*dispatcher_)); // Respond to request 2 upstream_request2->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -300,8 +300,8 @@ TEST_P(Http2IntegrationTest, IdleTimeoutWithSimultaneousRequests) { EXPECT_EQ(request1_bytes, response1->body().size()); // Do not send any requests and validate idle timeout kicks in after both the requests are done. - fake_upstream_connection1->waitForDisconnect(); - fake_upstream_connection2->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection1->waitForDisconnect()); + ASSERT_TRUE(fake_upstream_connection2->waitForDisconnect()); test_server_->waitForCounterGe("cluster.cluster_0.upstream_cx_idle_timeout", 2); } @@ -325,8 +325,8 @@ void Http2IntegrationTest::simultaneousRequest(int32_t request1_bytes, int32_t r encoder1 = &encoder_decoder.first; auto response1 = std::move(encoder_decoder.second); - fake_upstream_connection1 = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request1 = fake_upstream_connection1->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection1)); + ASSERT_TRUE(fake_upstream_connection1->waitForNewStream(*dispatcher_, upstream_request1)); // Start request 2 auto encoder_decoder2 = @@ -336,16 +336,16 @@ void Http2IntegrationTest::simultaneousRequest(int32_t request1_bytes, int32_t r {":authority", "host"}}); encoder2 = &encoder_decoder2.first; auto response2 = std::move(encoder_decoder2.second); - fake_upstream_connection2 = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request2 = fake_upstream_connection2->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection2)); + ASSERT_TRUE(fake_upstream_connection2->waitForNewStream(*dispatcher_, upstream_request2)); // Finish request 1 codec_client_->sendData(*encoder1, request1_bytes, true); - upstream_request1->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request1->waitForEndStream(*dispatcher_)); // Finish request 2 codec_client_->sendData(*encoder2, request2_bytes, true); - upstream_request2->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request2->waitForEndStream(*dispatcher_)); // Respond to request 2 upstream_request2->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -368,11 +368,11 @@ void Http2IntegrationTest::simultaneousRequest(int32_t request1_bytes, int32_t r EXPECT_EQ(request2_bytes, response1->body().size()); // Cleanup both downstream and upstream + ASSERT_TRUE(fake_upstream_connection1->close()); + ASSERT_TRUE(fake_upstream_connection1->waitForDisconnect()); + ASSERT_TRUE(fake_upstream_connection2->close()); + ASSERT_TRUE(fake_upstream_connection2->waitForDisconnect()); codec_client_->close(); - fake_upstream_connection1->close(); - fake_upstream_connection1->waitForDisconnect(); - fake_upstream_connection2->close(); - fake_upstream_connection2->waitForDisconnect(); } TEST_P(Http2IntegrationTest, SimultaneousRequest) { simultaneousRequest(1024, 512); } @@ -400,8 +400,10 @@ Http2RingHashIntegrationTest::~Http2RingHashIntegrationTest() { codec_client_ = nullptr; } for (auto it = fake_upstream_connections_.begin(); it != fake_upstream_connections_.end(); ++it) { - (*it)->close(); - (*it)->waitForDisconnect(); + AssertionResult result = (*it)->close(); + RELEASE_ASSERT(result, result.message()); + result = (*it)->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } @@ -435,17 +437,20 @@ void Http2RingHashIntegrationTest::sendMultipleRequests( } for (uint32_t i = 0; i < num_requests; ++i) { - auto fake_upstream_connection = - FakeUpstream::waitForHttpConnection(*dispatcher_, fake_upstreams_); + FakeHttpConnectionPtr fake_upstream_connection; + ASSERT_TRUE(FakeUpstream::waitForHttpConnection(*dispatcher_, fake_upstreams_, + fake_upstream_connection)); // As data and streams are interwoven, make sure waitForNewStream() // ignores incoming data and waits for actual stream establishment. - upstream_requests.push_back(fake_upstream_connection->waitForNewStream(*dispatcher_, true)); + upstream_requests.emplace_back(); + ASSERT_TRUE( + fake_upstream_connection->waitForNewStream(*dispatcher_, upstream_requests.back(), true)); upstream_requests.back()->setAddServedByHeader(true); fake_upstream_connections_.push_back(std::move(fake_upstream_connection)); } for (uint32_t i = 0; i < num_requests; ++i) { - upstream_requests[i]->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_requests[i]->waitForEndStream(*dispatcher_)); upstream_requests[i]->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); upstream_requests[i]->encodeData(rand.random() % (1024 * 2), true); } diff --git a/test/integration/http2_upstream_integration_test.cc b/test/integration/http2_upstream_integration_test.cc index 2d4b3be30a52e..d1401a5f0baac 100644 --- a/test/integration/http2_upstream_integration_test.cc +++ b/test/integration/http2_upstream_integration_test.cc @@ -97,12 +97,12 @@ void Http2UpstreamIntegrationTest::bidirectionalStreaming(uint32_t bytes) { {":authority", "host"}}); auto response = std::move(encoder_decoder.second); request_encoder_ = &encoder_decoder.first; - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); // Send part of the request body and ensure it is received upstream. codec_client_->sendData(*request_encoder_, bytes, false); - upstream_request_->waitForData(*dispatcher_, bytes); + ASSERT_TRUE(upstream_request_->waitForData(*dispatcher_, bytes)); // Start sending the response and ensure it is received downstream. upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -111,7 +111,7 @@ void Http2UpstreamIntegrationTest::bidirectionalStreaming(uint32_t bytes) { // Finish the request. codec_client_->sendTrailers(*request_encoder_, Http::TestHeaderMapImpl{{"trailer", "foo"}}); - upstream_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); // Finish the response. upstream_request_->encodeTrailers(Http::TestHeaderMapImpl{{"trailer", "bar"}}); @@ -138,12 +138,12 @@ TEST_P(Http2UpstreamIntegrationTest, BidirectionalStreamingReset) { {":authority", "host"}}); auto response = std::move(encoder_decoder.second); request_encoder_ = &encoder_decoder.first; - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); // Send some request data. codec_client_->sendData(*request_encoder_, 1024, false); - upstream_request_->waitForData(*dispatcher_, 1024); + ASSERT_TRUE(upstream_request_->waitForData(*dispatcher_, 1024)); // Start sending the response. upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -152,7 +152,7 @@ TEST_P(Http2UpstreamIntegrationTest, BidirectionalStreamingReset) { // Finish sending therequest. codec_client_->sendTrailers(*request_encoder_, Http::TestHeaderMapImpl{{"trailer", "foo"}}); - upstream_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); // Reset the stream. upstream_request_->encodeResetStream(); @@ -177,8 +177,8 @@ void Http2UpstreamIntegrationTest::simultaneousRequest(uint32_t request1_bytes, {":authority", "host"}}); Http::StreamEncoder* encoder1 = &encoder_decoder1.first; auto response1 = std::move(encoder_decoder1.second); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request1 = fake_upstream_connection_->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request1)); // Start request 2 auto encoder_decoder2 = @@ -188,15 +188,17 @@ void Http2UpstreamIntegrationTest::simultaneousRequest(uint32_t request1_bytes, {":authority", "host"}}); Http::StreamEncoder* encoder2 = &encoder_decoder2.first; auto response2 = std::move(encoder_decoder2.second); - upstream_request2 = fake_upstream_connection_->waitForNewStream(*dispatcher_); + // DO NOT SUBMIT replace other ASSERT_TRUE with ASSERT (?) + // (in places like this, ASSERT_TRUE is probably fine) + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request2)); // Finish request 1 codec_client_->sendData(*encoder1, request1_bytes, true); - upstream_request1->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request1->waitForEndStream(*dispatcher_)); // Finish request 2 codec_client_->sendData(*encoder2, request2_bytes, true); - upstream_request2->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request2->waitForEndStream(*dispatcher_)); // Respond to request 2 upstream_request2->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -302,20 +304,23 @@ TEST_P(Http2UpstreamIntegrationTest, UpstreamConnectionCloseWithManyStreams) { codec_client_->sendData(*encoders[i], 0, true); } } - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); for (uint32_t i = 0; i < num_requests; ++i) { - upstream_requests.push_back(fake_upstream_connection_->waitForNewStream(*dispatcher_)); + FakeStreamPtr stream; + upstream_requests.emplace_back(); + ASSERT_TRUE( + fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_requests.back())); } for (uint32_t i = 0; i < num_requests; ++i) { if (i % 15 != 0) { - upstream_requests[i]->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_requests[i]->waitForEndStream(*dispatcher_)); upstream_requests[i]->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); upstream_requests[i]->encodeData(100, false); } } // Close the connection. - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); // Ensure the streams are all reset successfully. for (uint32_t i = 0; i < num_requests; ++i) { if (i % 15 != 0) { diff --git a/test/integration/http_integration.cc b/test/integration/http_integration.cc index 46dc6d6e64b96..3f93b90aed7e4 100644 --- a/test/integration/http_integration.cc +++ b/test/integration/http_integration.cc @@ -65,7 +65,7 @@ typeToCodecType(Http::CodecClient::Type type) { return envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager:: HTTP2; default: - RELEASE_ASSERT(0); + RELEASE_ASSERT(0, ""); } } @@ -192,8 +192,9 @@ void HttpIntegrationTest::setDownstreamProtocol(Http::CodecClient::Type downstre } IntegrationStreamDecoderPtr HttpIntegrationTest::sendRequestAndWaitForResponse( - Http::TestHeaderMapImpl& request_headers, uint32_t request_body_size, - Http::TestHeaderMapImpl& response_headers, uint32_t response_size) { + const Http::TestHeaderMapImpl& request_headers, uint32_t request_body_size, + const Http::TestHeaderMapImpl& response_headers, uint32_t response_size) { + ASSERT(codec_client_ != nullptr); // Send the request to Envoy. IntegrationStreamDecoderPtr response; if (request_body_size) { @@ -214,25 +215,35 @@ IntegrationStreamDecoderPtr HttpIntegrationTest::sendRequestAndWaitForResponse( } void HttpIntegrationTest::cleanupUpstreamAndDownstream() { + // Close the upstream connection first. If there's an outstanding request, + // closing the client may result in a FIN being sent upstream, and FakeConnectionBase::close + // will interpret that as an unexpected disconnect. The codec client is not + // subject to the same failure mode. + if (fake_upstream_connection_) { + AssertionResult result = fake_upstream_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_upstream_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); + } if (codec_client_) { codec_client_->close(); } - if (fake_upstream_connection_) { - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); - } } void HttpIntegrationTest::waitForNextUpstreamRequest(uint64_t upstream_index) { // If there is no upstream connection, wait for it to be established. if (!fake_upstream_connection_) { - fake_upstream_connection_ = - fake_upstreams_[upstream_index]->waitForHttpConnection(*dispatcher_); + AssertionResult result = fake_upstreams_[upstream_index]->waitForHttpConnection( + *dispatcher_, fake_upstream_connection_); + RELEASE_ASSERT(result, result.message()); } // Wait for the next stream on the upstream connection. - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); + AssertionResult result = + fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_); + RELEASE_ASSERT(result, result.message()); // Wait for the stream to be completely received. - upstream_request_->waitForEndStream(*dispatcher_); + result = upstream_request_->waitForEndStream(*dispatcher_); + RELEASE_ASSERT(result, result.message()); } void HttpIntegrationTest::testRouterRequestAndResponseWithBody( @@ -445,12 +456,12 @@ void HttpIntegrationTest::testRouterUpstreamDisconnectBeforeRequestComplete() { {":authority", "host"}}); auto response = std::move(encoder_decoder.second); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForHeadersComplete(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); + ASSERT_TRUE(upstream_request_->waitForHeadersComplete()); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); response->waitForEndStream(); if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { @@ -479,8 +490,8 @@ void HttpIntegrationTest::testRouterUpstreamDisconnectBeforeResponseComplete( {":authority", "host"}}); waitForNextUpstreamRequest(); upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { codec_client_->waitForDisconnect(); @@ -509,17 +520,17 @@ void HttpIntegrationTest::testRouterDownstreamDisconnectBeforeRequestComplete( {":scheme", "http"}, {":authority", "host"}}); auto response = std::move(encoder_decoder.second); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForHeadersComplete(); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); + ASSERT_TRUE(upstream_request_->waitForHeadersComplete()); codec_client_->close(); if (upstreamProtocol() == FakeHttpConnection::Type::HTTP1) { - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } else { - upstream_request_->waitForReset(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(upstream_request_->waitForReset()); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } EXPECT_FALSE(upstream_request_->complete()); @@ -545,11 +556,11 @@ void HttpIntegrationTest::testRouterDownstreamDisconnectBeforeResponseComplete( codec_client_->close(); if (upstreamProtocol() == FakeHttpConnection::Type::HTTP1) { - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } else { - upstream_request_->waitForReset(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(upstream_request_->waitForReset()); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } EXPECT_TRUE(upstream_request_->complete()); @@ -569,19 +580,19 @@ void HttpIntegrationTest::testRouterUpstreamResponseBeforeRequestComplete() { {":scheme", "http"}, {":authority", "host"}}); auto response = std::move(encoder_decoder.second); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForHeadersComplete(); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); + ASSERT_TRUE(upstream_request_->waitForHeadersComplete()); upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); upstream_request_->encodeData(512, true); response->waitForEndStream(); if (upstreamProtocol() == FakeHttpConnection::Type::HTTP1) { - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } else { - upstream_request_->waitForReset(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(upstream_request_->waitForReset()); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { @@ -613,10 +624,10 @@ void HttpIntegrationTest::testRetry() { upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "503"}}, false); if (fake_upstreams_[0]->httpType() == FakeHttpConnection::Type::HTTP1) { - fake_upstream_connection_->waitForDisconnect(); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); } else { - upstream_request_->waitForReset(); + ASSERT_TRUE(upstream_request_->waitForReset()); } waitForNextUpstreamRequest(); upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -664,10 +675,10 @@ void HttpIntegrationTest::testGrpcRetry() { upstream_request_->encodeHeaders( Http::TestHeaderMapImpl{{":status", "200"}, {"grpc-status", "1"}}, false); if (fake_upstreams_[0]->httpType() == FakeHttpConnection::Type::HTTP1) { - fake_upstream_connection_->waitForDisconnect(); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); } else { - upstream_request_->waitForReset(); + ASSERT_TRUE(upstream_request_->waitForReset()); } waitForNextUpstreamRequest(); @@ -739,8 +750,16 @@ void HttpIntegrationTest::testHittingDecoderFilterLimit() { 1024 * 65); response->waitForEndStream(); - ASSERT_TRUE(response->complete()); - EXPECT_STREQ("413", response->headers().Status()->value().c_str()); + // With HTTP/1 there's a possible race where if the connection backs up early, + // the 413-and-connection-close may be sent while the body is still being + // sent, resulting in a write error and the connection being closed before the + // response is read. + if (downstream_protocol_ == Http::CodecClient::Type::HTTP2) { + ASSERT_TRUE(response->complete()); + } + if (response->complete()) { + EXPECT_STREQ("413", response->headers().Status()->value().c_str()); + } } // Test hitting the dynamo filter with too many response bytes to buffer. Given the request headers @@ -763,8 +782,11 @@ void HttpIntegrationTest::testHittingEncoderFilterLimit() { // Send the respone headers. upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); - // Now send an overly large response body. + // Now send an overly large response body. At some point, too much data will + // be buffered, the stream will be reset, and the connection will disconnect. + fake_upstreams_[0]->set_allow_unexpected_disconnects(true); upstream_request_->encodeData(1024 * 65, false); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); response->waitForEndStream(); EXPECT_TRUE(response->complete()); @@ -784,14 +806,14 @@ void HttpIntegrationTest::testEnvoyHandling100Continue(bool additional_continue_ {"expect", "100-continue"}}); request_encoder_ = &encoder_decoder.first; auto response = std::move(encoder_decoder.second); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); // The continue headers should arrive immediately. response->waitForContinueHeaders(); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); // Send the rest of the request. codec_client_->sendData(*request_encoder_, 10, true); - upstream_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); // Verify the Expect header is stripped. EXPECT_EQ(nullptr, upstream_request_->headers().get(Http::Headers::get().Expect)); if (via.empty()) { @@ -856,8 +878,8 @@ void HttpIntegrationTest::testEnvoyProxying100Continue(bool continue_before_upst auto response = std::move(encoder_decoder.second); // Wait for the request headers to be received upstream. - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); if (continue_before_upstream_complete) { // This case tests sending on 100-Continue headers before the client has sent all the @@ -867,7 +889,7 @@ void HttpIntegrationTest::testEnvoyProxying100Continue(bool continue_before_upst } // Send all of the request data and wait for it to be received upstream. codec_client_->sendData(*request_encoder_, 10, true); - upstream_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); if (!continue_before_upstream_complete) { // This case tests forwarding 100-Continue after the client has sent all data. @@ -915,7 +937,7 @@ void HttpIntegrationTest::testIdleTimeoutBasic() { test_server_->waitForCounterGe("cluster.cluster_0.upstream_rq_200", 1); // Do not send any requests and validate if idle time out kicks in. - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); test_server_->waitForCounterGe("cluster.cluster_0.upstream_cx_idle_timeout", 1); } @@ -969,7 +991,7 @@ void HttpIntegrationTest::testIdleTimeoutWithTwoRequests() { test_server_->waitForCounterGe("cluster.cluster_0.upstream_rq_200", 2); // Do not send any requests and validate if idle time out kicks in. - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); test_server_->waitForCounterGe("cluster.cluster_0.upstream_cx_idle_timeout", 1); } @@ -1000,7 +1022,7 @@ void HttpIntegrationTest::testUpstreamDisconnectWithTwoRequests() { // Response 1. upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); upstream_request_->encodeData(512, true); - fake_upstream_connection_->close(); + ASSERT_TRUE(fake_upstream_connection_->close()); response->waitForEndStream(); EXPECT_TRUE(upstream_request_->complete()); @@ -1010,7 +1032,7 @@ void HttpIntegrationTest::testUpstreamDisconnectWithTwoRequests() { test_server_->waitForCounterGe("cluster.cluster_0.upstream_rq_200", 1); // Response 2. - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); fake_upstream_connection_.reset(); waitForNextUpstreamRequest(); upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); @@ -1120,6 +1142,11 @@ void HttpIntegrationTest::testHttp10Enabled() { reinterpret_cast(fake_upstreams_.front().get())->lastRequestHeaders(); ASSERT_TRUE(upstream_headers.get() != nullptr); EXPECT_EQ(upstream_headers->Host()->value(), "default.com"); + + sendRawHttpAndWaitForResponse(lookupPort("http"), "HEAD / HTTP/1.0\r\n\r\n", &response, false); + EXPECT_THAT(response, HasSubstr("HTTP/1.0 200 OK\r\n")); + EXPECT_THAT(response, HasSubstr("connection: close")); + EXPECT_THAT(response, Not(HasSubstr("transfer-encoding: chunked\r\n"))); } // Verify for HTTP/1.0 a keep-alive header results in no connection: close. @@ -1384,12 +1411,14 @@ void HttpIntegrationTest::testUpstreamProtocolError() { {":method", "GET"}, {":path", "/test/long/url"}, {":authority", "host"}}); auto response = std::move(encoder_decoder.second); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); // TODO(mattklein123): Waiting for exact amount of data is a hack. This needs to // be fixed. - fake_upstream_connection->waitForData(187); - fake_upstream_connection->write("bad protocol data!"); - fake_upstream_connection->waitForDisconnect(); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(187, &data)); + ASSERT_TRUE(fake_upstream_connection->write("bad protocol data!")); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); codec_client_->waitForDisconnect(); EXPECT_TRUE(response->complete()); @@ -1421,11 +1450,11 @@ void HttpIntegrationTest::testDownstreamResetBeforeResponseComplete() { codec_client_->sendReset(*request_encoder_); if (upstreamProtocol() == FakeHttpConnection::Type::HTTP1) { - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } else { - upstream_request_->waitForReset(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(upstream_request_->waitForReset()); + ASSERT_TRUE(fake_upstream_connection_->close()); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); } codec_client_->close(); @@ -1439,7 +1468,6 @@ void HttpIntegrationTest::testDownstreamResetBeforeResponseComplete() { } void HttpIntegrationTest::testTrailers(uint64_t request_size, uint64_t response_size) { - config_helper_.addFilter(ConfigHelper::DEFAULT_BUFFER_FILTER); Http::TestHeaderMapImpl request_trailers{{"request1", "trailer1"}, {"request2", "trailer2"}}; Http::TestHeaderMapImpl response_trailers{{"response1", "trailer1"}, {"response2", "trailer2"}}; diff --git a/test/integration/http_integration.h b/test/integration/http_integration.h index f07c032b825cd..1b6c3b760157c 100644 --- a/test/integration/http_integration.h +++ b/test/integration/http_integration.h @@ -93,8 +93,8 @@ class HttpIntegrationTest : public BaseIntegrationTest { // Waits for the complete downstream response before returning. // Requires |codec_client_| to be initialized. IntegrationStreamDecoderPtr sendRequestAndWaitForResponse( - Http::TestHeaderMapImpl& request_headers, uint32_t request_body_size, - Http::TestHeaderMapImpl& response_headers, uint32_t response_body_size); + const Http::TestHeaderMapImpl& request_headers, uint32_t request_body_size, + const Http::TestHeaderMapImpl& response_headers, uint32_t response_body_size); // Wait for the end of stream on the next upstream stream on fake_upstreams_ // Sets fake_upstream_connection_ to the connection and upstream_request_ to stream. diff --git a/test/integration/idle_timeout_integration_test.cc b/test/integration/idle_timeout_integration_test.cc new file mode 100644 index 0000000000000..3006fcf25a28d --- /dev/null +++ b/test/integration/idle_timeout_integration_test.cc @@ -0,0 +1,170 @@ +#include "test/integration/http_protocol_integration.h" + +namespace Envoy { +namespace { + +class IdleTimeoutIntegrationTest : public HttpProtocolIntegrationTest { +public: + void initialize() override { + config_helper_.addConfigModifier( + [&](envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) + -> void { + if (enable_global_idle_timeout_) { + hcm.mutable_stream_idle_timeout()->set_seconds(0); + hcm.mutable_stream_idle_timeout()->set_nanos(TimeoutMs * 1000 * 1000); + } + if (enable_per_stream_idle_timeout_) { + auto* route_config = hcm.mutable_route_config(); + auto* virtual_host = route_config->mutable_virtual_hosts(0); + auto* route = virtual_host->mutable_routes(0)->mutable_route(); + route->mutable_idle_timeout()->set_seconds(0); + route->mutable_idle_timeout()->set_nanos(TimeoutMs * 1000 * 1000); + } + // For validating encode100ContinueHeaders() timer kick. + hcm.set_proxy_100_continue(true); + }); + HttpProtocolIntegrationTest::initialize(); + } + + IntegrationStreamDecoderPtr setupPerStreamIdleTimeoutTest() { + initialize(); + fake_upstreams_[0]->set_allow_unexpected_disconnects(true); + codec_client_ = makeHttpConnection(makeClientConnection((lookupPort("http")))); + auto encoder_decoder = + codec_client_->startRequest(Http::TestHeaderMapImpl{{":method", "GET"}, + {":path", "/test/long/url"}, + {":scheme", "http"}, + {":authority", "host"}}); + request_encoder_ = &encoder_decoder.first; + auto response = std::move(encoder_decoder.second); + AssertionResult result = + fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_); + RELEASE_ASSERT(result, result.message()); + result = fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_); + RELEASE_ASSERT(result, result.message()); + result = upstream_request_->waitForHeadersComplete(); + RELEASE_ASSERT(result, result.message()); + return response; + } + + void sleep() { std::this_thread::sleep_for(std::chrono::milliseconds(TimeoutMs / 2)); } + + void waitForTimeout(IntegrationStreamDecoder& response) { + if (downstream_protocol_ == Http::CodecClient::Type::HTTP1) { + codec_client_->waitForDisconnect(); + } else { + response.waitForReset(); + codec_client_->close(); + } + EXPECT_EQ(1, test_server_->counter("http.config_test.downstream_rq_idle_timeout")->value()); + } + + // TODO(htuch): This might require scaling for TSAN/ASAN/Valgrind/etc. Bump if + // this is the cause of flakes. + static constexpr uint64_t TimeoutMs = 200; + bool enable_global_idle_timeout_{}; + bool enable_per_stream_idle_timeout_{true}; +}; + +INSTANTIATE_TEST_CASE_P(Protocols, IdleTimeoutIntegrationTest, + testing::ValuesIn(HttpProtocolIntegrationTest::getProtocolTestParams()), + HttpProtocolIntegrationTest::protocolTestParamsToString); + +// Per-stream idle timeout after having sent downstream headers. +TEST_P(IdleTimeoutIntegrationTest, PerStreamIdleTimeoutAfterDownstreamHeaders) { + auto response = setupPerStreamIdleTimeoutTest(); + + waitForTimeout(*response); + + EXPECT_FALSE(upstream_request_->complete()); + EXPECT_EQ(0U, upstream_request_->bodyLength()); + EXPECT_TRUE(response->complete()); + EXPECT_STREQ("408", response->headers().Status()->value().c_str()); + EXPECT_EQ("stream timeout", response->body()); +} + +// Global per-stream idle timeout applies if there is no per-stream idle timeout. +TEST_P(IdleTimeoutIntegrationTest, GlobalPerStreamIdleTimeoutAfterDownstreamHeaders) { + enable_global_idle_timeout_ = true; + enable_per_stream_idle_timeout_ = false; + auto response = setupPerStreamIdleTimeoutTest(); + + waitForTimeout(*response); + + EXPECT_FALSE(upstream_request_->complete()); + EXPECT_EQ(0U, upstream_request_->bodyLength()); + EXPECT_TRUE(response->complete()); + EXPECT_STREQ("408", response->headers().Status()->value().c_str()); + EXPECT_EQ("stream timeout", response->body()); +} + +// Per-stream idle timeout after having sent downstream headers+body. +TEST_P(IdleTimeoutIntegrationTest, PerStreamIdleTimeoutAfterDownstreamHeadersAndBody) { + auto response = setupPerStreamIdleTimeoutTest(); + + sleep(); + codec_client_->sendData(*request_encoder_, 1, false); + + waitForTimeout(*response); + + EXPECT_FALSE(upstream_request_->complete()); + EXPECT_EQ(1U, upstream_request_->bodyLength()); + EXPECT_TRUE(response->complete()); + EXPECT_STREQ("408", response->headers().Status()->value().c_str()); + EXPECT_EQ("stream timeout", response->body()); +} + +// Per-stream idle timeout after upstream headers have been sent. +TEST_P(IdleTimeoutIntegrationTest, PerStreamIdleTimeoutAfterUpstreamHeaders) { + auto response = setupPerStreamIdleTimeoutTest(); + + upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + + waitForTimeout(*response); + + EXPECT_FALSE(upstream_request_->complete()); + EXPECT_EQ(0U, upstream_request_->bodyLength()); + EXPECT_FALSE(response->complete()); + EXPECT_STREQ("200", response->headers().Status()->value().c_str()); + EXPECT_EQ("", response->body()); +} + +// Per-stream idle timeout after a sequence of header/data events. +TEST_P(IdleTimeoutIntegrationTest, PerStreamIdleTimeoutAfterBidiData) { + auto response = setupPerStreamIdleTimeoutTest(); + + sleep(); + upstream_request_->encode100ContinueHeaders(Http::TestHeaderMapImpl{{":status", "100"}}); + + sleep(); + upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + + sleep(); + upstream_request_->encodeData(1, false); + + sleep(); + codec_client_->sendData(*request_encoder_, 1, false); + + sleep(); + Http::TestHeaderMapImpl request_trailers{{"request1", "trailer1"}, {"request2", "trailer2"}}; + codec_client_->sendTrailers(*request_encoder_, request_trailers); + + sleep(); + upstream_request_->encodeData(1, false); + + waitForTimeout(*response); + + EXPECT_TRUE(upstream_request_->complete()); + EXPECT_EQ(1U, upstream_request_->bodyLength()); + EXPECT_FALSE(response->complete()); + EXPECT_STREQ("200", response->headers().Status()->value().c_str()); + EXPECT_EQ("aa", response->body()); +} + +// Successful request/response when per-stream idle timeout is configured. +TEST_P(IdleTimeoutIntegrationTest, PerStreamIdleTimeoutRequestAndResponse) { + testRouterRequestAndResponseWithBody(1024, 1024, false, nullptr); +} + +} // namespace +} // namespace Envoy diff --git a/test/integration/integration.cc b/test/integration/integration.cc index 3f8803914e98f..688406c625274 100644 --- a/test/integration/integration.cc +++ b/test/integration/integration.cc @@ -242,8 +242,8 @@ Network::ClientConnectionPtr BaseIntegrationTest::makeClientConnection(uint32_t } void BaseIntegrationTest::initialize() { - RELEASE_ASSERT(!initialized_); - RELEASE_ASSERT(Event::Libevent::Global::initialized()); + RELEASE_ASSERT(!initialized_, ""); + RELEASE_ASSERT(Event::Libevent::Global::initialized(), ""); initialized_ = true; createUpstreams(); @@ -289,12 +289,12 @@ void BaseIntegrationTest::setUpstreamProtocol(FakeHttpConnection::Type protocol) if (upstream_protocol_ == FakeHttpConnection::Type::HTTP2) { config_helper_.addConfigModifier( [&](envoy::config::bootstrap::v2::Bootstrap& bootstrap) -> void { - RELEASE_ASSERT(bootstrap.mutable_static_resources()->clusters_size() >= 1); + RELEASE_ASSERT(bootstrap.mutable_static_resources()->clusters_size() >= 1, ""); auto* cluster = bootstrap.mutable_static_resources()->mutable_clusters(0); cluster->mutable_http2_protocol_options(); }); } else { - RELEASE_ASSERT(protocol == FakeHttpConnection::Type::HTTP1); + RELEASE_ASSERT(protocol == FakeHttpConnection::Type::HTTP1, ""); } } @@ -312,7 +312,7 @@ uint32_t BaseIntegrationTest::lookupPort(const std::string& key) { if (it != port_map_.end()) { return it->second; } - RELEASE_ASSERT(false); + RELEASE_ASSERT(false, ""); } void BaseIntegrationTest::setUpstreamAddress(uint32_t upstream_index, @@ -382,7 +382,7 @@ void BaseIntegrationTest::sendRawHttpAndWaitForResponse(int port, const char* ra RawConnectionDriver connection( port, buffer, [&](Network::ClientConnection& client, const Buffer::Instance& data) -> void { - response->append(TestUtility::bufferToString(data)); + response->append(data.toString()); if (disconnect_after_headers_complete && response->find("\r\n\r\n") != std::string::npos) { client.close(Network::ConnectionCloseType::NoFlush); } diff --git a/test/integration/integration_admin_test.cc b/test/integration/integration_admin_test.cc index 9ebd2b9e53877..0376aa9190acf 100644 --- a/test/integration/integration_admin_test.cc +++ b/test/integration/integration_admin_test.cc @@ -21,12 +21,12 @@ TEST_P(IntegrationAdminTest, HealthCheck) { initialize(); BufferingStreamDecoderPtr response = IntegrationUtil::makeSingleRequest( - lookupPort("http"), "GET", "/healthcheck", "", downstreamProtocol(), version_); + lookupPort("http"), "POST", "/healthcheck", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/healthcheck/fail", "", - downstreamProtocol(), version_); + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", "/healthcheck/fail", + "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); @@ -35,7 +35,7 @@ TEST_P(IntegrationAdminTest, HealthCheck) { EXPECT_TRUE(response->complete()); EXPECT_STREQ("503", response->headers().Status()->value().c_str()); - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/healthcheck/ok", "", + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", "/healthcheck/ok", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); @@ -60,39 +60,39 @@ TEST_P(IntegrationAdminTest, AdminLogging) { initialize(); BufferingStreamDecoderPtr response = IntegrationUtil::makeSingleRequest( - lookupPort("admin"), "GET", "/logging", "", downstreamProtocol(), version_); + lookupPort("admin"), "POST", "/logging", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("404", response->headers().Status()->value().c_str()); // Bad level - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/logging?level=blah", + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", "/logging?level=blah", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("404", response->headers().Status()->value().c_str()); // Bad logger - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/logging?blah=info", + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", "/logging?blah=info", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("404", response->headers().Status()->value().c_str()); // This is going to stomp over custom log levels that are set on the command line. response = IntegrationUtil::makeSingleRequest( - lookupPort("admin"), "GET", "/logging?level=warning", "", downstreamProtocol(), version_); + lookupPort("admin"), "POST", "/logging?level=warning", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); for (const Logger::Logger& logger : Logger::Registry::loggers()) { EXPECT_EQ("warning", logger.levelString()); } - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/logging?assert=trace", - "", downstreamProtocol(), version_); + response = IntegrationUtil::makeSingleRequest( + lookupPort("admin"), "POST", "/logging?assert=trace", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); EXPECT_EQ(spdlog::level::trace, Logger::Registry::getLog(Logger::Id::assert).level()); const char* level_name = spdlog::level::level_names[default_log_level_]; - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", fmt::format("/logging?level={}", level_name), "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); @@ -224,7 +224,7 @@ TEST_P(IntegrationAdminTest, Admin) { EXPECT_THAT(response->body(), testing::HasSubstr("added_via_api")); EXPECT_STREQ("text/plain; charset=UTF-8", ContentType(response)); - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/cpuprofiler", "", + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", "/cpuprofiler", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("400", response->headers().Status()->value().c_str()); @@ -236,7 +236,7 @@ TEST_P(IntegrationAdminTest, Admin) { EXPECT_STREQ("200", response->headers().Status()->value().c_str()); EXPECT_STREQ("text/plain; charset=UTF-8", ContentType(response)); - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/reset_counters", "", + response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "POST", "/reset_counters", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); @@ -341,12 +341,12 @@ TEST_P(IntegrationAdminTest, AdminCpuProfilerStart) { initialize(); BufferingStreamDecoderPtr response = IntegrationUtil::makeSingleRequest( - lookupPort("admin"), "GET", "/cpuprofiler?enable=y", "", downstreamProtocol(), version_); + lookupPort("admin"), "POST", "/cpuprofiler?enable=y", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); - response = IntegrationUtil::makeSingleRequest(lookupPort("admin"), "GET", "/cpuprofiler?enable=n", - "", downstreamProtocol(), version_); + response = IntegrationUtil::makeSingleRequest( + lookupPort("admin"), "POST", "/cpuprofiler?enable=n", "", downstreamProtocol(), version_); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); } diff --git a/test/integration/integration_test.cc b/test/integration/integration_test.cc index 526978c2773db..f5773c3532646 100644 --- a/test/integration/integration_test.cc +++ b/test/integration/integration_test.cc @@ -15,7 +15,10 @@ #include "gtest/gtest.h" +using testing::EndsWith; +using testing::HasSubstr; using testing::MatchesRegex; +using testing::Not; namespace Envoy { @@ -134,8 +137,8 @@ TEST_P(IntegrationTest, IdleTimoutBasic) { testIdleTimeoutBasic(); } TEST_P(IntegrationTest, IdleTimeoutWithTwoRequests) { testIdleTimeoutWithTwoRequests(); } // Test hitting the bridge filter with too many response bytes to buffer. Given -// the headers are not proxied, the connection manager will send a 500. -TEST_P(IntegrationTest, HittingEncoderFilterLimitBufferingHeaders) { +// the headers are not proxied, the connection manager will send a local error reply. +TEST_P(IntegrationTest, HittingGrpcFilterLimitBufferingHeaders) { config_helper_.addFilter("{ name: envoy.grpc_http1_bridge, config: {} }"); config_helper_.setBufferLimits(1024, 1024); @@ -152,9 +155,11 @@ TEST_P(IntegrationTest, HittingEncoderFilterLimitBufferingHeaders) { waitForNextUpstreamRequest(); // Send the overly large response. Because the grpc_http1_bridge filter buffers and buffer - // limits are exceeded, this will be translated into a 500 from Envoy. + // limits are exceeded, this will be translated into an unknown gRPC error. upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); + fake_upstreams_[0]->set_allow_unexpected_disconnects(true); upstream_request_->encodeData(1024 * 65, false); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); response->waitForEndStream(); EXPECT_TRUE(response->complete()); @@ -205,6 +210,68 @@ TEST_P(IntegrationTest, OverlyLongHeaders) { testOverlyLongHeaders(); } TEST_P(IntegrationTest, UpstreamProtocolError) { testUpstreamProtocolError(); } +TEST_P(IntegrationTest, TestHead) { + initialize(); + + codec_client_ = makeHttpConnection(lookupPort("http")); + + Http::TestHeaderMapImpl head_request{{":method", "HEAD"}, + {":path", "/test/long/url"}, + {":scheme", "http"}, + {":authority", "host"}}; + + // Without an explicit content length, assume we chunk for HTTP/1.1 + auto response = sendRequestAndWaitForResponse(head_request, 0, default_response_headers_, 0); + ASSERT_TRUE(response->complete()); + EXPECT_STREQ("200", response->headers().Status()->value().c_str()); + EXPECT_TRUE(response->headers().ContentLength() == nullptr); + ASSERT_TRUE(response->headers().TransferEncoding() != nullptr); + EXPECT_EQ(Http::Headers::get().TransferEncodingValues.Chunked, + response->headers().TransferEncoding()->value().c_str()); + EXPECT_EQ(0, response->body().size()); + + // Preserve explicit content length. + Http::TestHeaderMapImpl content_length_response{{":status", "200"}, {"content-length", "12"}}; + response = sendRequestAndWaitForResponse(head_request, 0, content_length_response, 0); + ASSERT_TRUE(response->complete()); + EXPECT_STREQ("200", response->headers().Status()->value().c_str()); + ASSERT_TRUE(response->headers().ContentLength() != nullptr); + EXPECT_STREQ(response->headers().ContentLength()->value().c_str(), "12"); + EXPECT_TRUE(response->headers().TransferEncoding() == nullptr); + EXPECT_EQ(0, response->body().size()); + + cleanupUpstreamAndDownstream(); +} + +// The Envoy HTTP/1.1 codec ASSERTs that T-E headers are cleared in +// encodeHeaders, so to test upstreams explicitly sending T-E: chunked we have +// to send raw HTTP. +TEST_P(IntegrationTest, TestHeadWithExplicitTE) { + initialize(); + + auto tcp_client = makeTcpConnection(lookupPort("http")); + tcp_client->write("HEAD / HTTP/1.1\r\nHost: host\r\n\r\n"); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData( + FakeRawConnection::waitForInexactMatch("\r\n\r\n"), &data)); + + ASSERT_TRUE( + fake_upstream_connection->write("HTTP/1.1 200 OK\r\nTransfer-encoding: chunked\r\n\r\n")); + tcp_client->waitForData("\r\n\r\n", false); + std::string response = tcp_client->data(); + + EXPECT_THAT(response, HasSubstr("HTTP/1.1 200 OK\r\n")); + EXPECT_THAT(response, Not(HasSubstr("content-length"))); + EXPECT_THAT(response, HasSubstr("transfer-encoding: chunked\r\n")); + EXPECT_THAT(response, EndsWith("\r\n\r\n")); + + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); + tcp_client->close(); +} + TEST_P(IntegrationTest, TestBind) { std::string address_string; if (GetParam() == Network::Address::IpVersion::v4) { @@ -216,7 +283,6 @@ TEST_P(IntegrationTest, TestBind) { initialize(); codec_client_ = makeHttpConnection(lookupPort("http")); - // Request 1. auto response = codec_client_->makeRequestWithBody(Http::TestHeaderMapImpl{{":method", "GET"}, @@ -224,17 +290,16 @@ TEST_P(IntegrationTest, TestBind) { {":scheme", "http"}, {":authority", "host"}}, 1024); - - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_NE(fake_upstream_connection_, nullptr); std::string address = fake_upstream_connection_->connection().remoteAddress()->ip()->addressAsString(); EXPECT_EQ(address, address_string); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForEndStream(*dispatcher_); - // Cleanup both downstream and upstream - codec_client_->close(); - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); + ASSERT_NE(upstream_request_, nullptr); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); + + cleanupUpstreamAndDownstream(); } TEST_P(IntegrationTest, TestFailedBind) { diff --git a/test/integration/load_stats_integration_test.cc b/test/integration/load_stats_integration_test.cc index e0011f0c4aad5..a7c16e034c316 100644 --- a/test/integration/load_stats_integration_test.cc +++ b/test/integration/load_stats_integration_test.cc @@ -147,8 +147,11 @@ class LoadStatsIntegrationTest : public HttpIntegrationTest, } void waitForLoadStatsStream() { - fake_loadstats_connection_ = load_report_upstream_->waitForHttpConnection(*dispatcher_); - loadstats_stream_ = fake_loadstats_connection_->waitForNewStream(*dispatcher_); + AssertionResult result = + load_report_upstream_->waitForHttpConnection(*dispatcher_, fake_loadstats_connection_); + RELEASE_ASSERT(result, result.message()); + result = fake_loadstats_connection_->waitForNewStream(*dispatcher_, loadstats_stream_); + RELEASE_ASSERT(result, result.message()); } void @@ -218,9 +221,24 @@ class LoadStatsIntegrationTest : public HttpIntegrationTest, // merge until all the expected load has been reported. do { envoy::service::load_stats::v2::LoadStatsRequest local_loadstats_request; - loadstats_stream_->waitForGrpcMessage(*dispatcher_, local_loadstats_request); - + AssertionResult result = + loadstats_stream_->waitForGrpcMessage(*dispatcher_, local_loadstats_request); + RELEASE_ASSERT(result, result.message()); + // Sanity check and clear the measured load report interval. + for (auto& cluster_stats : *local_loadstats_request.mutable_cluster_stats()) { + const uint32_t actual_load_report_interval_ms = + Protobuf::util::TimeUtil::DurationToMilliseconds(cluster_stats.load_report_interval()); + // Turns out libevent timers aren't that accurate; without this adjustment we see things + // like "expected 500, actual 497". Tweak as needed if races are observed. + EXPECT_GE(actual_load_report_interval_ms, load_report_interval_ms_ - 100); + // Allow for some skew in test environment. + EXPECT_LT(actual_load_report_interval_ms, load_report_interval_ms_ + 1000); + cluster_stats.mutable_load_report_interval()->Clear(); + } mergeLoadStats(loadstats_request, local_loadstats_request); + if (!loadstats_request.cluster_stats().empty()) { + ENVOY_LOG_MISC(debug, "HTD {}", loadstats_request.cluster_stats()[0].DebugString()); + } EXPECT_STREQ("POST", loadstats_stream_->headers().Method()->value().c_str()); EXPECT_STREQ("/envoy.service.load_stats.v2.LoadReportingService/StreamLoadStats", @@ -231,10 +249,13 @@ class LoadStatsIntegrationTest : public HttpIntegrationTest, } void waitForUpstreamResponse(uint32_t endpoint_index, uint32_t response_code = 200) { - fake_upstream_connection_ = - service_upstream_[endpoint_index]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForEndStream(*dispatcher_); + AssertionResult result = service_upstream_[endpoint_index]->waitForHttpConnection( + *dispatcher_, fake_upstream_connection_); + RELEASE_ASSERT(result, result.message()); + result = fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_); + RELEASE_ASSERT(result, result.message()); + result = upstream_request_->waitForEndStream(*dispatcher_); + RELEASE_ASSERT(result, result.message()); upstream_request_->encodeHeaders( Http::TestHeaderMapImpl{{":status", std::to_string(response_code)}}, false); @@ -252,7 +273,8 @@ class LoadStatsIntegrationTest : public HttpIntegrationTest, void requestLoadStatsResponse(const std::vector& clusters) { envoy::service::load_stats::v2::LoadStatsResponse loadstats_response; - loadstats_response.mutable_load_reporting_interval()->set_nanos(500000000); // 500ms + loadstats_response.mutable_load_reporting_interval()->MergeFrom( + Protobuf::util::TimeUtil::MillisecondsToDuration(load_report_interval_ms_)); for (const auto& cluster : clusters) { loadstats_response.add_clusters(cluster); } @@ -277,25 +299,19 @@ class LoadStatsIntegrationTest : public HttpIntegrationTest, return locality_stats; } - void cleanupUpstreamConnection() { - codec_client_->close(); - if (fake_upstream_connection_ != nullptr) { - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); - } - } - void cleanupLoadStatsConnection() { if (fake_loadstats_connection_ != nullptr) { - fake_loadstats_connection_->close(); - fake_loadstats_connection_->waitForDisconnect(); + AssertionResult result = fake_loadstats_connection_->close(); + RELEASE_ASSERT(result, result.message()); + result = fake_loadstats_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } } void sendAndReceiveUpstream(uint32_t endpoint_index, uint32_t response_code = 200) { initiateClientConnection(); waitForUpstreamResponse(endpoint_index, response_code); - cleanupUpstreamConnection(); + cleanupUpstreamAndDownstream(); } static constexpr uint32_t upstream_endpoints_ = 5; @@ -312,6 +328,7 @@ class LoadStatsIntegrationTest : public HttpIntegrationTest, const uint64_t request_size_ = 1024; const uint64_t response_size_ = 512; + const uint32_t load_report_interval_ms_ = 500; }; INSTANTIATE_TEST_CASE_P(IpVersions, LoadStatsIntegrationTest, @@ -390,19 +407,34 @@ TEST_P(LoadStatsIntegrationTest, Success) { EXPECT_LE(5, test_server_->counter("load_reporter.responses")->value()); EXPECT_EQ(0, test_server_->counter("load_reporter.errors")->value()); - // A LoadStatsResponse arrives before the expiration of the reporting interval. + // A LoadStatsResponse arrives before the expiration of the reporting + // interval. Since we are keep tracking cluster_0, stats rollover. requestLoadStatsResponse({"cluster_0"}); sendAndReceiveUpstream(1); requestLoadStatsResponse({"cluster_0"}); sendAndReceiveUpstream(1); sendAndReceiveUpstream(1); - waitForLoadStatsRequest({localityStats("winter", 2, 0, 0)}); + waitForLoadStatsRequest({localityStats("winter", 3, 0, 0)}); EXPECT_EQ(6, test_server_->counter("load_reporter.requests")->value()); EXPECT_LE(6, test_server_->counter("load_reporter.responses")->value()); EXPECT_EQ(0, test_server_->counter("load_reporter.errors")->value()); + // As above, but stop tracking cluster_0 and only get the requests since the + // response. + requestLoadStatsResponse({}); + sendAndReceiveUpstream(1); + requestLoadStatsResponse({"cluster_0"}); + sendAndReceiveUpstream(1); + sendAndReceiveUpstream(1); + + waitForLoadStatsRequest({localityStats("winter", 2, 0, 0)}); + + EXPECT_EQ(8, test_server_->counter("load_reporter.requests")->value()); + EXPECT_LE(7, test_server_->counter("load_reporter.responses")->value()); + EXPECT_EQ(0, test_server_->counter("load_reporter.errors")->value()); + cleanupLoadStatsConnection(); } @@ -515,7 +547,7 @@ TEST_P(LoadStatsIntegrationTest, InProgress) { waitForLoadStatsRequest({localityStats("winter", 0, 0, 1)}); waitForUpstreamResponse(0, 503); - cleanupUpstreamConnection(); + cleanupUpstreamAndDownstream(); EXPECT_EQ(1, test_server_->counter("load_reporter.requests")->value()); EXPECT_LE(2, test_server_->counter("load_reporter.responses")->value()); @@ -544,7 +576,7 @@ TEST_P(LoadStatsIntegrationTest, Dropped) { response_->waitForEndStream(); ASSERT_TRUE(response_->complete()); EXPECT_STREQ("503", response_->headers().Status()->value().c_str()); - cleanupUpstreamConnection(); + cleanupUpstreamAndDownstream(); waitForLoadStatsRequest({}, 1); diff --git a/test/integration/proxy_proto_integration_test.cc b/test/integration/proxy_proto_integration_test.cc index c5289aa44b465..20a193ab6c8e3 100644 --- a/test/integration/proxy_proto_integration_test.cc +++ b/test/integration/proxy_proto_integration_test.cc @@ -15,7 +15,7 @@ INSTANTIATE_TEST_CASE_P(IpVersions, ProxyProtoIntegrationTest, testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), TestUtility::ipTestParamsToString); -TEST_P(ProxyProtoIntegrationTest, RouterRequestAndResponseWithBodyNoBuffer) { +TEST_P(ProxyProtoIntegrationTest, v1RouterRequestAndResponseWithBodyNoBuffer) { ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { Network::ClientConnectionPtr conn = makeClientConnection(lookupPort("http")); Buffer::OwnedImpl buf("PROXY TCP4 1.2.3.4 254.254.254.254 65535 1234\r\n"); @@ -26,7 +26,21 @@ TEST_P(ProxyProtoIntegrationTest, RouterRequestAndResponseWithBodyNoBuffer) { testRouterRequestAndResponseWithBody(1024, 512, false, &creator); } -TEST_P(ProxyProtoIntegrationTest, RouterRequestAndResponseWithBodyNoBufferV6) { +TEST_P(ProxyProtoIntegrationTest, v2RouterRequestAndResponseWithBodyNoBuffer) { + ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { + Network::ClientConnectionPtr conn = makeClientConnection(lookupPort("http")); + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0x00, 0x0c, 0x01, 0x02, 0x03, 0x04, + 0xff, 0xff, 0xfe, 0xfe, 0xfe, 0xfe, 0x04, 0xd2}; + Buffer::OwnedImpl buf(buffer, sizeof(buffer)); + conn->write(buf, false); + return conn; + }; + + testRouterRequestAndResponseWithBody(1024, 512, false, &creator); +} + +TEST_P(ProxyProtoIntegrationTest, v1RouterRequestAndResponseWithBodyNoBufferV6) { ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { auto conn = makeClientConnection(lookupPort("http")); Buffer::OwnedImpl buf("PROXY TCP6 1:2:3::4 FE00:: 65535 1234\r\n"); @@ -37,6 +51,22 @@ TEST_P(ProxyProtoIntegrationTest, RouterRequestAndResponseWithBodyNoBufferV6) { testRouterRequestAndResponseWithBody(1024, 512, false, &creator); } +TEST_P(ProxyProtoIntegrationTest, v2RouterRequestAndResponseWithBodyNoBufferV6) { + ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, + 0x0a, 0x21, 0x22, 0x00, 0x24, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, + 0x01, 0x01, 0x00, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x02}; + auto conn = makeClientConnection(lookupPort("http")); + Buffer::OwnedImpl buf(buffer, sizeof(buffer)); + conn->write(buf, false); + return conn; + }; + + testRouterRequestAndResponseWithBody(1024, 512, false, &creator); +} + TEST_P(ProxyProtoIntegrationTest, RouterProxyUnknownRequestAndResponseWithBodyNoBuffer) { ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { auto conn = makeClientConnection(lookupPort("http")); diff --git a/test/integration/ratelimit_integration_test.cc b/test/integration/ratelimit_integration_test.cc index 1e77ff2b74b6b..c45084e716fb2 100644 --- a/test/integration/ratelimit_integration_test.cc +++ b/test/integration/ratelimit_integration_test.cc @@ -74,11 +74,16 @@ class RatelimitIntegrationTest : public HttpIntegrationTest, } void waitForRatelimitRequest() { - fake_ratelimit_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - ratelimit_request_ = fake_ratelimit_connection_->waitForNewStream(*dispatcher_); + AssertionResult result = + fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_ratelimit_connection_); + RELEASE_ASSERT(result, result.message()); + result = fake_ratelimit_connection_->waitForNewStream(*dispatcher_, ratelimit_request_); + RELEASE_ASSERT(result, result.message()); envoy::service::ratelimit::v2::RateLimitRequest request_msg; - ratelimit_request_->waitForGrpcMessage(*dispatcher_, request_msg); - ratelimit_request_->waitForEndStream(*dispatcher_); + result = ratelimit_request_->waitForGrpcMessage(*dispatcher_, request_msg); + RELEASE_ASSERT(result, result.message()); + result = ratelimit_request_->waitForEndStream(*dispatcher_); + RELEASE_ASSERT(result, result.message()); EXPECT_STREQ("POST", ratelimit_request_->headers().Method()->value().c_str()); if (useDataPlaneProto()) { EXPECT_STREQ("/envoy.service.ratelimit.v2.RateLimitService/ShouldRateLimit", @@ -98,9 +103,13 @@ class RatelimitIntegrationTest : public HttpIntegrationTest, } void waitForSuccessfulUpstreamResponse() { - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForEndStream(*dispatcher_); + AssertionResult result = + fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_); + RELEASE_ASSERT(result, result.message()); + result = fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_); + RELEASE_ASSERT(result, result.message()); + result = upstream_request_->waitForEndStream(*dispatcher_); + RELEASE_ASSERT(result, result.message()); upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, false); upstream_request_->encodeData(response_size_, true); @@ -130,15 +139,16 @@ class RatelimitIntegrationTest : public HttpIntegrationTest, } void cleanup() { - codec_client_->close(); if (fake_ratelimit_connection_ != nullptr) { - fake_ratelimit_connection_->close(); - fake_ratelimit_connection_->waitForDisconnect(); - } - if (fake_upstream_connection_ != nullptr) { - fake_upstream_connection_->close(); - fake_upstream_connection_->waitForDisconnect(); + if (clientType() != Grpc::ClientType::GoogleGrpc) { + // TODO(htuch) we should document the underlying cause of this difference and/or fix it. + AssertionResult result = fake_ratelimit_connection_->close(); + RELEASE_ASSERT(result, result.message()); + } + AssertionResult result = fake_ratelimit_connection_->waitForDisconnect(); + RELEASE_ASSERT(result, result.message()); } + cleanupUpstreamAndDownstream(); } FakeHttpConnectionPtr fake_ratelimit_connection_; @@ -205,7 +215,7 @@ TEST_P(RatelimitIntegrationTest, Timeout) { EXPECT_EQ(1, test_server_->counter("grpc.ratelimit.streams_closed_4")->value()); break; default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } // Rate limiter fails open waitForSuccessfulUpstreamResponse(); @@ -214,9 +224,9 @@ TEST_P(RatelimitIntegrationTest, Timeout) { TEST_P(RatelimitIntegrationTest, ConnectImmediateDisconnect) { initiateClientConnection(); - fake_ratelimit_connection_ = fake_upstreams_[1]->waitForHttpConnection(*dispatcher_); - fake_ratelimit_connection_->close(); - fake_ratelimit_connection_->waitForDisconnect(true); + ASSERT_TRUE(fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, fake_ratelimit_connection_)); + ASSERT_TRUE(fake_ratelimit_connection_->close()); + ASSERT_TRUE(fake_ratelimit_connection_->waitForDisconnect(true)); fake_ratelimit_connection_ = nullptr; // Rate limiter fails open waitForSuccessfulUpstreamResponse(); @@ -224,7 +234,11 @@ TEST_P(RatelimitIntegrationTest, ConnectImmediateDisconnect) { } TEST_P(RatelimitIntegrationTest, FailedConnect) { - fake_upstreams_[1].reset(); + // Do not reset the fake upstream for the ratelimiter, but have it stop listening. + // If we reset, the Envoy will continue to send H2 to the original rate limiter port, which may + // be used by another test, and data sent to that port "unexpectedly" will cause problems for + // that test. + fake_upstreams_[1]->cleanUp(); initiateClientConnection(); // Rate limiter fails open waitForSuccessfulUpstreamResponse(); diff --git a/test/integration/sds_static_integration_test.cc b/test/integration/sds_static_integration_test.cc new file mode 100644 index 0000000000000..28d017d242649 --- /dev/null +++ b/test/integration/sds_static_integration_test.cc @@ -0,0 +1,191 @@ +#include +#include + +#include "common/event/dispatcher_impl.h" +#include "common/network/connection_impl.h" +#include "common/network/utility.h" +#include "common/ssl/context_config_impl.h" +#include "common/ssl/context_manager_impl.h" + +#include "test/integration/http_integration.h" +#include "test/integration/server.h" +#include "test/integration/ssl_utility.h" +#include "test/mocks/init/mocks.h" +#include "test/mocks/runtime/mocks.h" +#include "test/mocks/secret/mocks.h" +#include "test/test_common/network_utility.h" +#include "test/test_common/utility.h" + +#include "absl/strings/match.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "integration.h" +#include "utility.h" + +using testing::NiceMock; +using testing::Return; + +namespace Envoy { +namespace Ssl { + +class SdsStaticDownstreamIntegrationTest + : public HttpIntegrationTest, + public testing::TestWithParam { +public: + SdsStaticDownstreamIntegrationTest() + : HttpIntegrationTest(Http::CodecClient::Type::HTTP1, GetParam()) {} + + void initialize() override { + config_helper_.addConfigModifier([](envoy::config::bootstrap::v2::Bootstrap& bootstrap) { + auto* common_tls_context = bootstrap.mutable_static_resources() + ->mutable_listeners(0) + ->mutable_filter_chains(0) + ->mutable_tls_context() + ->mutable_common_tls_context(); + common_tls_context->add_alpn_protocols("http/1.1"); + + auto* validation_context = common_tls_context->mutable_validation_context(); + validation_context->mutable_trusted_ca()->set_filename( + TestEnvironment::runfilesPath("test/config/integration/certs/cacert.pem")); + validation_context->add_verify_certificate_hash( + "E0:F3:C8:CE:5E:2E:A3:05:F0:70:1F:F5:12:E3:6E:2E:" + "97:92:82:84:A2:28:BC:F7:73:32:D3:39:30:A1:B6:FD"); + + common_tls_context->add_tls_certificate_sds_secret_configs()->set_name("server_cert"); + + auto* secret = bootstrap.mutable_static_resources()->add_secrets(); + secret->set_name("server_cert"); + auto* tls_certificate = secret->mutable_tls_certificate(); + tls_certificate->mutable_certificate_chain()->set_filename( + TestEnvironment::runfilesPath("/test/config/integration/certs/servercert.pem")); + tls_certificate->mutable_private_key()->set_filename( + TestEnvironment::runfilesPath("/test/config/integration/certs/serverkey.pem")); + }); + + HttpIntegrationTest::initialize(); + + registerTestServerPorts({"http"}); + + client_ssl_ctx_ = + createClientSslTransportSocketFactory(false, false, context_manager_, secret_manager_); + } + + void TearDown() override { + client_ssl_ctx_.reset(); + cleanupUpstreamAndDownstream(); + fake_upstream_connection_.reset(); + codec_client_.reset(); + } + + Network::ClientConnectionPtr makeSslClientConnection() { + Network::Address::InstanceConstSharedPtr address = getSslAddress(version_, lookupPort("http")); + return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(), + client_ssl_ctx_->createTransportSocket(), nullptr); + } + +private: + Runtime::MockLoader runtime_; + Ssl::ContextManagerImpl context_manager_{runtime_}; + Secret::MockSecretManager secret_manager_; + + Network::TransportSocketFactoryPtr client_ssl_ctx_; +}; + +INSTANTIATE_TEST_CASE_P(IpVersions, SdsStaticDownstreamIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), + TestUtility::ipTestParamsToString); + +TEST_P(SdsStaticDownstreamIntegrationTest, RouterRequestAndResponseWithGiantBodyBuffer) { + ConnectionCreationFunction creator = [&]() -> Network::ClientConnectionPtr { + return makeSslClientConnection(); + }; + testRouterRequestAndResponseWithBody(16 * 1024 * 1024, 16 * 1024 * 1024, false, &creator); +} + +class SdsStaticUpstreamIntegrationTest + : public HttpIntegrationTest, + public testing::TestWithParam { +public: + SdsStaticUpstreamIntegrationTest() + : HttpIntegrationTest(Http::CodecClient::Type::HTTP1, GetParam()) {} + + void initialize() override { + config_helper_.addConfigModifier([](envoy::config::bootstrap::v2::Bootstrap& bootstrap) { + bootstrap.mutable_static_resources() + ->mutable_clusters(0) + ->mutable_tls_context() + ->mutable_common_tls_context() + ->add_tls_certificate_sds_secret_configs() + ->set_name("client_cert"); + + auto* secret = bootstrap.mutable_static_resources()->add_secrets(); + secret->set_name("client_cert"); + auto* tls_certificate = secret->mutable_tls_certificate(); + tls_certificate->mutable_certificate_chain()->set_filename( + TestEnvironment::runfilesPath("/test/config/integration/certs/clientcert.pem")); + tls_certificate->mutable_private_key()->set_filename( + TestEnvironment::runfilesPath("/test/config/integration/certs/clientkey.pem")); + }); + + HttpIntegrationTest::initialize(); + + registerTestServerPorts({"http"}); + } + + void TearDown() override { + cleanupUpstreamAndDownstream(); + fake_upstream_connection_.reset(); + codec_client_.reset(); + + test_server_.reset(); + fake_upstreams_.clear(); + } + + void createUpstreams() override { + fake_upstreams_.emplace_back( + new FakeUpstream(createUpstreamSslContext(), 0, FakeHttpConnection::Type::HTTP1, version_)); + } + + Network::TransportSocketFactoryPtr createUpstreamSslContext() { + envoy::api::v2::auth::DownstreamTlsContext tls_context; + auto* common_tls_context = tls_context.mutable_common_tls_context(); + common_tls_context->add_alpn_protocols("h2"); + common_tls_context->add_alpn_protocols("http/1.1"); + common_tls_context->mutable_deprecated_v1()->set_alt_alpn_protocols("http/1.1"); + + auto* validation_context = common_tls_context->mutable_validation_context(); + validation_context->mutable_trusted_ca()->set_filename( + TestEnvironment::runfilesPath("test/config/integration/certs/cacert.pem")); + validation_context->add_verify_certificate_hash( + "E0:F3:C8:CE:5E:2E:A3:05:F0:70:1F:F5:12:E3:6E:2E:" + "97:92:82:84:A2:28:BC:F7:73:32:D3:39:30:A1:B6:FD"); + + auto* tls_certificate = common_tls_context->add_tls_certificates(); + tls_certificate->mutable_certificate_chain()->set_filename( + TestEnvironment::runfilesPath("/test/config/integration/certs/servercert.pem")); + tls_certificate->mutable_private_key()->set_filename( + TestEnvironment::runfilesPath("/test/config/integration/certs/serverkey.pem")); + + Ssl::ServerContextConfigImpl cfg(tls_context, secret_manager_); + + static Stats::Scope* upstream_stats_store = new Stats::TestIsolatedStoreImpl(); + return std::make_unique( + cfg, context_manager_, *upstream_stats_store, std::vector{}); + } + +private: + Runtime::MockLoader runtime_; + Ssl::ContextManagerImpl context_manager_{runtime_}; + Secret::MockSecretManager secret_manager_; +}; + +INSTANTIATE_TEST_CASE_P(IpVersions, SdsStaticUpstreamIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), + TestUtility::ipTestParamsToString); + +TEST_P(SdsStaticUpstreamIntegrationTest, RouterRequestAndResponseWithGiantBodyBuffer) { + testRouterRequestAndResponseWithBody(16 * 1024 * 1024, 16 * 1024 * 1024, false, nullptr); +} + +} // namespace Ssl +} // namespace Envoy diff --git a/test/integration/server.cc b/test/integration/server.cc index ed0f40191beb0..eb84cfe8ad911 100644 --- a/test/integration/server.cc +++ b/test/integration/server.cc @@ -15,6 +15,7 @@ #include "test/integration/integration.h" #include "test/integration/utility.h" #include "test/mocks/runtime/mocks.h" +#include "test/mocks/server/mocks.h" #include "test/test_common/environment.h" #include "gtest/gtest.h" @@ -58,7 +59,7 @@ IntegrationTestServer::~IntegrationTestServer() { ENVOY_LOG(info, "stopping integration test server"); BufferingStreamDecoderPtr response = - IntegrationUtil::makeSingleRequest(server_->admin().socket().localAddress(), "GET", + IntegrationUtil::makeSingleRequest(server_->admin().socket().localAddress(), "POST", "/quitquitquit", "", Http::CodecClient::Type::HTTP1); EXPECT_TRUE(response->complete()); EXPECT_STREQ("200", response->headers().Status()->value().c_str()); @@ -91,8 +92,9 @@ void IntegrationTestServer::threadRoutine(const Network::Address::IpVersion vers Thread::MutexBasicLockable lock; ThreadLocal::InstanceImpl tls; - Stats::HeapRawStatDataAllocator stats_allocator; - Stats::ThreadLocalStoreImpl stats_store(stats_allocator); + Stats::HeapStatDataAllocator stats_allocator; + Stats::StatsOptionsImpl stats_options; + Stats::ThreadLocalStoreImpl stats_store(stats_options, stats_allocator); stat_store_ = &stats_store; Runtime::RandomGeneratorPtr random_generator; if (deterministic) { diff --git a/test/integration/server.h b/test/integration/server.h index 8f0f8fbf94afe..c29abdcd6ae8f 100644 --- a/test/integration/server.h +++ b/test/integration/server.h @@ -7,6 +7,7 @@ #include #include "envoy/server/options.h" +#include "envoy/stats/stats.h" #include "common/common/assert.h" #include "common/common/lock_guard.h" @@ -49,8 +50,8 @@ class TestOptionsImpl : public Options { return local_address_ip_version_; } std::chrono::seconds drainTime() const override { return std::chrono::seconds(1); } - spdlog::level::level_enum logLevel() const override { NOT_IMPLEMENTED; } - const std::string& logFormat() const override { NOT_IMPLEMENTED; } + spdlog::level::level_enum logLevel() const override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + const std::string& logFormat() const override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } std::chrono::seconds parentShutdownTime() const override { return std::chrono::seconds(2); } const std::string& logPath() const override { return log_path_; } uint64_t restartEpoch() const override { return 0; } @@ -62,7 +63,7 @@ class TestOptionsImpl : public Options { const std::string& serviceNodeName() const override { return service_node_name_; } const std::string& serviceZone() const override { return service_zone_; } uint64_t maxStats() const override { return 16384; } - uint64_t maxObjNameLength() const override { return 60; } + const Stats::StatsOptions& statsOptions() const override { return stats_options_; } bool hotRestartDisabled() const override { return false; } // asConfigYaml returns a new config that empties the configPath() and populates configYaml() @@ -76,6 +77,7 @@ class TestOptionsImpl : public Options { const std::string service_cluster_name_; const std::string service_node_name_; const std::string service_zone_; + Stats::StatsOptionsImpl stats_options_; const std::string log_path_; }; @@ -138,9 +140,12 @@ class TestScopeWrapper : public Scope { return wrapped_scope_->histogram(name); } + const Stats::StatsOptions& statsOptions() const override { return stats_options_; } + private: Thread::MutexBasicLockable& lock_; ScopePtr wrapped_scope_; + Stats::StatsOptionsImpl stats_options_; }; /** @@ -168,6 +173,7 @@ class TestIsolatedStoreImpl : public StoreRoot { Thread::LockGuard lock(lock_); return store_.histogram(name); } + const Stats::StatsOptions& statsOptions() const override { return stats_options_; } // Stats::Store std::vector counters() const override { @@ -196,6 +202,7 @@ class TestIsolatedStoreImpl : public StoreRoot { mutable Thread::MutexBasicLockable lock_; IsolatedStoreImpl store_; SourceImpl source_; + Stats::StatsOptionsImpl stats_options_; }; } // namespace Stats @@ -219,7 +226,7 @@ class IntegrationTestServer : Logger::Loggable, Server::TestDrainManager& drainManager() { return *drain_manager_; } Server::InstanceImpl& server() { - RELEASE_ASSERT(server_ != nullptr); + RELEASE_ASSERT(server_ != nullptr, ""); return *server_; } void setOnWorkerListenerAddedCb(std::function on_worker_listener_added) { diff --git a/test/integration/ssl_integration_test.cc b/test/integration/ssl_integration_test.cc index 2cbd941cfa9a4..3ed603e9b37c6 100644 --- a/test/integration/ssl_integration_test.cc +++ b/test/integration/ssl_integration_test.cc @@ -50,6 +50,8 @@ void SslIntegrationTest::TearDown() { client_ssl_ctx_alpn_.reset(); client_ssl_ctx_san_.reset(); client_ssl_ctx_alpn_san_.reset(); + HttpIntegrationTest::cleanupUpstreamAndDownstream(); + codec_client_.reset(); context_manager_.reset(); runtime_.reset(); } diff --git a/test/integration/tcp_conn_pool_integration_test.cc b/test/integration/tcp_conn_pool_integration_test.cc new file mode 100644 index 0000000000000..fbe93d877c9a4 --- /dev/null +++ b/test/integration/tcp_conn_pool_integration_test.cc @@ -0,0 +1,198 @@ +#include + +#include "envoy/config/bootstrap/v2/bootstrap.pb.h" +#include "envoy/server/filter_config.h" + +#include "test/integration/integration.h" +#include "test/integration/utility.h" +#include "test/server/utility.h" +#include "test/test_common/registry.h" +#include "test/test_common/utility.h" + +namespace Envoy { +namespace { + +std::string tcp_conn_pool_config; + +// Trivial Filter that obtains connections from a TCP connection pool each time onData is called +// and sends the data to the resulting upstream. The upstream's response is sent directly to +// the downstream. +class TestFilter : public Network::ReadFilter { +public: + TestFilter(Upstream::ClusterManager& cluster_manager) : cluster_manager_(cluster_manager) {} + + // Network::ReadFilter + Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override { + UNREFERENCED_PARAMETER(end_stream); + + Tcp::ConnectionPool::Instance* pool = cluster_manager_.tcpConnPoolForCluster( + "cluster_0", Upstream::ResourcePriority::Default, nullptr); + ASSERT(pool != nullptr); + + requests_.emplace_back(*this, data); + pool->newConnection(requests_.back()); + + ASSERT(data.length() == 0); + return Network::FilterStatus::StopIteration; + } + Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } + void initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) override { + read_callbacks_ = &callbacks; + } + +private: + class Request : public Tcp::ConnectionPool::Callbacks, + public Tcp::ConnectionPool::UpstreamCallbacks { + public: + Request(TestFilter& parent, Buffer::Instance& data) : parent_(parent) { data_.move(data); } + + // Tcp::ConnectionPool::Callbacks + void onPoolFailure(Tcp::ConnectionPool::PoolFailureReason, + Upstream::HostDescriptionConstSharedPtr) override { + ASSERT(false); + } + + void onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn, + Upstream::HostDescriptionConstSharedPtr) override { + upstream_ = std::move(conn); + + upstream_->addUpstreamCallbacks(*this); + upstream_->connection().write(data_, false); + } + + // Tcp::ConnectionPool::UpstreamCallbacks + void onUpstreamData(Buffer::Instance& data, bool end_stream) override { + UNREFERENCED_PARAMETER(end_stream); + + Network::Connection& downstream = parent_.read_callbacks_->connection(); + downstream.write(data, false); + + upstream_.reset(); + } + void onEvent(Network::ConnectionEvent) override {} + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + TestFilter& parent_; + Buffer::OwnedImpl data_; + Tcp::ConnectionPool::ConnectionDataPtr upstream_; + }; + + Upstream::ClusterManager& cluster_manager_; + Network::ReadFilterCallbacks* read_callbacks_{}; + std::list requests_; +}; + +class TestFilterConfigFactory : public Server::Configuration::NamedNetworkFilterConfigFactory { +public: + // NamedNetworkFilterConfigFactory + Network::FilterFactoryCb + createFilterFactory(const Json::Object&, + Server::Configuration::FactoryContext& context) override { + return [&context](Network::FilterManager& filter_manager) -> void { + filter_manager.addReadFilter(std::make_shared(context.clusterManager())); + }; + } + + Network::FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message&, + Server::Configuration::FactoryContext& context) override { + return [&context](Network::FilterManager& filter_manager) -> void { + filter_manager.addReadFilter(std::make_shared(context.clusterManager())); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return ProtobufTypes::MessagePtr{new Envoy::ProtobufWkt::Empty()}; + } + + std::string name() override { CONSTRUCT_ON_FIRST_USE(std::string, "envoy.test.router"); } +}; + +} // namespace + +class TcpConnPoolIntegrationTest : public BaseIntegrationTest, + public testing::TestWithParam { +public: + TcpConnPoolIntegrationTest() + : BaseIntegrationTest(GetParam(), tcp_conn_pool_config), filter_resolver_(config_factory_) {} + + // Called once by the gtest framework before any tests are run. + static void SetUpTestCase() { + tcp_conn_pool_config = ConfigHelper::BASE_CONFIG + R"EOF( + filter_chains: + - filters: + - name: envoy.test.router + config: + )EOF"; + } + + // Initializer for individual tests. + void SetUp() override { BaseIntegrationTest::initialize(); } + + // Destructor for individual tests. + void TearDown() override { + test_server_.reset(); + fake_upstreams_.clear(); + } + +private: + TestFilterConfigFactory config_factory_; + Registry::InjectFactory filter_resolver_; +}; + +INSTANTIATE_TEST_CASE_P(IpVersions, TcpConnPoolIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), + TestUtility::ipTestParamsToString); + +TEST_P(TcpConnPoolIntegrationTest, SingleRequest) { + std::string request("request"); + std::string response("response"); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + tcp_client->write(request); + + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->waitForData(request.size())); + ASSERT_TRUE(fake_upstream_connection->write(response)); + + tcp_client->waitForData(response); + tcp_client->close(); +} + +TEST_P(TcpConnPoolIntegrationTest, MultipleRequests) { + std::string request1("request1"); + std::string request2("request2"); + std::string response1("response1"); + std::string response2("response2"); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + + // send request 1 + tcp_client->write(request1); + FakeRawConnectionPtr fake_upstream_connection1; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection1)); + std::string data; + ASSERT_TRUE(fake_upstream_connection1->waitForData(request1.size(), &data)); + EXPECT_EQ(request1, data); + + // send request 2 + tcp_client->write(request2); + FakeRawConnectionPtr fake_upstream_connection2; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection2)); + ASSERT_TRUE(fake_upstream_connection2->waitForData(request2.size(), &data)); + EXPECT_EQ(request2, data); + + // send response 2 + ASSERT_TRUE(fake_upstream_connection2->write(response2)); + tcp_client->waitForData(response2); + + // send response 1 + ASSERT_TRUE(fake_upstream_connection1->write(response1)); + tcp_client->waitForData(response1, false); + + tcp_client->close(); +} + +} // namespace Envoy diff --git a/test/integration/tcp_proxy_integration_test.cc b/test/integration/tcp_proxy_integration_test.cc index 64779fb89ac20..ae3d1f132ef8f 100644 --- a/test/integration/tcp_proxy_integration_test.cc +++ b/test/integration/tcp_proxy_integration_test.cc @@ -33,19 +33,20 @@ void TcpProxyIntegrationTest::initialize() { TEST_P(TcpProxyIntegrationTest, TcpProxyUpstreamWritesFirst) { initialize(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); - fake_upstream_connection->write("hello"); + ASSERT_TRUE(fake_upstream_connection->write("hello")); tcp_client->waitForData("hello"); tcp_client->write("hello"); - fake_upstream_connection->waitForData(5); + ASSERT_TRUE(fake_upstream_connection->waitForData(5)); - fake_upstream_connection->write("", true); + ASSERT_TRUE(fake_upstream_connection->write("", true)); tcp_client->waitForHalfClose(); tcp_client->write("", true); - fake_upstream_connection->waitForHalfClose(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); } // Test proxying data in both directions, and that all data is flushed properly @@ -54,11 +55,12 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyUpstreamDisconnect) { initialize(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); tcp_client->write("hello"); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - fake_upstream_connection->waitForData(5); - fake_upstream_connection->write("world"); - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->waitForData(5)); + ASSERT_TRUE(fake_upstream_connection->write("world")); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); tcp_client->waitForHalfClose(); tcp_client->close(); @@ -71,15 +73,16 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyDownstreamDisconnect) { initialize(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); tcp_client->write("hello"); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - fake_upstream_connection->waitForData(5); - fake_upstream_connection->write("world"); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->waitForData(5)); + ASSERT_TRUE(fake_upstream_connection->write("world")); tcp_client->waitForData("world"); tcp_client->write("hello", true); - fake_upstream_connection->waitForData(10); - fake_upstream_connection->waitForHalfClose(); - fake_upstream_connection->write("", true); - fake_upstream_connection->waitForDisconnect(true); + ASSERT_TRUE(fake_upstream_connection->waitForData(10)); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); + ASSERT_TRUE(fake_upstream_connection->write("", true)); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect(true)); tcp_client->waitForDisconnect(); } @@ -90,14 +93,15 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyLargeWrite) { std::string data(1024 * 16, 'a'); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); tcp_client->write(data); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - fake_upstream_connection->waitForData(data.size()); - fake_upstream_connection->write(data); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->waitForData(data.size())); + ASSERT_TRUE(fake_upstream_connection->write(data)); tcp_client->waitForData(data); tcp_client->close(); - fake_upstream_connection->waitForHalfClose(); - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); uint32_t upstream_pauses = test_server_->counter("cluster.cluster_0.upstream_flow_control_paused_reading_total") @@ -123,15 +127,16 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyDownstreamFlush) { std::string data(size, 'a'); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); tcp_client->readDisable(true); tcp_client->write("", true); // This ensures that readDisable(true) has been run on it's thread // before tcp_client starts writing. - fake_upstream_connection->waitForHalfClose(); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); - fake_upstream_connection->write(data, true); + ASSERT_TRUE(fake_upstream_connection->write(data, true)); test_server_->waitForCounterGe("cluster.cluster_0.upstream_flow_control_paused_reading_total", 1); EXPECT_EQ(test_server_->counter("cluster.cluster_0.upstream_flow_control_resumed_reading_total") @@ -140,7 +145,7 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyDownstreamFlush) { tcp_client->readDisable(false); tcp_client->waitForData(data); tcp_client->waitForHalfClose(); - fake_upstream_connection->waitForHalfClose(); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); uint32_t upstream_pauses = test_server_->counter("cluster.cluster_0.upstream_flow_control_paused_reading_total") @@ -161,9 +166,10 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyUpstreamFlush) { std::string data(size, 'a'); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - fake_upstream_connection->readDisable(true); - fake_upstream_connection->write("", true); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->readDisable(true)); + ASSERT_TRUE(fake_upstream_connection->write("", true)); // This ensures that fake_upstream_connection->readDisable has been run on it's thread // before tcp_client starts writing. @@ -172,9 +178,9 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyUpstreamFlush) { tcp_client->write(data, true); test_server_->waitForGaugeEq("tcp.tcp_stats.upstream_flush_active", 1); - fake_upstream_connection->readDisable(false); - fake_upstream_connection->waitForData(data.size()); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->readDisable(false)); + ASSERT_TRUE(fake_upstream_connection->waitForData(data.size())); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); tcp_client->waitForHalfClose(); EXPECT_EQ(test_server_->counter("tcp.tcp_stats.upstream_flush_total")->value(), 1); @@ -190,9 +196,10 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyUpstreamFlushEnvoyExit) { std::string data(size, 'a'); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - fake_upstream_connection->readDisable(true); - fake_upstream_connection->write("", true); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->readDisable(true)); + ASSERT_TRUE(fake_upstream_connection->write("", true)); // This ensures that fake_upstream_connection->readDisable has been run on it's thread // before tcp_client starts writing. @@ -202,8 +209,8 @@ TEST_P(TcpProxyIntegrationTest, TcpProxyUpstreamFlushEnvoyExit) { test_server_->waitForGaugeEq("tcp.tcp_stats.upstream_flush_active", 1); test_server_.reset(); - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); // Success criteria is that no ASSERTs fire and there are no leaks. } @@ -233,16 +240,17 @@ TEST_P(TcpProxyIntegrationTest, AccessLog) { initialize(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); - fake_upstream_connection->write("hello"); + ASSERT_TRUE(fake_upstream_connection->write("hello")); tcp_client->waitForData("hello"); - fake_upstream_connection->write("", true); + ASSERT_TRUE(fake_upstream_connection->write("", true)); tcp_client->waitForHalfClose(); tcp_client->write("", true); - fake_upstream_connection->waitForHalfClose(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); std::string log_result; // Access logs only get flushed to disk periodically, so poll until the log is non-empty @@ -277,16 +285,17 @@ TEST_P(TcpProxyIntegrationTest, ShutdownWithOpenConnections) { initialize(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); tcp_client->write("hello"); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - fake_upstream_connection->waitForData(5); - fake_upstream_connection->write("world"); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->waitForData(5)); + ASSERT_TRUE(fake_upstream_connection->write("world")); tcp_client->waitForData("world"); tcp_client->write("hello", false); - fake_upstream_connection->waitForData(10); + ASSERT_TRUE(fake_upstream_connection->waitForData(10)); test_server_.reset(); - fake_upstream_connection->waitForHalfClose(); - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(true); + ASSERT_TRUE(fake_upstream_connection->waitForHalfClose()); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect(true)); tcp_client->waitForHalfClose(); tcp_client->close(); @@ -333,14 +342,15 @@ TEST_P(TcpProxyIntegrationTest, TestIdletimeoutWithLargeOutstandingData) { initialize(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("tcp_proxy")); - FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); std::string data(1024 * 16, 'a'); tcp_client->write(data); - fake_upstream_connection->write(data); + ASSERT_TRUE(fake_upstream_connection->write(data)); tcp_client->waitForDisconnect(true); - fake_upstream_connection->waitForDisconnect(true); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect(true)); } void TcpProxySslIntegrationTest::initialize() { @@ -387,7 +397,8 @@ void TcpProxySslIntegrationTest::setupConnections() { dispatcher_->run(Event::Dispatcher::RunType::NonBlock); } - fake_upstream_connection_ = fake_upstreams_[0]->waitForRawConnection(); + AssertionResult result = fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection_); + RELEASE_ASSERT(result, result.message()); } // Test proxying data in both directions with envoy doing TCP and TLS @@ -402,10 +413,10 @@ void TcpProxySslIntegrationTest::sendAndReceiveTlsData(const std::string& data_t } // Make sure the data makes it upstream. - fake_upstream_connection_->waitForData(data_to_send_upstream.size()); + ASSERT_TRUE(fake_upstream_connection_->waitForData(data_to_send_upstream.size())); // Now send data downstream and make sure it arrives. - fake_upstream_connection_->write(data_to_send_downstream); + ASSERT_TRUE(fake_upstream_connection_->write(data_to_send_downstream)); payload_reader_->set_data_to_wait_for(data_to_send_downstream); ssl_client_->dispatcher().run(Event::Dispatcher::RunType::Block); @@ -413,9 +424,9 @@ void TcpProxySslIntegrationTest::sendAndReceiveTlsData(const std::string& data_t Buffer::OwnedImpl empty_buffer; ssl_client_->write(empty_buffer, true); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); - fake_upstream_connection_->waitForHalfClose(); - fake_upstream_connection_->write("", true); - fake_upstream_connection_->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection_->waitForHalfClose()); + ASSERT_TRUE(fake_upstream_connection_->write("", true)); + ASSERT_TRUE(fake_upstream_connection_->waitForDisconnect()); ssl_client_->dispatcher().run(Event::Dispatcher::RunType::Block); EXPECT_TRUE(payload_reader_->readLastByte()); EXPECT_TRUE(connect_callbacks_.closed()); @@ -438,15 +449,15 @@ TEST_P(TcpProxySslIntegrationTest, DownstreamHalfClose) { Buffer::OwnedImpl empty_buffer; ssl_client_->write(empty_buffer, true); - fake_upstream_connection_->waitForHalfClose(); + ASSERT_TRUE(fake_upstream_connection_->waitForHalfClose()); const std::string data("data"); - fake_upstream_connection_->write(data, false); + ASSERT_TRUE(fake_upstream_connection_->write(data, false)); payload_reader_->set_data_to_wait_for(data); ssl_client_->dispatcher().run(Event::Dispatcher::RunType::Block); EXPECT_FALSE(payload_reader_->readLastByte()); - fake_upstream_connection_->write("", true); + ASSERT_TRUE(fake_upstream_connection_->write("", true)); ssl_client_->dispatcher().run(Event::Dispatcher::RunType::Block); EXPECT_TRUE(payload_reader_->readLastByte()); } @@ -455,7 +466,7 @@ TEST_P(TcpProxySslIntegrationTest, DownstreamHalfClose) { TEST_P(TcpProxySslIntegrationTest, UpstreamHalfClose) { setupConnections(); - fake_upstream_connection_->write("", true); + ASSERT_TRUE(fake_upstream_connection_->write("", true)); ssl_client_->dispatcher().run(Event::Dispatcher::RunType::Block); EXPECT_TRUE(payload_reader_->readLastByte()); EXPECT_FALSE(connect_callbacks_.closed()); @@ -466,14 +477,14 @@ TEST_P(TcpProxySslIntegrationTest, UpstreamHalfClose) { while (client_write_buffer_->bytes_drained() != val.size()) { dispatcher_->run(Event::Dispatcher::RunType::NonBlock); } - fake_upstream_connection_->waitForData(val.size()); + ASSERT_TRUE(fake_upstream_connection_->waitForData(val.size())); Buffer::OwnedImpl empty_buffer; ssl_client_->write(empty_buffer, true); while (!connect_callbacks_.closed()) { dispatcher_->run(Event::Dispatcher::RunType::NonBlock); } - fake_upstream_connection_->waitForHalfClose(); + ASSERT_TRUE(fake_upstream_connection_->waitForHalfClose()); } } // namespace diff --git a/test/integration/uds_integration_test.cc b/test/integration/uds_integration_test.cc index fcc20a6178070..7519b5bc1a175 100644 --- a/test/integration/uds_integration_test.cc +++ b/test/integration/uds_integration_test.cc @@ -54,7 +54,7 @@ void UdsListenerIntegrationTest::initialize() { admin_addr->mutable_pipe()->set_path(getAdminSocketName()); auto* listeners = bootstrap.mutable_static_resources()->mutable_listeners(); - RELEASE_ASSERT(listeners->size() > 0); + RELEASE_ASSERT(listeners->size() > 0, ""); auto filter_chains = listeners->Get(0).filter_chains(); listeners->Clear(); auto* listener = listeners->Add(); diff --git a/test/integration/uds_integration_test.h b/test/integration/uds_integration_test.h index 2905ce29e3654..478f2127e3172 100644 --- a/test/integration/uds_integration_test.h +++ b/test/integration/uds_integration_test.h @@ -14,6 +14,7 @@ #include "gtest/gtest.h" namespace Envoy { + class UdsUpstreamIntegrationTest : public HttpIntegrationTest, public testing::TestWithParam> { @@ -23,8 +24,9 @@ class UdsUpstreamIntegrationTest abstract_namespace_(std::get<1>(GetParam())) {} void createUpstreams() override { - fake_upstreams_.emplace_back( - new FakeUpstream(getSocketName(), FakeHttpConnection::Type::HTTP1)); + fake_upstreams_.emplace_back(new FakeUpstream( + TestEnvironment::unixDomainSocketPath("udstest.1.sock", abstract_namespace_), + FakeHttpConnection::Type::HTTP1)); config_helper_.addConfigModifier( [&](envoy::config::bootstrap::v2::Bootstrap& bootstrap) -> void { @@ -33,17 +35,13 @@ class UdsUpstreamIntegrationTest auto* cluster = static_resources->mutable_clusters(i); for (int j = 0; j < cluster->hosts_size(); ++j) { cluster->mutable_hosts(j)->clear_socket_address(); - cluster->mutable_hosts(j)->mutable_pipe()->set_path(getSocketName()); + cluster->mutable_hosts(j)->mutable_pipe()->set_path( + TestEnvironment::unixDomainSocketPath("udstest.1.sock", abstract_namespace_)); } } }); } - std::string getSocketName() { - return abstract_namespace_ ? "@/my/udstest" - : TestEnvironment::unixDomainSocketPath("udstest.1.sock"); - } - protected: const bool abstract_namespace_; }; @@ -58,17 +56,13 @@ class UdsListenerIntegrationTest void initialize() override; - std::string getSocketName(const std::string& path) { - const std::string name = TestEnvironment::unixDomainSocketPath(path); - if (!abstract_namespace_) { - return name; - } - return "@" + name; + std::string getAdminSocketName() { + return TestEnvironment::unixDomainSocketPath("admin.sock", abstract_namespace_); } - std::string getAdminSocketName() { return getSocketName("admin.sock"); } - - std::string getListenerSocketName() { return getSocketName("listener_0.sock"); } + std::string getListenerSocketName() { + return TestEnvironment::unixDomainSocketPath("listener_0.sock", abstract_namespace_); + } protected: HttpIntegrationTest::ConnectionCreationFunction createConnectionFn(); diff --git a/test/integration/utility.cc b/test/integration/utility.cc index 0271c2af6b450..ffae1f6701516 100644 --- a/test/integration/utility.cc +++ b/test/integration/utility.cc @@ -36,13 +36,15 @@ void BufferingStreamDecoder::decodeHeaders(Http::HeaderMapPtr&& headers, bool en void BufferingStreamDecoder::decodeData(Buffer::Instance& data, bool end_stream) { ASSERT(!complete_); complete_ = end_stream; - body_.append(TestUtility::bufferToString(data)); + body_.append(data.toString()); if (complete_) { onComplete(); } } -void BufferingStreamDecoder::decodeTrailers(Http::HeaderMapPtr&&) { NOT_IMPLEMENTED; } +void BufferingStreamDecoder::decodeTrailers(Http::HeaderMapPtr&&) { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; +} void BufferingStreamDecoder::onComplete() { ASSERT(complete_); @@ -125,7 +127,7 @@ WaitForPayloadReader::WaitForPayloadReader(Event::Dispatcher& dispatcher) : dispatcher_(dispatcher) {} Network::FilterStatus WaitForPayloadReader::onData(Buffer::Instance& data, bool end_stream) { - data_.append(TestUtility::bufferToString(data)); + data_.append(data.toString()); data.drain(data.length()); read_end_stream_ = end_stream; if ((!data_to_wait_for_.empty() && data_.find(data_to_wait_for_) == 0) || diff --git a/test/integration/websocket_integration_test.cc b/test/integration/websocket_integration_test.cc index c06440d66e141..c3b4289a641fb 100644 --- a/test/integration/websocket_integration_test.cc +++ b/test/integration/websocket_integration_test.cc @@ -22,14 +22,27 @@ namespace { bool headersRead(const std::string& data) { return data.find("\r\n\r\n") != std::string::npos; } +static std::string websocketTestParamsToString( + const testing::TestParamInfo> params) { + return absl::StrCat(std::get<0>(params.param) == Network::Address::IpVersion::v4 ? "IPv4" + : "IPv6", + "_", std::get<1>(params.param) == true ? "OldStyle" : "NewStyle"); +} + } // namespace INSTANTIATE_TEST_CASE_P(IpVersions, WebsocketIntegrationTest, - testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), - TestUtility::ipTestParamsToString); + testing::Combine(testing::ValuesIn(TestEnvironment::getIpVersionsForTest()), + testing::Bool()), + websocketTestParamsToString); ConfigHelper::HttpModifierFunction -setRouteUsingWebsocket(const envoy::api::v2::route::RouteAction::WebSocketProxyConfig* ws_config) { +setRouteUsingWebsocket(const envoy::api::v2::route::RouteAction::WebSocketProxyConfig* ws_config, + bool old_style) { + if (!old_style) { + return [](envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& + hcm) { hcm.add_upgrade_configs()->set_upgrade_type("websocket"); }; + } return [ws_config]( envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) { @@ -46,15 +59,19 @@ setRouteUsingWebsocket(const envoy::api::v2::route::RouteAction::WebSocketProxyC } void WebsocketIntegrationTest::initialize() { - // Set a less permissive default route so it does not pick up the /websocket query. - config_helper_.setDefaultHostAndRoute("*", "/asd"); + if (old_style_websockets_) { + // Set a less permissive default route so it does not pick up the /websocket query. + config_helper_.setDefaultHostAndRoute("*", "/asd"); + } HttpIntegrationTest::initialize(); } void WebsocketIntegrationTest::validateInitialUpstreamData(const std::string& received_data) { - // The request path gets rewritten from /websocket/test to /websocket. - // The size of headers received by the destination is 228 bytes. - EXPECT_EQ(received_data.size(), 228); + if (old_style_websockets_) { + // The request path gets rewritten from /websocket/test to /websocket. + // The size of headers received by the destination is 228 bytes. + EXPECT_EQ(received_data.size(), 228); + } // In HTTP1, the transfer-length is defined by use of the "chunked" transfer-coding, even if // content-length header is present. No body websocket upgrade request send to upstream has // content-length header and has no transfer-encoding header. @@ -62,17 +79,39 @@ void WebsocketIntegrationTest::validateInitialUpstreamData(const std::string& re EXPECT_EQ(received_data.find("transfer-encoding:"), std::string::npos); } -void WebsocketIntegrationTest::validateInitialDownstreamData(const std::string& received_data) { - ASSERT_EQ(received_data, upgrade_resp_str_); +void WebsocketIntegrationTest::validateInitialDownstreamData(const std::string& received_data, + const std::string& expected_data) { + if (old_style_websockets_) { + ASSERT_EQ(expected_data, received_data); + } else { + // Strip out the date header since we're not going to generate an exact match. + std::regex extra_request_headers("date:.*\r\nserver: envoy\r\n"); + std::string stripped_data = std::regex_replace(received_data, extra_request_headers, ""); + EXPECT_EQ(expected_data, stripped_data); + } } void WebsocketIntegrationTest::validateFinalDownstreamData(const std::string& received_data, const std::string& expected_data) { - EXPECT_EQ(received_data, expected_data); + if (old_style_websockets_) { + EXPECT_EQ(received_data, expected_data); + } else { + // Strip out the date header since we're not going to generate an exact match. + std::regex extra_request_headers("date:.*\r\nserver: envoy\r\n"); + std::string stripped_data = std::regex_replace(received_data, extra_request_headers, ""); + EXPECT_EQ(expected_data, stripped_data); + } +} + +void WebsocketIntegrationTest::validateFinalUpstreamData(const std::string& received_data, + const std::string& expected_data) { + std::regex extra_response_headers("x-request-id:.*\r\n"); + std::string stripped_data = std::regex_replace(received_data, extra_response_headers, ""); + EXPECT_EQ(expected_data, stripped_data); } TEST_P(WebsocketIntegrationTest, WebSocketConnectionDownstreamDisconnect) { - config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr)); + config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr, old_style_websockets_)); initialize(); // WebSocket upgrade, send some data and disconnect downstream @@ -82,34 +121,54 @@ TEST_P(WebsocketIntegrationTest, WebSocketConnectionDownstreamDisconnect) { tcp_client = makeTcpConnection(lookupPort("http")); // Send websocket upgrade request tcp_client->write(upgrade_req_str_); - test_server_->waitForCounterGe("tcp.websocket.downstream_cx_total", 1); - fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - const std::string data = fake_upstream_connection->waitForData(&headersRead); + if (old_style_websockets_) { + test_server_->waitForCounterGe("tcp.websocket.downstream_cx_total", 1); + } + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(&headersRead, &data)); validateInitialUpstreamData(data); // Accept websocket upgrade request - fake_upstream_connection->write(upgrade_resp_str_); + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str_)); tcp_client->waitForData("\r\n\r\n", false); - validateInitialDownstreamData(tcp_client->data()); + validateInitialDownstreamData(tcp_client->data(), downstreamRespStr()); // Standard TCP proxy semantics post upgrade tcp_client->write("hello"); - fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("hello")); - fake_upstream_connection->write("world"); + ASSERT_TRUE( + fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("hello"))); + ASSERT_TRUE(fake_upstream_connection->write("world")); tcp_client->waitForData("world", false); tcp_client->write("bye!"); // downstream disconnect tcp_client->close(); - fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("bye")); - fake_upstream_connection->waitForDisconnect(); - - validateFinalDownstreamData(tcp_client->data(), upgrade_resp_str_ + "world"); + std::string final_data; + ASSERT_TRUE(fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("bye"), + &final_data)); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); + + validateFinalDownstreamData(tcp_client->data(), downstreamRespStr() + "world"); + + if (old_style_websockets_) { + return; + } + + const std::string upstream_payload = "GET /websocket/test HTTP/1.1\r\n" + "host: host\r\n" + "connection: keep-alive, Upgrade\r\n" + "upgrade: websocket\r\n" + "content-length: 0\r\n" + "x-forwarded-proto: http\r\n" + "x-envoy-expected-rq-timeout-ms: 15000\r\n\r\n" + "hellobye!"; + validateFinalUpstreamData(final_data, upstream_payload); } TEST_P(WebsocketIntegrationTest, WebSocketConnectionUpstreamDisconnect) { - config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr)); + config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr, old_style_websockets_)); initialize(); // WebSocket upgrade, send some data and disconnect upstream @@ -118,34 +177,35 @@ TEST_P(WebsocketIntegrationTest, WebSocketConnectionUpstreamDisconnect) { tcp_client = makeTcpConnection(lookupPort("http")); // Send websocket upgrade request tcp_client->write(upgrade_req_str_); - fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - ASSERT_TRUE(fake_upstream_connection != nullptr); - const std::string data = fake_upstream_connection->waitForData(&headersRead); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(&headersRead, &data)); validateInitialUpstreamData(data); // Accept websocket upgrade request - fake_upstream_connection->write(upgrade_resp_str_); + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str_)); tcp_client->waitForData("\r\n\r\n", false); - validateInitialDownstreamData(tcp_client->data()); + validateInitialDownstreamData(tcp_client->data(), downstreamRespStr()); // Standard TCP proxy semantics post upgrade tcp_client->write("hello"); - fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("hello")); + ASSERT_TRUE( + fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("hello"))); - fake_upstream_connection->write("world"); + ASSERT_TRUE(fake_upstream_connection->write("world")); // upstream disconnect - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); tcp_client->waitForData("world", false); tcp_client->waitForDisconnect(); ASSERT(!fake_upstream_connection->connected()); - validateFinalDownstreamData(tcp_client->data(), upgrade_resp_str_ + "world"); + validateFinalDownstreamData(tcp_client->data(), downstreamRespStr() + "world"); } TEST_P(WebsocketIntegrationTest, EarlyData) { - config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr)); + config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr, old_style_websockets_)); initialize(); // WebSocket upgrade with early data (HTTP body) @@ -160,24 +220,25 @@ TEST_P(WebsocketIntegrationTest, EarlyData) { tcp_client = makeTcpConnection(lookupPort("http")); // Send early data alongside websocket upgrade request tcp_client->write(upgrade_req_str + early_data_req_str); - fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); // Wait for both the upgrade, and the early data. - const std::string data = fake_upstream_connection->waitForData( - FakeRawConnection::waitForInexactMatch(early_data_req_str.c_str())); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData( + FakeRawConnection::waitForInexactMatch(early_data_req_str.c_str()), &data)); // We expect to find the early data on the upstream side EXPECT_TRUE(StringUtil::endsWith(data, early_data_req_str)); // Accept websocket upgrade request - fake_upstream_connection->write(upgrade_resp_str_); + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str_)); // Reply also with early data - fake_upstream_connection->write(early_data_resp_str); + ASSERT_TRUE(fake_upstream_connection->write(early_data_resp_str)); // upstream disconnect - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); tcp_client->waitForData(early_data_resp_str, false); tcp_client->waitForDisconnect(); - validateFinalDownstreamData(tcp_client->data(), upgrade_resp_str_ + "world"); + validateFinalDownstreamData(tcp_client->data(), downstreamRespStr() + "world"); } TEST_P(WebsocketIntegrationTest, WebSocketConnectionIdleTimeout) { @@ -185,7 +246,18 @@ TEST_P(WebsocketIntegrationTest, WebSocketConnectionIdleTimeout) { ws_config.mutable_idle_timeout()->set_nanos( std::chrono::duration_cast(std::chrono::milliseconds(100)).count()); *ws_config.mutable_stat_prefix() = "my-stat-prefix"; - config_helper_.addConfigModifier(setRouteUsingWebsocket(&ws_config)); + config_helper_.addConfigModifier(setRouteUsingWebsocket(&ws_config, old_style_websockets_)); + if (!old_style_websockets_) { + config_helper_.addConfigModifier( + [&](envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) + -> void { + auto* route_config = hcm.mutable_route_config(); + auto* virtual_host = route_config->mutable_virtual_hosts(0); + auto* route = virtual_host->mutable_routes(0)->mutable_route(); + route->mutable_idle_timeout()->set_seconds(0); + route->mutable_idle_timeout()->set_nanos(200 * 1000 * 1000); + }); + } initialize(); // WebSocket upgrade, send some data and disconnect downstream @@ -196,33 +268,39 @@ TEST_P(WebsocketIntegrationTest, WebSocketConnectionIdleTimeout) { // The request path gets rewritten from /websocket/test to /websocket. // The size of headers received by the destination is 228 bytes. tcp_client->write(upgrade_req_str_); - fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - const std::string data = fake_upstream_connection->waitForData(&headersRead); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(&headersRead, &data)); validateInitialUpstreamData(data); // Accept websocket upgrade request - fake_upstream_connection->write(upgrade_resp_str_); + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str_)); tcp_client->waitForData("\r\n\r\n", false); - validateInitialDownstreamData(tcp_client->data()); + validateInitialDownstreamData(tcp_client->data(), downstreamRespStr()); // Standard TCP proxy semantics post upgrade tcp_client->write("hello"); tcp_client->write("hello"); - fake_upstream_connection->write("world"); + ASSERT_TRUE(fake_upstream_connection->write("world")); tcp_client->waitForData("world", false); - test_server_->waitForCounterGe("tcp.my-stat-prefix.idle_timeout", 1); + if (old_style_websockets_) { + test_server_->waitForCounterGe("tcp.my-stat-prefix.idle_timeout", 1); + } else { + test_server_->waitForCounterGe("http.config_test.downstream_rq_idle_timeout", 1); + } tcp_client->waitForDisconnect(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); } TEST_P(WebsocketIntegrationTest, WebSocketLogging) { + if (!old_style_websockets_) + return; envoy::api::v2::route::RouteAction::WebSocketProxyConfig ws_config; ws_config.mutable_idle_timeout()->set_nanos( std::chrono::duration_cast(std::chrono::milliseconds(100)).count()); *ws_config.mutable_stat_prefix() = "my-stat-prefix"; - config_helper_.addConfigModifier(setRouteUsingWebsocket(&ws_config)); - + config_helper_.addConfigModifier(setRouteUsingWebsocket(&ws_config, old_style_websockets_)); std::string expected_log_template = "bytes_sent={0} " "bytes_received={1} " "downstream_local_address={2} " @@ -263,20 +341,21 @@ TEST_P(WebsocketIntegrationTest, WebSocketLogging) { // The request path gets rewritten from /websocket/test to /websocket. // The size of headers received by the destination is 228 bytes. tcp_client->write(upgrade_req_str_); - fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); - const std::string data = fake_upstream_connection->waitForData(228); + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(228, &data)); // Accept websocket upgrade request - fake_upstream_connection->write(upgrade_resp_str_); + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str_)); tcp_client->waitForData(upgrade_resp_str_); // Standard TCP proxy semantics post upgrade tcp_client->write("hello"); // datalen = 228 + strlen(hello) - fake_upstream_connection->waitForData(233); - fake_upstream_connection->write("world"); + ASSERT_TRUE(fake_upstream_connection->waitForData(233)); + ASSERT_TRUE(fake_upstream_connection->write("world")); tcp_client->waitForData(upgrade_resp_str_ + "world"); - fake_upstream_connection->close(); - fake_upstream_connection->waitForDisconnect(); + ASSERT_TRUE(fake_upstream_connection->close()); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); tcp_client->waitForDisconnect(); tcp_client->close(); @@ -296,4 +375,217 @@ TEST_P(WebsocketIntegrationTest, WebSocketLogging) { ip_port_regex, ip_port_regex, ip_port_regex))); } +// Technically not a websocket tests, but verfies normal upgrades have parity +// with websocket upgrades +TEST_P(WebsocketIntegrationTest, NonWebsocketUpgrade) { + if (old_style_websockets_) { + return; + } + config_helper_.addConfigModifier( + [&](envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) + -> void { + auto* foo_upgrade = hcm.add_upgrade_configs(); + foo_upgrade->set_upgrade_type("foo"); + }); + + config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr, old_style_websockets_)); + initialize(); + + const std::string upgrade_req_str = "GET / HTTP/1.1\r\nHost: host\r\nConnection: " + "keep-alive, Upgrade\r\nUpgrade: foo\r\n\r\n"; + const std::string upgrade_resp_str = + "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: foo\r\n\r\n"; + + // Upgrade, send some data and disconnect downstream + IntegrationTcpClientPtr tcp_client; + FakeRawConnectionPtr fake_upstream_connection; + + tcp_client = makeTcpConnection(lookupPort("http")); + // Send websocket upgrade request + // The size of headers received by the destination is 228 bytes. + tcp_client->write(upgrade_req_str); + if (old_style_websockets_) { + test_server_->waitForCounterGe("tcp.websocket.downstream_cx_total", 1); + } + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + std::string data; + ASSERT_TRUE(fake_upstream_connection->waitForData(&headersRead, &data)); + validateInitialUpstreamData(data); + + // Accept websocket upgrade request + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str)); + tcp_client->waitForData("\r\n\r\n", false); + if (old_style_websockets_) { + ASSERT_EQ(tcp_client->data(), upgrade_resp_str); + } + // Standard TCP proxy semantics post upgrade + tcp_client->write("hello"); + + ASSERT_TRUE( + fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("hello"))); + ASSERT_TRUE(fake_upstream_connection->write("world")); + tcp_client->waitForData("world", false); + tcp_client->write("bye!"); + + // downstream disconnect + tcp_client->close(); + std::string final_data; + ASSERT_TRUE(fake_upstream_connection->waitForData(FakeRawConnection::waitForInexactMatch("bye"), + &final_data)); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); + + const std::string modified_upgrade_resp_str = "HTTP/1.1 101 Switching Protocols\r\nconnection: " + "Upgrade\r\nupgrade: foo\r\ncontent-length: " + "0\r\n\r\n"; + validateFinalDownstreamData(tcp_client->data(), modified_upgrade_resp_str + "world"); + const std::string upstream_payload = "GET / HTTP/1.1\r\n" + "host: host\r\n" + "connection: keep-alive, Upgrade\r\n" + "upgrade: foo\r\n" + "content-length: 0\r\n" + "x-forwarded-proto: http\r\n" + "x-envoy-expected-rq-timeout-ms: 15000\r\n\r\n" + "hellobye!"; + + std::regex extra_response_headers("x-request-id:.*\r\n"); + std::string stripped_data = std::regex_replace(final_data, extra_response_headers, ""); + EXPECT_EQ(upstream_payload, stripped_data); +} + +TEST_P(WebsocketIntegrationTest, WebsocketCustomFilterChain) { + config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr, old_style_websockets_)); + if (old_style_websockets_) { + return; + } + + // Add a small buffer filter to the standard HTTP filter chain. Websocket + // upgrades will use the HTTP filter chain so will also have small buffers. + config_helper_.addFilter(ConfigHelper::SMALL_BUFFER_FILTER); + + // Add a second upgrade type which goes directly to the router filter. + config_helper_.addConfigModifier( + [&](envoy::config::filter::network::http_connection_manager::v2::HttpConnectionManager& hcm) + -> void { + auto* foo_upgrade = hcm.add_upgrade_configs(); + foo_upgrade->set_upgrade_type("foo"); + auto* filter_list_back = foo_upgrade->add_filters(); + const std::string json = + Json::Factory::loadFromYamlString("name: envoy.router")->asJsonString(); + MessageUtil::loadFromJson(json, *filter_list_back); + }); + initialize(); + + // Websocket upgrades are configured to disallow large payload. + const std::string early_data_req_str(2048, 'a'); + { + const std::string upgrade_req_str = + fmt::format("GET /websocket/test HTTP/1.1\r\nHost: host\r\nConnection: " + "keep-alive, Upgrade\r\nUpgrade: websocket\r\nContent-Length: {}\r\n\r\n", + early_data_req_str.length()); + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("http")); + tcp_client->write(upgrade_req_str + early_data_req_str); + tcp_client->waitForData("\r\n\r\n", false); + EXPECT_NE(tcp_client->data().find("413"), std::string::npos); + tcp_client->waitForDisconnect(true); + } + + // HTTP requests are configured to disallow large bodies. + { + const std::string upgrade_req_str = fmt::format("GET / HTTP/1.1\r\nHost: host\r\nConnection: " + "keep-alive\r\nContent-Length: {}\r\n\r\n", + early_data_req_str.length()); + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("http")); + tcp_client->write(upgrade_req_str + early_data_req_str); + tcp_client->waitForData("\r\n\r\n", false); + EXPECT_NE(tcp_client->data().find("413"), std::string::npos); + tcp_client->waitForDisconnect(true); + } + + // Foo upgrades are configured without the buffer filter, so should explicitly + // allow large payload. + { + const std::string upgrade_req_str = + fmt::format("GET /websocket/test HTTP/1.1\r\nHost: host\r\nConnection: " + "keep-alive, Upgrade\r\nUpgrade: foo\r\nContent-Length: {}\r\n\r\n", + early_data_req_str.length()); + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("http")); + tcp_client->write(upgrade_req_str + early_data_req_str); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + // Make sure the full payload arrives. + ASSERT_TRUE(fake_upstream_connection->waitForData( + FakeRawConnection::waitForInexactMatch(early_data_req_str.c_str()))); + // Tear down all the connections cleanly. + tcp_client->close(); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); + } +} + +TEST_P(WebsocketIntegrationTest, BidirectionalChunkedData) { + config_helper_.addConfigModifier(setRouteUsingWebsocket(nullptr, old_style_websockets_)); + initialize(); + const std::string upgrade_req_str = "GET /websocket/test HTTP/1.1\r\nHost: host\r\nconnection: " + "keep-alive, Upgrade\r\nupgrade: Websocket\r\n" + "transfer-encoding: chunked\r\n\r\n" + "3\r\n123\r\n0\r\n\r\n" + "SomeWebSocketPayload"; + + // Upgrade, send initial data and wait for it to be received. + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("http")); + tcp_client->write(upgrade_req_str); + FakeRawConnectionPtr fake_upstream_connection; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection)); + ASSERT_TRUE(fake_upstream_connection->waitForData( + FakeRawConnection::waitForInexactMatch("SomeWebSocketPayload"))); + + // Finish the upgrade. + const std::string upgrade_resp_str = + "HTTP/1.1 101 Switching Protocols\r\nconnection: Upgrade\r\nupgrade: Websocket\r\n" + "transfer-encoding: chunked\r\n\r\n" + "4\r\nabcd\r\n0\r\n\r\n" + "SomeWebsocketResponsePayload"; + ASSERT_TRUE(fake_upstream_connection->write(upgrade_resp_str)); + tcp_client->waitForData("SomeWebsocketResponsePayload", false); + + // Verify bidirectional data still works. + tcp_client->write("FinalClientPayload"); + std::string final_data; + ASSERT_TRUE(fake_upstream_connection->waitForData( + FakeRawConnection::waitForInexactMatch("FinalClientPayload"), &final_data)); + ASSERT_TRUE(fake_upstream_connection->write("FinalServerPayload")); + tcp_client->waitForData("FinalServerPayload", false); + + // Clean up. + tcp_client->close(); + ASSERT_TRUE(fake_upstream_connection->waitForDisconnect()); + + const std::string modified_upstream_payload = + "GET /websocket/test HTTP/1.1\r\n" + "host: host\r\n" + "connection: keep-alive, Upgrade\r\n" + "upgrade: Websocket\r\n" + "x-forwarded-proto: http\r\n" + "x-envoy-expected-rq-timeout-ms: 15000\r\n" + "transfer-encoding: chunked\r\n\r\n" + "3\r\n123\r\n0\r\n\r\nSomeWebSocketPayloadFinalClientPayload"; + const std::string old_style_modified_payload = + "GET /websocket HTTP/1.1\r\n" + "host: host\r\n" + "connection: keep-alive, Upgrade\r\n" + "upgrade: Websocket\r\n" + "x-forwarded-proto: http\r\n" + "x-envoy-original-path: /websocket/test\r\n" + "transfer-encoding: chunked\r\n\r\n" + "3\r\n123\r\n0\r\n\r\nSomeWebSocketPayloadFinalClientPayload"; + validateFinalUpstreamData(final_data, old_style_websockets_ ? old_style_modified_payload + : modified_upstream_payload); + + const std::string modified_downstream_payload = + "HTTP/1.1 101 Switching Protocols\r\nconnection: Upgrade\r\nupgrade: Websocket\r\n" + "transfer-encoding: chunked\r\n\r\n" + "4\r\nabcd\r\n0\r\n\r\n" + "SomeWebsocketResponsePayloadFinalServerPayload"; + validateFinalDownstreamData(tcp_client->data(), modified_downstream_payload); +} + } // namespace Envoy diff --git a/test/integration/websocket_integration_test.h b/test/integration/websocket_integration_test.h index d3ac3cb94f75d..3dd7a34b5039d 100644 --- a/test/integration/websocket_integration_test.h +++ b/test/integration/websocket_integration_test.h @@ -6,22 +6,36 @@ namespace Envoy { -class WebsocketIntegrationTest : public HttpIntegrationTest, - public testing::TestWithParam { +class WebsocketIntegrationTest + : public HttpIntegrationTest, + public testing::TestWithParam> { public: void initialize() override; - WebsocketIntegrationTest() : HttpIntegrationTest(Http::CodecClient::Type::HTTP1, GetParam()) {} + WebsocketIntegrationTest() + : HttpIntegrationTest(Http::CodecClient::Type::HTTP1, std::get<0>(GetParam())) {} + bool old_style_websockets_{std::get<1>(GetParam())}; protected: void validateInitialUpstreamData(const std::string& received_data); - void validateInitialDownstreamData(const std::string& received_data); + void validateInitialDownstreamData(const std::string& received_data, + const std::string& expected_data); void validateFinalDownstreamData(const std::string& received_data, const std::string& expected_data); + void validateFinalUpstreamData(const std::string& received_data, + const std::string& expected_data); + + const std::string& downstreamRespStr() { + return old_style_websockets_ ? upgrade_resp_str_ : modified_upgrade_resp_str_; + } const std::string upgrade_req_str_ = "GET /websocket/test HTTP/1.1\r\nHost: host\r\nConnection: " "keep-alive, Upgrade\r\nUpgrade: websocket\r\n\r\n"; const std::string upgrade_resp_str_ = "HTTP/1.1 101 Switching Protocols\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n"; + + const std::string modified_upgrade_resp_str_ = "HTTP/1.1 101 Switching Protocols\r\nconnection: " + "Upgrade\r\nupgrade: websocket\r\ncontent-length: " + "0\r\n\r\n"; }; } // namespace Envoy diff --git a/test/integration/xfcc_integration_test.cc b/test/integration/xfcc_integration_test.cc index f0d4f70eda79d..a2d55546a25e1 100644 --- a/test/integration/xfcc_integration_test.cc +++ b/test/integration/xfcc_integration_test.cc @@ -31,6 +31,8 @@ void XfccIntegrationTest::TearDown() { client_tls_ssl_ctx_.reset(); fake_upstream_connection_.reset(); fake_upstreams_.clear(); + HttpIntegrationTest::cleanupUpstreamAndDownstream(); + codec_client_.reset(); context_manager_.reset(); runtime_.reset(); } @@ -147,9 +149,9 @@ void XfccIntegrationTest::testRequestAndResponseWithXfccHeader(std::string previ codec_client_ = makeHttpConnection(std::move(conn)); auto response = codec_client_->makeHeaderOnlyRequest(header_map); - fake_upstream_connection_ = fake_upstreams_[0]->waitForHttpConnection(*dispatcher_); - upstream_request_ = fake_upstream_connection_->waitForNewStream(*dispatcher_); - upstream_request_->waitForEndStream(*dispatcher_); + ASSERT_TRUE(fake_upstreams_[0]->waitForHttpConnection(*dispatcher_, fake_upstream_connection_)); + ASSERT_TRUE(fake_upstream_connection_->waitForNewStream(*dispatcher_, upstream_request_)); + ASSERT_TRUE(upstream_request_->waitForEndStream(*dispatcher_)); if (expected_xfcc.empty()) { EXPECT_EQ(nullptr, upstream_request_->headers().ForwardedClientCert()); } else { diff --git a/test/main.cc b/test/main.cc index cc00c32f7332b..34d2254070e8d 100644 --- a/test/main.cc +++ b/test/main.cc @@ -2,6 +2,8 @@ #include "test/test_common/environment.h" #include "test/test_runner.h" +#include "absl/debugging/symbolize.h" + #ifdef ENVOY_HANDLE_SIGNALS #include "exe/signal_action.h" #endif @@ -14,6 +16,9 @@ const char* __asan_default_options() { // The main entry point (and the rest of this file) should have no logic in it, // this allows overriding by site specific versions of main.cc. int main(int argc, char** argv) { +#ifndef __APPLE__ + absl::InitializeSymbolizer(argv[0]); +#endif #ifdef ENVOY_HANDLE_SIGNALS // Enabled by default. Control with "bazel --define=signal_trace=disabled" Envoy::SignalAction handle_sigs; diff --git a/test/mocks/api/mocks.h b/test/mocks/api/mocks.h index 37f76a896f4f3..78bef30632e59 100644 --- a/test/mocks/api/mocks.h +++ b/test/mocks/api/mocks.h @@ -48,6 +48,7 @@ class MockOsSysCalls : public OsSysCallsImpl { int getsockopt(int sockfd, int level, int optname, void* optval, socklen_t* optlen) override; MOCK_METHOD3(bind, int(int sockfd, const sockaddr* addr, socklen_t addrlen)); + MOCK_METHOD3(ioctl, int(int sockfd, unsigned long int request, void* argp)); MOCK_METHOD1(close, int(int)); MOCK_METHOD3(open_, int(const std::string& full_path, int flags, int mode)); MOCK_METHOD3(write_, ssize_t(int, const void*, size_t)); @@ -64,6 +65,7 @@ class MockOsSysCalls : public OsSysCallsImpl { int(int sockfd, int level, int optname, const void* optval, socklen_t optlen)); MOCK_METHOD5(getsockopt_, int(int sockfd, int level, int optname, void* optval, socklen_t* optlen)); + MOCK_METHOD3(socket, int(int domain, int type, int protocol)); size_t num_writes_; size_t num_open_; diff --git a/test/mocks/buffer/mocks.h b/test/mocks/buffer/mocks.h index a7c50c6976cd3..59cdae2902ec9 100644 --- a/test/mocks/buffer/mocks.h +++ b/test/mocks/buffer/mocks.h @@ -16,7 +16,7 @@ template class MockBufferBase : public BaseClass { MockBufferBase(); MockBufferBase(std::function below_low, std::function above_high); - MOCK_METHOD1(write, int(int fd)); + MOCK_METHOD1(write, Api::SysCallResult(int fd)); MOCK_METHOD1(move, void(Buffer::Instance& rhs)); MOCK_METHOD2(move, void(Buffer::Instance& rhs, uint64_t length)); MOCK_METHOD1(drain, void(uint64_t size)); @@ -24,12 +24,12 @@ template class MockBufferBase : public BaseClass { void baseMove(Buffer::Instance& rhs) { BaseClass::move(rhs); } void baseDrain(uint64_t size) { BaseClass::drain(size); } - int trackWrites(int fd) { - int bytes_written = BaseClass::write(fd); - if (bytes_written > 0) { - bytes_written_ += bytes_written; + Api::SysCallResult trackWrites(int fd) { + Api::SysCallResult result = BaseClass::write(fd); + if (result.rc_ > 0) { + bytes_written_ += result.rc_; } - return bytes_written; + return result; } void trackDrains(uint64_t size) { @@ -38,10 +38,7 @@ template class MockBufferBase : public BaseClass { } // A convenience function to invoke on write() which fails the write with EAGAIN. - int failWrite(int) { - errno = EAGAIN; - return -1; - } + Api::SysCallResult failWrite(int) { return {-1, EAGAIN}; } int bytes_written() const { return bytes_written_; } uint64_t bytes_drained() const { return bytes_drained_; } @@ -99,19 +96,20 @@ MATCHER_P(BufferEqual, rhs, testing::PrintToString(*rhs)) { } MATCHER_P(BufferStringEqual, rhs, rhs) { - *result_listener << "\"" << TestUtility::bufferToString(arg) << "\""; + *result_listener << "\"" << arg.toString() << "\""; Buffer::OwnedImpl buffer(rhs); return TestUtility::buffersEqual(arg, buffer); } ACTION_P(AddBufferToString, target_string) { - target_string->append(TestUtility::bufferToString(arg0)); + auto bufferToString = [](const Buffer::OwnedImpl& buf) -> std::string { return buf.toString(); }; + target_string->append(bufferToString(arg0)); arg0.drain(arg0.length()); } ACTION_P(AddBufferToStringWithoutDraining, target_string) { - target_string->append(TestUtility::bufferToString(arg0)); + target_string->append(arg0.toString()); } } // namespace Envoy diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index f482f1f05a5c0..e5e820d4e0368 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -78,6 +78,7 @@ class MockConnection : public Connection, public MockConnectionBase { MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats)); MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); + MOCK_CONST_METHOD0(requestedServerName, absl::string_view()); MOCK_CONST_METHOD0(state, State()); MOCK_METHOD2(write, void(Buffer::Instance& data, bool end_stream)); MOCK_METHOD1(setBufferLimits, void(uint32_t limit)); @@ -117,6 +118,7 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats)); MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); + MOCK_CONST_METHOD0(requestedServerName, absl::string_view()); MOCK_CONST_METHOD0(state, State()); MOCK_METHOD2(write, void(Buffer::Instance& data, bool end_stream)); MOCK_METHOD1(setBufferLimits, void(uint32_t limit)); @@ -407,8 +409,8 @@ class MockResolvedAddress : public Address::Instance { return asString() == other.asString(); } - MOCK_CONST_METHOD1(bind, int(int)); - MOCK_CONST_METHOD1(connect, int(int)); + MOCK_CONST_METHOD1(bind, Api::SysCallResult(int)); + MOCK_CONST_METHOD1(connect, Api::SysCallResult(int)); MOCK_CONST_METHOD0(ip, Address::Ip*()); MOCK_CONST_METHOD1(socket, int(Address::SocketType)); MOCK_CONST_METHOD0(type, Address::Type()); diff --git a/test/mocks/request_info/mocks.h b/test/mocks/request_info/mocks.h index 1df9251000078..cafe6815e5bcc 100644 --- a/test/mocks/request_info/mocks.h +++ b/test/mocks/request_info/mocks.h @@ -16,6 +16,7 @@ class MockRequestInfo : public RequestInfo { // RequestInfo::RequestInfo MOCK_METHOD1(setResponseFlag, void(ResponseFlag response_flag)); + MOCK_CONST_METHOD1(intersectResponseFlags, bool(uint64_t)); MOCK_METHOD1(onUpstreamHostSelected, void(Upstream::HostDescriptionConstSharedPtr host)); MOCK_CONST_METHOD0(startTime, SystemTime()); MOCK_CONST_METHOD0(startTimeMonotonic, MonotonicTime()); @@ -43,7 +44,8 @@ class MockRequestInfo : public RequestInfo { MOCK_CONST_METHOD0(responseCode, absl::optional()); MOCK_METHOD1(addBytesSent, void(uint64_t)); MOCK_CONST_METHOD0(bytesSent, uint64_t()); - MOCK_CONST_METHOD1(getResponseFlag, bool(ResponseFlag)); + MOCK_CONST_METHOD1(hasResponseFlag, bool(ResponseFlag)); + MOCK_CONST_METHOD0(hasAnyResponseFlag, bool()); MOCK_CONST_METHOD0(upstreamHost, Upstream::HostDescriptionConstSharedPtr()); MOCK_METHOD1(setUpstreamLocalAddress, void(const Network::Address::InstanceConstSharedPtr&)); MOCK_CONST_METHOD0(upstreamLocalAddress, const Network::Address::InstanceConstSharedPtr&()); diff --git a/test/mocks/router/mocks.h b/test/mocks/router/mocks.h index 7afb1c654b41e..24b6433706824 100644 --- a/test/mocks/router/mocks.h +++ b/test/mocks/router/mocks.h @@ -46,6 +46,7 @@ class TestCorsPolicy : public CorsPolicy { public: // Router::CorsPolicy const std::list& allowOrigins() const override { return allow_origin_; }; + const std::list& allowOriginRegexes() const override { return allow_origin_regex_; }; const std::string& allowMethods() const override { return allow_methods_; }; const std::string& allowHeaders() const override { return allow_headers_; }; const std::string& exposeHeaders() const override { return expose_headers_; }; @@ -54,6 +55,7 @@ class TestCorsPolicy : public CorsPolicy { bool enabled() const override { return enabled_; }; std::list allow_origin_{}; + std::list allow_origin_regex_{}; std::string allow_methods_{}; std::string allow_headers_{}; std::string expose_headers_{}; @@ -227,6 +229,7 @@ class MockRouteEntry : public RouteEntry { MOCK_CONST_METHOD0(retryPolicy, const RetryPolicy&()); MOCK_CONST_METHOD0(shadowPolicy, const ShadowPolicy&()); MOCK_CONST_METHOD0(timeout, std::chrono::milliseconds()); + MOCK_CONST_METHOD0(idleTimeout, absl::optional()); MOCK_CONST_METHOD0(maxGrpcTimeout, absl::optional()); MOCK_CONST_METHOD1(virtualCluster, const VirtualCluster*(const Http::HeaderMap& headers)); MOCK_CONST_METHOD0(virtualHostName, const std::string&()); @@ -307,17 +310,14 @@ class MockRouteConfigProviderManager : public RouteConfigProviderManager { MockRouteConfigProviderManager(); ~MockRouteConfigProviderManager(); - MOCK_METHOD3(getRdsRouteConfigProvider, - RouteConfigProviderSharedPtr( + MOCK_METHOD3(createRdsRouteConfigProvider, + RouteConfigProviderPtr( const envoy::config::filter::network::http_connection_manager::v2::Rds& rds, Server::Configuration::FactoryContext& factory_context, const std::string& stat_prefix)); - MOCK_METHOD2( - getStaticRouteConfigProvider, - RouteConfigProviderSharedPtr(const envoy::api::v2::RouteConfiguration& route_config, - Server::Configuration::FactoryContext& factory_context)); - MOCK_METHOD0(getRdsRouteConfigProviders, std::vector()); - MOCK_METHOD0(getStaticRouteConfigProviders, std::vector()); + MOCK_METHOD2(createStaticRouteConfigProvider, + RouteConfigProviderPtr(const envoy::api::v2::RouteConfiguration& route_config, + Server::Configuration::FactoryContext& factory_context)); }; } // namespace Router diff --git a/test/mocks/server/BUILD b/test/mocks/server/BUILD index 8d203b45580ec..17bb23d6c12f0 100644 --- a/test/mocks/server/BUILD +++ b/test/mocks/server/BUILD @@ -24,6 +24,7 @@ envoy_cc_mock( "//include/envoy/server:options_interface", "//include/envoy/server:worker_interface", "//include/envoy/ssl:context_manager_interface", + "//include/envoy/upstream:health_checker_interface", "//source/common/secret:secret_manager_impl_lib", "//source/common/singleton:manager_impl_lib", "//source/common/ssl:context_lib", diff --git a/test/mocks/server/mocks.cc b/test/mocks/server/mocks.cc index bc27e777f7e2f..46920ab564294 100644 --- a/test/mocks/server/mocks.cc +++ b/test/mocks/server/mocks.cc @@ -28,7 +28,7 @@ MockOptions::MockOptions(const std::string& config_path) : config_path_(config_p ON_CALL(*this, serviceZone()).WillByDefault(ReturnRef(service_zone_name_)); ON_CALL(*this, logPath()).WillByDefault(ReturnRef(log_path_)); ON_CALL(*this, maxStats()).WillByDefault(Return(1000)); - ON_CALL(*this, maxObjNameLength()).WillByDefault(Return(150)); + ON_CALL(*this, statsOptions()).WillByDefault(ReturnRef(stats_options_)); ON_CALL(*this, hotRestartDisabled()).WillByDefault(ReturnPointee(&hot_restart_disabled_)); } MockOptions::~MockOptions() {} @@ -169,10 +169,12 @@ MockListenerFactoryContext::MockListenerFactoryContext() {} MockListenerFactoryContext::~MockListenerFactoryContext() {} MockHealthCheckerFactoryContext::MockHealthCheckerFactoryContext() { + event_logger_ = new NiceMock(); ON_CALL(*this, cluster()).WillByDefault(ReturnRef(cluster_)); ON_CALL(*this, dispatcher()).WillByDefault(ReturnRef(dispatcher_)); ON_CALL(*this, random()).WillByDefault(ReturnRef(random_)); ON_CALL(*this, runtime()).WillByDefault(ReturnRef(runtime_)); + ON_CALL(*this, eventLogger_()).WillByDefault(Return(event_logger_)); } MockHealthCheckerFactoryContext::~MockHealthCheckerFactoryContext() {} diff --git a/test/mocks/server/mocks.h b/test/mocks/server/mocks.h index 9bfd7e1a45680..83fbea4dcfbb0 100644 --- a/test/mocks/server/mocks.h +++ b/test/mocks/server/mocks.h @@ -68,7 +68,7 @@ class MockOptions : public Options { MOCK_CONST_METHOD0(serviceNodeName, const std::string&()); MOCK_CONST_METHOD0(serviceZone, const std::string&()); MOCK_CONST_METHOD0(maxStats, uint64_t()); - MOCK_CONST_METHOD0(maxObjNameLength, uint64_t()); + MOCK_CONST_METHOD0(statsOptions, const Stats::StatsOptions&()); MOCK_CONST_METHOD0(hotRestartDisabled, bool()); std::string config_path_; @@ -79,6 +79,7 @@ class MockOptions : public Options { std::string service_node_name_; std::string service_zone_name_; std::string log_path_; + Stats::StatsOptionsImpl stats_options_; bool hot_restart_disabled_{}; }; @@ -172,12 +173,12 @@ class MockHotRestart : public HotRestart { MOCK_METHOD0(version, std::string()); MOCK_METHOD0(logLock, Thread::BasicLockable&()); MOCK_METHOD0(accessLogLock, Thread::BasicLockable&()); - MOCK_METHOD0(statsAllocator, Stats::RawStatDataAllocator&()); + MOCK_METHOD0(statsAllocator, Stats::StatDataAllocator&()); private: Thread::MutexBasicLockable log_lock_; Thread::MutexBasicLockable access_log_lock_; - Stats::HeapRawStatDataAllocator stats_allocator_; + Stats::HeapStatDataAllocator stats_allocator_; }; class MockListenerComponentFactory : public ListenerComponentFactory { @@ -438,11 +439,16 @@ class MockHealthCheckerFactoryContext : public virtual HealthCheckerFactoryConte MOCK_METHOD0(dispatcher, Event::Dispatcher&()); MOCK_METHOD0(random, Envoy::Runtime::RandomGenerator&()); MOCK_METHOD0(runtime, Envoy::Runtime::Loader&()); + MOCK_METHOD0(eventLogger_, Upstream::HealthCheckEventLogger*()); + Upstream::HealthCheckEventLoggerPtr eventLogger() override { + return Upstream::HealthCheckEventLoggerPtr(eventLogger_()); + } testing::NiceMock cluster_; testing::NiceMock dispatcher_; testing::NiceMock random_; testing::NiceMock runtime_; + testing::NiceMock* event_logger_{}; }; } // namespace Configuration diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 60691fa397973..033e58afd9d5f 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -21,21 +21,11 @@ class MockContextManager : public ContextManager { MockContextManager(); ~MockContextManager(); - ClientContextPtr createSslClientContext(Stats::Scope& scope, - const ClientContextConfig& config) override { - return ClientContextPtr{createSslClientContext_(scope, config)}; - } - - ServerContextPtr createSslServerContext(Stats::Scope& scope, const ServerContextConfig& config, - const std::vector& server_names) override { - return ServerContextPtr{createSslServerContext_(scope, config, server_names)}; - } - - MOCK_METHOD2(createSslClientContext_, - ClientContext*(Stats::Scope& scope, const ClientContextConfig& config)); - MOCK_METHOD3(createSslServerContext_, - ServerContext*(Stats::Scope& stats, const ServerContextConfig& config, - const std::vector& server_names)); + MOCK_METHOD2(createSslClientContext, + ClientContextSharedPtr(Stats::Scope& scope, const ClientContextConfig& config)); + MOCK_METHOD3(createSslServerContext, + ServerContextSharedPtr(Stats::Scope& stats, const ServerContextConfig& config, + const std::vector& server_names)); MOCK_CONST_METHOD0(daysUntilFirstCertExpires, size_t()); MOCK_METHOD1(iterateContexts, void(std::function callback)); }; @@ -48,6 +38,7 @@ class MockConnection : public Connection { MOCK_CONST_METHOD0(peerCertificatePresented, bool()); MOCK_METHOD0(uriSanLocalCertificate, std::string()); MOCK_CONST_METHOD0(sha256PeerCertificateDigest, std::string&()); + MOCK_CONST_METHOD0(serialNumberPeerCertificate, std::string()); MOCK_CONST_METHOD0(subjectPeerCertificate, std::string()); MOCK_CONST_METHOD0(uriSanPeerCertificate, std::string()); MOCK_CONST_METHOD0(subjectLocalCertificate, std::string()); diff --git a/test/mocks/stats/mocks.cc b/test/mocks/stats/mocks.cc index 10fef63630558..9ec20c8dec6c0 100644 --- a/test/mocks/stats/mocks.cc +++ b/test/mocks/stats/mocks.cc @@ -79,6 +79,7 @@ MockStore::MockStore() { histograms_.emplace_back(histogram); return *histogram; })); + ON_CALL(*this, statsOptions()).WillByDefault(ReturnRef(stats_options_)); } MockStore::~MockStore() {} diff --git a/test/mocks/stats/mocks.h b/test/mocks/stats/mocks.h index 5237d16ea5577..c98fa2758f1b1 100644 --- a/test/mocks/stats/mocks.h +++ b/test/mocks/stats/mocks.h @@ -145,9 +145,11 @@ class MockStore : public Store { MOCK_CONST_METHOD0(gauges, std::vector()); MOCK_METHOD1(histogram, Histogram&(const std::string& name)); MOCK_CONST_METHOD0(histograms, std::vector()); + MOCK_CONST_METHOD0(statsOptions, const Stats::StatsOptions&()); testing::NiceMock counter_; std::vector> histograms_; + StatsOptionsImpl stats_options_; }; /** diff --git a/test/mocks/tcp/BUILD b/test/mocks/tcp/BUILD new file mode 100644 index 0000000000000..8634b86e9c5c8 --- /dev/null +++ b/test/mocks/tcp/BUILD @@ -0,0 +1,21 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_mock", + "envoy_package", +) + +envoy_package() + +envoy_cc_mock( + name = "tcp_mocks", + srcs = ["mocks.cc"], + hdrs = ["mocks.h"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/tcp:conn_pool_interface", + "//test/mocks/network:network_mocks", + "//test/mocks/upstream:host_mocks", + ], +) diff --git a/test/mocks/tcp/mocks.cc b/test/mocks/tcp/mocks.cc new file mode 100644 index 0000000000000..1374d415dea7f --- /dev/null +++ b/test/mocks/tcp/mocks.cc @@ -0,0 +1,63 @@ +#include "mocks.h" + +#include "gmock/gmock.h" + +using testing::ReturnRef; + +using testing::Invoke; +using testing::ReturnRef; +using testing::_; + +namespace Envoy { +namespace Tcp { +namespace ConnectionPool { + +MockCancellable::MockCancellable() {} +MockCancellable::~MockCancellable() {} + +MockUpstreamCallbacks::MockUpstreamCallbacks() {} +MockUpstreamCallbacks::~MockUpstreamCallbacks() {} + +MockConnectionData::MockConnectionData() {} +MockConnectionData::~MockConnectionData() { + if (release_callback_) { + release_callback_(); + } +} + +MockInstance::MockInstance() { + ON_CALL(*this, newConnection(_)).WillByDefault(Invoke([&](Callbacks& cb) -> Cancellable* { + return newConnectionImpl(cb); + })); +} +MockInstance::~MockInstance() {} + +MockCancellable* MockInstance::newConnectionImpl(Callbacks& cb) { + handles_.emplace_back(); + callbacks_.push_back(&cb); + return &handles_.back(); +} + +void MockInstance::poolFailure(PoolFailureReason reason) { + Callbacks* cb = callbacks_.front(); + callbacks_.pop_front(); + handles_.pop_front(); + + cb->onPoolFailure(reason, host_); +} + +void MockInstance::poolReady(Network::MockClientConnection& conn) { + Callbacks* cb = callbacks_.front(); + callbacks_.pop_front(); + handles_.pop_front(); + + ON_CALL(*connection_data_, connection()).WillByDefault(ReturnRef(conn)); + + connection_data_->release_callback_ = [&]() -> void { released(conn); }; + + cb->onPoolReady(std::move(connection_data_), host_); +} + +} // namespace ConnectionPool +} // namespace Tcp +} // namespace Envoy diff --git a/test/mocks/tcp/mocks.h b/test/mocks/tcp/mocks.h new file mode 100644 index 0000000000000..3fb969f088b2d --- /dev/null +++ b/test/mocks/tcp/mocks.h @@ -0,0 +1,81 @@ +#pragma once + +#include "envoy/tcp/conn_pool.h" + +#include "test/mocks/common.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/upstream/host.h" +#include "test/test_common/printers.h" + +#include "gmock/gmock.h" + +using testing::NiceMock; + +namespace Envoy { +namespace Tcp { +namespace ConnectionPool { + +class MockCancellable : public Cancellable { +public: + MockCancellable(); + ~MockCancellable(); + + // Tcp::ConnectionPool::Cancellable + MOCK_METHOD0(cancel, void()); +}; + +class MockUpstreamCallbacks : public UpstreamCallbacks { +public: + MockUpstreamCallbacks(); + ~MockUpstreamCallbacks(); + + // Tcp::ConnectionPool::UpstreamCallbacks + MOCK_METHOD2(onUpstreamData, void(Buffer::Instance& data, bool end_stream)); + MOCK_METHOD1(onEvent, void(Network::ConnectionEvent event)); + MOCK_METHOD0(onAboveWriteBufferHighWatermark, void()); + MOCK_METHOD0(onBelowWriteBufferLowWatermark, void()); +}; + +class MockConnectionData : public ConnectionData { +public: + MockConnectionData(); + ~MockConnectionData(); + + // Tcp::ConnectionPool::ConnectionData + MOCK_METHOD0(connection, Network::ClientConnection&()); + MOCK_METHOD1(addUpstreamCallbacks, void(ConnectionPool::UpstreamCallbacks&)); + + // If set, invoked in ~MockConnectionData, which indicates that the connection pool + // caller has relased a connection. + std::function release_callback_; +}; + +class MockInstance : public Instance { +public: + MockInstance(); + ~MockInstance(); + + // Tcp::ConnectionPool::Instance + MOCK_METHOD1(addDrainedCallback, void(DrainedCb cb)); + MOCK_METHOD0(drainConnections, void()); + MOCK_METHOD1(newConnection, Cancellable*(Tcp::ConnectionPool::Callbacks& callbacks)); + + MockCancellable* newConnectionImpl(Callbacks& cb); + void poolFailure(PoolFailureReason reason); + void poolReady(Network::MockClientConnection& conn); + + // Invoked when connection_data_, having been assigned via poolReady is released. + MOCK_METHOD1(released, void(Network::MockClientConnection&)); + + std::list> handles_; + std::list callbacks_; + + std::shared_ptr> host_{ + new NiceMock()}; + std::unique_ptr> connection_data_{ + new NiceMock()}; +}; + +} // namespace ConnectionPool +} // namespace Tcp +} // namespace Envoy diff --git a/test/mocks/upstream/BUILD b/test/mocks/upstream/BUILD index 753a2727ba90f..c23d76e9ed2aa 100644 --- a/test/mocks/upstream/BUILD +++ b/test/mocks/upstream/BUILD @@ -45,6 +45,7 @@ envoy_cc_mock( "//include/envoy/upstream:health_checker_interface", "//include/envoy/upstream:load_balancer_interface", "//include/envoy/upstream:upstream_interface", + "//source/common/upstream:health_discovery_service_lib", "//source/common/upstream:upstream_lib", "//test/mocks/config:config_mocks", "//test/mocks/grpc:grpc_mocks", @@ -52,5 +53,6 @@ envoy_cc_mock( "//test/mocks/runtime:runtime_mocks", "//test/mocks/secret:secret_mocks", "//test/mocks/stats:stats_mocks", + "//test/mocks/tcp:tcp_mocks", ], ) diff --git a/test/mocks/upstream/cluster_info.h b/test/mocks/upstream/cluster_info.h index 8c34157daf4d1..d84e2a46ddf9f 100644 --- a/test/mocks/upstream/cluster_info.h +++ b/test/mocks/upstream/cluster_info.h @@ -30,6 +30,7 @@ class MockLoadBalancerSubsetInfo : public LoadBalancerSubsetInfo { envoy::api::v2::Cluster::LbSubsetConfig::LbSubsetFallbackPolicy()); MOCK_CONST_METHOD0(defaultSubset, const ProtobufWkt::Struct&()); MOCK_CONST_METHOD0(subsetKeys, const std::vector>&()); + MOCK_CONST_METHOD0(localityWeightAware, bool()); std::vector> subset_keys_; }; diff --git a/test/mocks/upstream/host.h b/test/mocks/upstream/host.h index 32faed83aea21..3de35b55fb433 100644 --- a/test/mocks/upstream/host.h +++ b/test/mocks/upstream/host.h @@ -78,7 +78,9 @@ class MockHostDescription : public HostDescription { MOCK_CONST_METHOD0(address, Network::Address::InstanceConstSharedPtr()); MOCK_CONST_METHOD0(healthCheckAddress, Network::Address::InstanceConstSharedPtr()); MOCK_CONST_METHOD0(canary, bool()); - MOCK_CONST_METHOD0(metadata, const envoy::api::v2::core::Metadata&()); + MOCK_METHOD1(canary, void(bool new_canary)); + MOCK_CONST_METHOD0(metadata, const std::shared_ptr()); + MOCK_METHOD1(metadata, void(const envoy::api::v2::core::Metadata&)); MOCK_CONST_METHOD0(cluster, const ClusterInfo&()); MOCK_CONST_METHOD0(outlierDetector, Outlier::DetectorHostMonitor&()); MOCK_CONST_METHOD0(healthChecker, HealthCheckHostMonitor&()); @@ -128,7 +130,9 @@ class MockHost : public Host { MOCK_CONST_METHOD0(address, Network::Address::InstanceConstSharedPtr()); MOCK_CONST_METHOD0(healthCheckAddress, Network::Address::InstanceConstSharedPtr()); MOCK_CONST_METHOD0(canary, bool()); - MOCK_CONST_METHOD0(metadata, const envoy::api::v2::core::Metadata&()); + MOCK_METHOD1(canary, void(bool new_canary)); + MOCK_CONST_METHOD0(metadata, const std::shared_ptr()); + MOCK_METHOD1(metadata, void(const envoy::api::v2::core::Metadata&)); MOCK_CONST_METHOD0(cluster, const ClusterInfo&()); MOCK_CONST_METHOD0(counters, std::vector()); MOCK_CONST_METHOD2( diff --git a/test/mocks/upstream/mocks.cc b/test/mocks/upstream/mocks.cc index 874ef72f4483c..f4f1bbc1a8e2c 100644 --- a/test/mocks/upstream/mocks.cc +++ b/test/mocks/upstream/mocks.cc @@ -88,6 +88,7 @@ MockThreadLocalCluster::~MockThreadLocalCluster() {} MockClusterManager::MockClusterManager() { ON_CALL(*this, httpConnPoolForCluster(_, _, _, _)).WillByDefault(Return(&conn_pool_)); + ON_CALL(*this, tcpConnPoolForCluster(_, _, _)).WillByDefault(Return(&tcp_conn_pool_)); ON_CALL(*this, httpAsyncClientForCluster(_)).WillByDefault(ReturnRef(async_client_)); ON_CALL(*this, httpAsyncClientForCluster(_)).WillByDefault((ReturnRef(async_client_))); ON_CALL(*this, bindConfig()).WillByDefault(ReturnRef(bind_config_)); diff --git a/test/mocks/upstream/mocks.h b/test/mocks/upstream/mocks.h index 109cc31c2c678..bfdbd73624fa9 100644 --- a/test/mocks/upstream/mocks.h +++ b/test/mocks/upstream/mocks.h @@ -12,6 +12,7 @@ #include "envoy/upstream/upstream.h" #include "common/common/callback_impl.h" +#include "common/upstream/health_discovery_service.h" #include "common/upstream/upstream_impl.h" #include "test/mocks/config/mocks.h" @@ -20,6 +21,7 @@ #include "test/mocks/runtime/mocks.h" #include "test/mocks/secret/mocks.h" #include "test/mocks/stats/mocks.h" +#include "test/mocks/tcp/mocks.h" #include "test/mocks/upstream/cluster_info.h" #include "gmock/gmock.h" @@ -154,10 +156,16 @@ class MockClusterManagerFactory : public ClusterManagerFactory { ResourcePriority priority, Http::Protocol protocol, const Network::ConnectionSocket::OptionsSharedPtr& options)); - MOCK_METHOD4(clusterFromProto, + MOCK_METHOD4( + allocateTcpConnPool, + Tcp::ConnectionPool::InstancePtr(Event::Dispatcher& dispatcher, HostConstSharedPtr host, + ResourcePriority priority, + const Network::ConnectionSocket::OptionsSharedPtr& options)); + + MOCK_METHOD5(clusterFromProto, ClusterSharedPtr(const envoy::api::v2::Cluster& cluster, ClusterManager& cm, Outlier::EventLoggerSharedPtr outlier_event_logger, - bool added_via_api)); + AccessLog::AccessLogManager& log_manager, bool added_via_api)); MOCK_METHOD3(createCds, CdsApiPtr(const envoy::api::v2::core::ConfigSource& cds_config, @@ -191,6 +199,9 @@ class MockClusterManager : public ClusterManager { Http::ConnectionPool::Instance*(const std::string& cluster, ResourcePriority priority, Http::Protocol protocol, LoadBalancerContext* context)); + MOCK_METHOD3(tcpConnPoolForCluster, + Tcp::ConnectionPool::Instance*(const std::string& cluster, ResourcePriority priority, + LoadBalancerContext* context)); MOCK_METHOD2(tcpConnForCluster_, MockHost::MockCreateConnectionData(const std::string& cluster, LoadBalancerContext* context)); @@ -207,6 +218,7 @@ class MockClusterManager : public ClusterManager { NiceMock conn_pool_; NiceMock async_client_; + NiceMock tcp_conn_pool_; NiceMock thread_local_cluster_; envoy::api::v2::core::BindConfig bind_config_; NiceMock ads_mux_; @@ -232,6 +244,15 @@ class MockHealthChecker : public HealthChecker { std::list callbacks_; }; +class MockHealthCheckEventLogger : public HealthCheckEventLogger { +public: + MOCK_METHOD3(logEjectUnhealthy, void(envoy::data::core::v2alpha::HealthCheckerType, + const HostDescriptionConstSharedPtr&, + envoy::data::core::v2alpha::HealthCheckFailureType)); + MOCK_METHOD3(logAddHealthy, void(envoy::data::core::v2alpha::HealthCheckerType, + const HostDescriptionConstSharedPtr&, bool)); +}; + class MockCdsApi : public CdsApi { public: MockCdsApi(); @@ -253,5 +274,15 @@ class MockClusterUpdateCallbacks : public ClusterUpdateCallbacks { MOCK_METHOD1(onClusterRemoval, void(const std::string& cluster_name)); }; +class MockClusterInfoFactory : public ClusterInfoFactory, Logger::Loggable { +public: + MOCK_METHOD7( + createClusterInfo, + ClusterInfoConstSharedPtr(Runtime::Loader& runtime, const envoy::api::v2::Cluster& cluster, + const envoy::api::v2::core::BindConfig& bind_config, + Stats::Store& stats, Ssl::ContextManager& ssl_context_manager, + Secret::SecretManager& secret_manager, bool added_via_api)); +}; + } // namespace Upstream } // namespace Envoy diff --git a/test/proto/BUILD b/test/proto/BUILD index 3b21fb2d2ea29..fa16742517196 100644 --- a/test/proto/BUILD +++ b/test/proto/BUILD @@ -20,8 +20,9 @@ envoy_proto_library( name = "bookstore_proto", srcs = [":bookstore.proto"], external_deps = [ - "well_known_protos", + "api_httpbody_protos", "http_api_protos", + "well_known_protos", ], ) @@ -32,6 +33,7 @@ envoy_proto_descriptor( ], out = "bookstore.descriptor", external_deps = [ + "api_httpbody_protos", "http_api_protos", "well_known_protos", ], diff --git a/test/proto/bookstore.proto b/test/proto/bookstore.proto index 45ca15bf54c62..6fb65b42d119e 100644 --- a/test/proto/bookstore.proto +++ b/test/proto/bookstore.proto @@ -3,6 +3,7 @@ syntax = "proto3"; package bookstore; import "google/api/annotations.proto"; +import "google/api/httpbody.proto"; import "google/protobuf/empty.proto"; // A simple Bookstore API. @@ -84,6 +85,11 @@ service Bookstore { get: "/authors/{author}" }; } + rpc GetIndex(google.protobuf.Empty) returns (google.api.HttpBody) { + option (google.api.http) = { + get: "/index" + }; + } } // A shelf resource. diff --git a/test/run_envoy_bazel_coverage.sh b/test/run_envoy_bazel_coverage.sh index 5a8900ce39919..a336d6949e44b 100755 --- a/test/run_envoy_bazel_coverage.sh +++ b/test/run_envoy_bazel_coverage.sh @@ -8,6 +8,7 @@ set -e [[ -z "${BAZEL_COVERAGE}" ]] && BAZEL_COVERAGE=bazel [[ -z "${GCOVR}" ]] && GCOVR=gcovr [[ -z "${WORKSPACE}" ]] && WORKSPACE=envoy +[[ -z "${VALIDATE_COVERAGE}" ]] && VALIDATE_COVERAGE=true # This is the target that will be run to generate coverage data. It can be overriden by consumer # projects that want to run coverage on a different/combined target. @@ -56,7 +57,7 @@ COVERAGE_SUMMARY="${COVERAGE_DIR}/coverage_summary.txt" pushd "${GCOVR_DIR}" for f in $(find -L bazel-out/ -name "*.gcno") do - cp --parents "$f" bazel-out/k8-dbg/bin/test/coverage/coverage_tests.runfiles/"${WORKSPACE}" + cp --parents "$f" bazel-out/k8-dbg/bin/"${COVERAGE_TARGET/:/\/}".runfiles/"${WORKSPACE}" done popd @@ -76,13 +77,16 @@ rm "${SRCDIR}"/test/coverage/BUILD [[ -z "${ENVOY_COVERAGE_DIR}" ]] || rsync -av "${COVERAGE_DIR}"/ "${ENVOY_COVERAGE_DIR}" -COVERAGE_VALUE=$(grep -Po 'lines: \K(\d|\.)*' "${COVERAGE_SUMMARY}") -COVERAGE_THRESHOLD=98.0 -COVERAGE_FAILED=$(echo "${COVERAGE_VALUE}<${COVERAGE_THRESHOLD}" | bc) -if test ${COVERAGE_FAILED} -eq 1; then - echo Code coverage ${COVERAGE_VALUE} is lower than limit of ${COVERAGE_THRESHOLD} - exit 1 -else - echo Code coverage ${COVERAGE_VALUE} is good and higher than limit of ${COVERAGE_THRESHOLD} +if [ "$VALIDATE_COVERAGE" == "true" ] +then + COVERAGE_VALUE=$(grep -Po 'lines: \K(\d|\.)*' "${COVERAGE_SUMMARY}") + COVERAGE_THRESHOLD=98.0 + COVERAGE_FAILED=$(echo "${COVERAGE_VALUE}<${COVERAGE_THRESHOLD}" | bc) + if test ${COVERAGE_FAILED} -eq 1; then + echo Code coverage ${COVERAGE_VALUE} is lower than limit of ${COVERAGE_THRESHOLD} + exit 1 + else + echo Code coverage ${COVERAGE_VALUE} is good and higher than limit of ${COVERAGE_THRESHOLD} + fi + echo "HTML coverage report is in ${COVERAGE_DIR}/coverage.html" fi -echo "HTML coverage report is in ${COVERAGE_DIR}/coverage.html" diff --git a/test/server/BUILD b/test/server/BUILD index 723c30ef32d0c..4ceb534c44631 100644 --- a/test/server/BUILD +++ b/test/server/BUILD @@ -73,6 +73,7 @@ envoy_cc_test( "//source/common/stats:stats_lib", "//source/server:hot_restart_lib", "//test/mocks/server:server_mocks", + "//test/test_common:logging_lib", "//test/test_common:threadsafe_singleton_injector_lib", ], ) @@ -112,6 +113,19 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "overload_manager_impl_test", + srcs = ["overload_manager_impl_test.cc"], + deps = [ + "//include/envoy/registry", + "//source/extensions/resource_monitors/common:factory_base_lib", + "//source/server:overload_manager_lib", + "//test/mocks/event:event_mocks", + "//test/test_common:registry_lib", + "//test/test_common:utility_lib", + ], +) + envoy_cc_test( name = "lds_api_test", srcs = ["lds_api_test.cc"], @@ -161,6 +175,7 @@ envoy_cc_fuzz_test( corpus = "server_corpus", deps = [ "//source/common/thread_local:thread_local_lib", + "//source/server:proto_descriptors_lib", "//source/server:server_lib", "//test/integration:integration_lib", "//test/mocks/server:server_mocks", @@ -174,6 +189,7 @@ envoy_cc_test( srcs = ["server_test.cc"], data = [ ":cluster_dupe_bootstrap.yaml", + ":cluster_health_check_bootstrap.yaml", ":empty_bootstrap.yaml", ":node_bootstrap.yaml", "//test/config/integration:server.json", diff --git a/test/server/cluster_health_check_bootstrap.yaml b/test/server/cluster_health_check_bootstrap.yaml new file mode 100644 index 0000000000000..7d928f9ca4335 --- /dev/null +++ b/test/server/cluster_health_check_bootstrap.yaml @@ -0,0 +1,17 @@ +admin: + access_log_path: /dev/null + address: + socket_address: + address: {{ ntop_ip_loopback_address }} + port_value: 0 +static_resources: + clusters: + - name: service_google + connect_timeout: 0.25s + health_checks: + - timeout: {{ health_check_timeout }}s + interval: {{ health_check_interval }}s + unhealthy_threshold: 1 + healthy_threshold: 1 + http_health_check: + path: "/" \ No newline at end of file diff --git a/test/server/config_validation/BUILD b/test/server/config_validation/BUILD index 2925a168f5ac7..c57990c451c21 100644 --- a/test/server/config_validation/BUILD +++ b/test/server/config_validation/BUILD @@ -1,6 +1,10 @@ licenses(["notice"]) # Apache 2 -load("//bazel:envoy_build_system.bzl", "envoy_cc_test", "envoy_package") +load("//bazel:envoy_build_system.bzl", "envoy_cc_fuzz_test", "envoy_cc_test", "envoy_package") +load( + "//source/extensions:all_extensions.bzl", + "envoy_all_extensions", +) envoy_package() @@ -72,3 +76,15 @@ envoy_cc_test( "//test/test_common:network_utility_lib", ], ) + +envoy_cc_fuzz_test( + name = "config_fuzz_test", + srcs = ["config_fuzz_test.cc"], + corpus = "config_corpus", + deps = [ + "//source/server/config_validation:server_lib", + "//test/integration:integration_lib", + "//test/mocks/server:server_mocks", + "//test/test_common:environment_lib", + ] + envoy_all_extensions(), +) diff --git a/test/server/config_validation/config_corpus/clusterfuzz-testcase-config_fuzz_test-5697041979146240 b/test/server/config_validation/config_corpus/clusterfuzz-testcase-config_fuzz_test-5697041979146240 new file mode 100644 index 0000000000000..6379608692417 --- /dev/null +++ b/test/server/config_validation/config_corpus/clusterfuzz-testcase-config_fuzz_test-5697041979146240 @@ -0,0 +1,231 @@ +static_resources { + clusters { + name: "ineasrh_stsB" + eds_cluster_config { + service_name: "\177" + } + connect_timeout { + nanos: 249999905 + } + dns_refresh_rate { + nanos: 249999905 + } + dns_lookup_family: V4_ONLY + load_assignment { + cluster_name: "GGG" + endpoints { + priority: 538970624 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + priority: 246 + } + endpoints { + priority: 538970624 + } + endpoints { + priority: 2 + } + endpoints { + priority: 538970624 + } + endpoints { + priority: 2105354 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + priority: 671091188 + } + endpoints { + priority: 2105354 + } + endpoints { + priority: 11264 + } + endpoints { + lb_endpoints { + } + load_balancing_weight { + value: 2 + } + priority: 11264 + } + endpoints { + priority: 538970624 + } + endpoints { + priority: 538970624 + } + endpoints { + priority: 538970624 + } + endpoints { + priority: 11264 + } + endpoints { + priority: 2105354 + } + endpoints { + priority: 671091190 + } + endpoints { + priority: 671151625 + } + endpoints { + priority: 671151625 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + priority: 671091190 + } + endpoints { + priority: 538976256 + } + endpoints { + priority: 671091188 + } + endpoints { + priority: 671091188 + } + endpoints { + priority: 11264 + } + endpoints { + lb_endpoints { + health_status: DRAINING + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + priority: 11264 + } + endpoints { + priority: 538970624 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 11264 + } + endpoints { + priority: 11264 + } + endpoints { + lb_endpoints { + load_balancing_weight { + value: 2 + } + } + priority: 738208768 + } + endpoints { + priority: 671151625 + } + endpoints { + priority: 671091188 + } + endpoints { + priority: 30 + } + endpoints { + priority: 671151625 + } + endpoints { + lb_endpoints { + } + load_balancing_weight { + value: 64 + } + priority: 2 + } + endpoints { + lb_endpoints { + health_status: DRAINING + } + priority: 11264 + } + endpoints { + lb_endpoints { + } + priority: 11264 + } + endpoints { + locality { + zone: "\177\r" + } + priority: 538970624 + } + endpoints { + priority: 671091190 + } + endpoints { + priority: 244 + } + endpoints { + priority: 538970624 + } + endpoints { + locality { + region: "~" + } + lb_endpoints { + } + priority: 11264 + } + endpoints { + priority: 2 + } + } + } +} +admin { + access_log_path: "/tmp/admin_access.log" + address { + pipe { + path: "*" + } + } +} diff --git a/test/server/config_validation/config_corpus/clusterfuzz-testcase-config_fuzz_test-6287096397430784 b/test/server/config_validation/config_corpus/clusterfuzz-testcase-config_fuzz_test-6287096397430784 new file mode 100644 index 0000000000000..dd15f1a510fd8 --- /dev/null +++ b/test/server/config_validation/config_corpus/clusterfuzz-testcase-config_fuzz_test-6287096397430784 @@ -0,0 +1,252 @@ +static_resources { + clusters { + name: " " + connect_timeout { + seconds: 2304 + } + per_connection_buffer_limit_bytes { + value: 209 + } + lb_policy: RING_HASH + hosts { + pipe { + path: "z" + } + } + hosts { + pipe { + path: " " + } + } + hosts { + pipe { + path: ";" + } + } + dns_lookup_family: V4_ONLY + outlier_detection { + success_rate_stdev_factor { + value: 268435456 + } + } + } + clusters { + name: "@" + connect_timeout { + seconds: 2304 + } + lb_policy: RING_HASH + hosts { + pipe { + path: "@" + } + } + hosts { + pipe { + path: "X" + } + } + hosts { + pipe { + path: "@" + } + } + dns_lookup_family: V4_ONLY + outlier_detection { + success_rate_stdev_factor { + value: 589951 + } + } + } + clusters { + name: "#" + connect_timeout { + seconds: 2304 + nanos: 235995425 + } + lb_policy: MAGLEV + dns_lookup_family: V4_ONLY + cleanup_interval { + nanos: 235995425 + } + upstream_connection_options { + tcp_keepalive { + keepalive_probes { + value: 589824 + } + } + } + } + clusters { + name: "X" + connect_timeout { + seconds: 2304 + } + outlier_detection { + success_rate_stdev_factor { + value: 589951 + } + } + lb_subset_config { + fallback_policy: ANY_ENDPOINT + subset_selectors { + } + locality_weight_aware: true + } + ring_hash_lb_config { + deprecated_v1 { + use_std_hash { + value: true + } + } + } + } + clusters { + name: "0" + connect_timeout { + seconds: 2304 + } + outlier_detection { + success_rate_stdev_factor { + value: 589951 + } + } + lb_subset_config { + fallback_policy: ANY_ENDPOINT + subset_selectors { + } + locality_weight_aware: true + } + } + clusters { + name: "`" + connect_timeout { + seconds: 2304 + } + lb_policy: RING_HASH + hosts { + pipe { + path: ";" + } + } + hosts { + pipe { + path: ";" + } + } + dns_lookup_family: V4_ONLY + outlier_detection { + success_rate_stdev_factor { + value: 589951 + } + } + lb_subset_config { + default_subset { + fields { + key: "" + value { + bool_value: true + } + } + } + } + upstream_connection_options { + tcp_keepalive { + keepalive_probes { + value: 589824 + } + } + } + close_connections_on_host_health_failure: true + drain_connections_on_host_removal: true + } + clusters { + name: "z" + connect_timeout { + seconds: 2304 + } + hosts { + pipe { + path: "*" + } + } + hosts { + pipe { + path: "5" + } + } + hosts { + pipe { + path: "z" + } + } + hosts { + pipe { + path: "@" + } + } + hosts { + pipe { + path: "z" + } + } + upstream_connection_options { + tcp_keepalive { + keepalive_probes { + value: 589824 + } + } + } + load_assignment { + cluster_name: " " + endpoints { + locality { + region: " " + } + lb_endpoints { + endpoint { + address { + pipe { + path: "\n\000\000\000" + } + } + health_check_config { + port_value: 10878976 + } + } + health_status: TIMEOUT + } + } + endpoints { + lb_endpoints { + endpoint { + health_check_config { + port_value: 41216 + } + } + health_status: TIMEOUT + } + priority: 41216 + } + endpoints { + locality { + region: "\027" + } + lb_endpoints { + } + } + endpoints { + priority: 41216 + } + } + } +} +admin { + access_log_path: "/tmp/admin_access.lss" + address { + pipe { + path: "*" + } + } +} + diff --git a/test/server/config_validation/config_corpus/google_com_proxy.v2.pb_text b/test/server/config_validation/config_corpus/google_com_proxy.v2.pb_text new file mode 100644 index 0000000000000..3585ae9c7e2a7 --- /dev/null +++ b/test/server/config_validation/config_corpus/google_com_proxy.v2.pb_text @@ -0,0 +1,150 @@ +static_resources { + listeners { + name: "listener_0" + address { + socket_address { + address: "0.0.0.0" + port_value: 0 + } + } + filter_chains { + filters { + name: "envoy.http_connection_manager" + config { + fields { + key: "http_filters" + value { + list_value { + values { + struct_value { + fields { + key: "name" + value { + string_value: "envoy.router" + } + } + } + } + } + } + } + fields { + key: "route_config" + value { + struct_value { + fields { + key: "name" + value { + string_value: "local_route" + } + } + fields { + key: "virtual_hosts" + value { + list_value { + values { + struct_value { + fields { + key: "domains" + value { + list_value { + values { + string_value: "*" + } + } + } + } + fields { + key: "name" + value { + string_value: "local_service" + } + } + fields { + key: "routes" + value { + list_value { + values { + struct_value { + fields { + key: "match" + value { + struct_value { + fields { + key: "prefix" + value { + string_value: "/" + } + } + } + } + } + fields { + key: "route" + value { + struct_value { + fields { + key: "cluster" + value { + string_value: "service_google" + } + } + fields { + key: "host_rewrite" + value { + string_value: "www.google.com" + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + } + fields { + key: "stat_prefix" + value { + string_value: "ingress_http" + } + } + } + } + } + } + clusters { + name: "service_google" + type: LOGICAL_DNS + connect_timeout { + nanos: 250000000 + } + hosts { + socket_address { + address: "google.com" + port_value: 0 + } + } + tls_context { + sni: "www.google.com" + } + dns_lookup_family: V4_ONLY + } +} +admin { + access_log_path: "/tmp/admin_access.log" + address { + socket_address { + address: "127.0.0.1" + port_value: 0 + } + } +} diff --git a/test/server/config_validation/config_fuzz_test.cc b/test/server/config_validation/config_fuzz_test.cc new file mode 100644 index 0000000000000..6ad1e6477cd1d --- /dev/null +++ b/test/server/config_validation/config_fuzz_test.cc @@ -0,0 +1,34 @@ +#include + +#include "common/network/address_impl.h" + +#include "server/config_validation/server.h" + +#include "test/fuzz/fuzz_runner.h" +#include "test/integration/server.h" +#include "test/mocks/server/mocks.h" +#include "test/test_common/environment.h" + +namespace Envoy { +namespace Server { +// Derived from //test/server:server_fuzz_test.cc, but starts the server in configuration validation +// mode (quits upon validation of the given config) +DEFINE_PROTO_FUZZER(const envoy::config::bootstrap::v2::Bootstrap& input) { + testing::NiceMock options; + TestComponentFactory component_factory; + + const std::string bootstrap_path = TestEnvironment::temporaryPath("bootstrap.pb_text"); + std::ofstream bootstrap_file(bootstrap_path); + bootstrap_file << input.DebugString(); + options.config_path_ = bootstrap_path; + options.v2_config_only_ = true; + + try { + validateConfig(options, Network::Address::InstanceConstSharedPtr(), component_factory); + } catch (const EnvoyException& ex) { + ENVOY_LOG_MISC(debug, "Controlled EnvoyException exit: {}", ex.what()); + } +} + +} // namespace Server +} // namespace Envoy diff --git a/test/server/configuration_impl_test.cc b/test/server/configuration_impl_test.cc index bec9bb6d05670..eb18e1b86288e 100644 --- a/test/server/configuration_impl_test.cc +++ b/test/server/configuration_impl_test.cc @@ -242,7 +242,7 @@ TEST_F(ConfigurationImplTest, ProtoSpecifiedStatsSink) { envoy::config::bootstrap::v2::Bootstrap bootstrap = TestUtility::parseBootstrapFromJson(json); auto& sink = *bootstrap.mutable_stats_sinks()->Add(); - sink.set_name(Extensions::StatSinks::StatsSinkNames::get().STATSD); + sink.set_name(Extensions::StatSinks::StatsSinkNames::get().Statsd); auto& field_map = *sink.mutable_config()->mutable_fields(); field_map["tcp_cluster_name"].set_string_value("fake_cluster"); diff --git a/test/server/hot_restart_impl_test.cc b/test/server/hot_restart_impl_test.cc index bda4509117b8e..4216a265c8fe6 100644 --- a/test/server/hot_restart_impl_test.cc +++ b/test/server/hot_restart_impl_test.cc @@ -1,10 +1,12 @@ #include "common/api/os_sys_calls_impl.h" +#include "common/common/hex.h" #include "common/stats/stats_impl.h" #include "server/hot_restart_impl.h" #include "test/mocks/api/mocks.h" #include "test/mocks/server/mocks.h" +#include "test/test_common/logging.h" #include "test/test_common/threadsafe_singleton_injector.h" #include "absl/strings/match.h" @@ -14,6 +16,7 @@ using testing::Invoke; using testing::InvokeWithoutArgs; using testing::Return; +using testing::ReturnRef; using testing::WithArg; using testing::_; @@ -33,24 +36,17 @@ class HotRestartImplTest : public testing::Test { return buffer_.data(); })); EXPECT_CALL(os_sys_calls_, bind(_, _, _)); - - Stats::RawStatData::configureForTestsOnly(options_); + EXPECT_CALL(options_, statsOptions()).WillRepeatedly(ReturnRef(stats_options_)); // Test we match the correct stat with empty-slots before, after, or both. hot_restart_.reset(new HotRestartImpl(options_)); hot_restart_->drainParentListeners(); } - void TearDown() { - // Configure it back so that later tests don't get the wonky values - // used here - NiceMock default_options; - Stats::RawStatData::configureForTestsOnly(default_options); - } - Api::MockOsSysCalls os_sys_calls_; TestThreadsafeSingletonInjector os_calls{&os_sys_calls_}; NiceMock options_; + Stats::StatsOptionsImpl stats_options_; std::vector buffer_; std::unique_ptr hot_restart_; }; @@ -68,7 +64,7 @@ TEST_F(HotRestartImplTest, versionString) { version = hot_restart_->version(); EXPECT_TRUE(absl::StartsWith(version, fmt::format("{}.", SharedMemory::VERSION))) << version; max_stats = options_.maxStats(); // Save this so we can double it below. - max_obj_name_length = options_.maxObjNameLength(); + max_obj_name_length = options_.statsOptions().maxObjNameLength(); TearDown(); } @@ -86,7 +82,7 @@ TEST_F(HotRestartImplTest, versionString) { } { - ON_CALL(options_, maxObjNameLength()).WillByDefault(Return(2 * max_obj_name_length)); + stats_options_.max_obj_name_length_ = 2 * max_obj_name_length; setup(); EXPECT_NE(version, hot_restart_->version()) << "Version changes when max-obj-name-length changes"; @@ -94,6 +90,44 @@ TEST_F(HotRestartImplTest, versionString) { } } +// Check consistency of internal raw stat representation by comparing hash of +// memory contents against a previously recorded value. +TEST_F(HotRestartImplTest, Consistency) { + setup(); + + // Generate a stat, encode it to hex, and take the hash of that hex string. We + // expect the hash to vary only when the internal representation of a stat has + // been intentionally changed, in which case SharedMemory::VERSION should be + // incremented as well. + const uint64_t expected_hash = 1874506077228772558; + const uint64_t max_name_length = stats_options_.maxNameLength(); + + const std::string name_1(max_name_length, 'A'); + Stats::RawStatData* stat_1 = hot_restart_->alloc(name_1); + const uint64_t stat_size = sizeof(Stats::RawStatData) + max_name_length; + const std::string stat_hex_dump_1 = Hex::encode(reinterpret_cast(stat_1), stat_size); + EXPECT_EQ(HashUtil::xxHash64(stat_hex_dump_1), expected_hash); + EXPECT_EQ(name_1, stat_1->key()); + hot_restart_->free(*stat_1); +} + +TEST_F(HotRestartImplTest, RawAlloc) { + setup(); + + Stats::RawStatData* stat_1 = hot_restart_->alloc("ref_name"); + ASSERT_NE(stat_1, nullptr); + Stats::RawStatData* stat_2 = hot_restart_->alloc("ref_name"); + ASSERT_NE(stat_2, nullptr); + Stats::RawStatData* stat_3 = hot_restart_->alloc("not_ref_name"); + ASSERT_NE(stat_3, nullptr); + EXPECT_EQ(stat_1, stat_2); + EXPECT_NE(stat_1, stat_3); + EXPECT_NE(stat_2, stat_3); + hot_restart_->free(*stat_1); + hot_restart_->free(*stat_2); + hot_restart_->free(*stat_3); +} + TEST_F(HotRestartImplTest, crossAlloc) { setup(); @@ -120,16 +154,6 @@ TEST_F(HotRestartImplTest, crossAlloc) { EXPECT_EQ(stat5, stat5_prime); } -TEST_F(HotRestartImplTest, truncateKey) { - setup(); - - std::string key1(Stats::RawStatData::maxNameLength(), 'a'); - Stats::RawStatData* stat1 = hot_restart_->alloc(key1); - std::string key2 = key1 + "a"; - Stats::RawStatData* stat2 = hot_restart_->alloc(key2); - EXPECT_EQ(stat1, stat2); -} - TEST_F(HotRestartImplTest, allocFail) { EXPECT_CALL(options_, maxStats()).WillRepeatedly(Return(2)); setup(); @@ -150,12 +174,15 @@ class HotRestartImplAlignmentTest : public HotRestartImplTest, public testing::WithParamInterface { public: HotRestartImplAlignmentTest() : name_len_(8 + GetParam()) { + stats_options_.max_obj_name_length_ = name_len_; + EXPECT_CALL(options_, statsOptions()).WillRepeatedly(ReturnRef(stats_options_)); EXPECT_CALL(options_, maxStats()).WillRepeatedly(Return(num_stats_)); - EXPECT_CALL(options_, maxObjNameLength()).WillRepeatedly(Return(name_len_)); + setup(); - EXPECT_EQ(name_len_, Stats::RawStatData::maxObjNameLength()); + EXPECT_EQ(name_len_ + stats_options_.maxStatSuffixLength(), stats_options_.maxNameLength()); } + Stats::StatsOptionsImpl stats_options_; static const uint64_t num_stats_ = 8; const uint64_t name_len_; }; @@ -185,7 +212,7 @@ TEST_P(HotRestartImplAlignmentTest, objectOverlap) { "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz" "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", i) - .substr(0, Stats::RawStatData::maxNameLength()); + .substr(0, stats_options_.maxNameLength()); TestStat ts; ts.stat_ = hot_restart_->alloc(name); ts.name_ = ts.stat_->name_; @@ -193,7 +220,7 @@ TEST_P(HotRestartImplAlignmentTest, objectOverlap) { // If this isn't true then the hard coded part of the name isn't long enough to make the test // valid. - EXPECT_EQ(ts.name_.size(), Stats::RawStatData::maxNameLength()); + EXPECT_EQ(ts.name_.size(), stats_options_.maxNameLength()); stats.push_back(ts); } diff --git a/test/server/http/BUILD b/test/server/http/BUILD index eabf36be9b5e1..c2e35b7ac745c 100644 --- a/test/server/http/BUILD +++ b/test/server/http/BUILD @@ -17,6 +17,8 @@ envoy_cc_test( "//source/common/http:message_lib", "//source/common/json:json_loader_lib", "//source/common/profiler:profiler_lib", + "//source/common/protobuf", + "//source/common/protobuf:utility_lib", "//source/common/stats:thread_local_store_lib", "//source/server/http:admin_lib", "//test/mocks/runtime:runtime_mocks", diff --git a/test/server/http/admin_test.cc b/test/server/http/admin_test.cc index 0845d225eaaf4..66dcf511f742a 100644 --- a/test/server/http/admin_test.cc +++ b/test/server/http/admin_test.cc @@ -7,6 +7,8 @@ #include "common/http/message_impl.h" #include "common/json/json_loader.h" #include "common/profiler/profiler.h" +#include "common/protobuf/protobuf.h" +#include "common/protobuf/utility.h" #include "common/stats/stats_impl.h" #include "common/stats/thread_local_store.h" @@ -29,26 +31,18 @@ using testing::InSequence; using testing::Invoke; using testing::NiceMock; using testing::Ref; +using testing::Return; +using testing::ReturnPointee; +using testing::ReturnRef; using testing::_; namespace Envoy { namespace Server { -class AdminStatsTest : public testing::TestWithParam, - public Stats::RawStatDataAllocator { +class AdminStatsTest : public testing::TestWithParam { public: -public: - AdminStatsTest() { - ON_CALL(*this, alloc(_)) - .WillByDefault(Invoke( - [this](const std::string& name) -> Stats::RawStatData* { return alloc_.alloc(name); })); - - ON_CALL(*this, free(_)).WillByDefault(Invoke([this](Stats::RawStatData& data) -> void { - return alloc_.free(data); - })); - - EXPECT_CALL(*this, alloc("stats.overflow")); - store_.reset(new Stats::ThreadLocalStoreImpl(*this)); + AdminStatsTest() : alloc_(options_) { + store_ = std::make_unique(options_, alloc_); store_->addSink(sink_); } @@ -59,12 +53,10 @@ class AdminStatsTest : public testing::TestWithParam main_thread_dispatcher_; NiceMock tls_; - Stats::TestAllocator alloc_; + Stats::StatsOptionsImpl options_; + Stats::MockedTestAllocator alloc_; Stats::MockSink sink_; std::unique_ptr store_; }; @@ -115,7 +107,7 @@ TEST_P(AdminStatsTest, StatsAsJson) { store_->mergeHistograms([]() -> void {}); - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(alloc_, free(_)); std::map all_stats; @@ -251,7 +243,7 @@ TEST_P(AdminStatsTest, UsedOnlyStatsAsJson) { store_->mergeHistograms([]() -> void {}); - EXPECT_CALL(*this, free(_)); + EXPECT_CALL(alloc_, free(_)); std::map all_stats; @@ -411,15 +403,15 @@ TEST_P(AdminInstanceTest, AdminProfiler) { #endif -TEST_P(AdminInstanceTest, MutatesWarnWithGet) { +TEST_P(AdminInstanceTest, MutatesErrorWithGet) { Buffer::OwnedImpl data; Http::HeaderMapImpl header_map; const std::string path("/healthcheck/fail"); // TODO(jmarantz): the call to getCallback should be made to fail, but as an interim we will // just issue a warning, so that scripts using curl GET comamnds to mutate state can be fixed. - EXPECT_LOG_CONTAINS("warning", + EXPECT_LOG_CONTAINS("error", "admin path \"" + path + "\" mutates state, method=GET rather than POST", - EXPECT_EQ(Http::Code::OK, getCallback(path, header_map, data))); + EXPECT_EQ(Http::Code::BadRequest, getCallback(path, header_map, data))); } TEST_P(AdminInstanceTest, AdminBadProfiler) { @@ -549,7 +541,7 @@ TEST_P(AdminInstanceTest, ConfigDump) { } )EOF"; EXPECT_EQ(Http::Code::OK, getCallback("/config_dump", header_map, response)); - std::string output = TestUtility::bufferToString(response); + std::string output = response.toString(); EXPECT_EQ(expected_json, output); } @@ -616,7 +608,7 @@ TEST_P(AdminInstanceTest, Runtime) { EXPECT_CALL(loader, snapshot()).WillRepeatedly(testing::ReturnPointee(&snapshot)); EXPECT_CALL(server_, runtime()).WillRepeatedly(testing::ReturnPointee(&loader)); EXPECT_EQ(Http::Code::OK, getCallback("/runtime", header_map, response)); - EXPECT_EQ(expected_json, TestUtility::bufferToString(response)); + EXPECT_EQ(expected_json, response.toString()); } TEST_P(AdminInstanceTest, RuntimeModify) { @@ -632,16 +624,16 @@ TEST_P(AdminInstanceTest, RuntimeModify) { overrides["nothing"] = ""; EXPECT_CALL(loader, mergeValues(overrides)).Times(1); EXPECT_EQ(Http::Code::OK, - getCallback("/runtime_modify?foo=bar&x=42¬hing=", header_map, response)); - EXPECT_EQ("OK\n", TestUtility::bufferToString(response)); + postCallback("/runtime_modify?foo=bar&x=42¬hing=", header_map, response)); + EXPECT_EQ("OK\n", response.toString()); } TEST_P(AdminInstanceTest, RuntimeModifyNoArguments) { Http::HeaderMapImpl header_map; Buffer::OwnedImpl response; - EXPECT_EQ(Http::Code::BadRequest, getCallback("/runtime_modify", header_map, response)); - EXPECT_TRUE(absl::StartsWith(TestUtility::bufferToString(response), "usage:")); + EXPECT_EQ(Http::Code::BadRequest, postCallback("/runtime_modify", header_map, response)); + EXPECT_TRUE(absl::StartsWith(response.toString(), "usage:")); } TEST_P(AdminInstanceTest, TracingStatsDisabled) { @@ -651,6 +643,110 @@ TEST_P(AdminInstanceTest, TracingStatsDisabled) { } } +TEST_P(AdminInstanceTest, ClustersJson) { + Upstream::ClusterManager::ClusterInfoMap cluster_map; + ON_CALL(server_.cluster_manager_, clusters()).WillByDefault(ReturnPointee(&cluster_map)); + + NiceMock cluster; + cluster_map.emplace(cluster.info_->name_, cluster); + + NiceMock outlier_detector; + ON_CALL(Const(cluster), outlierDetector()).WillByDefault(Return(&outlier_detector)); + ON_CALL(outlier_detector, successRateEjectionThreshold()).WillByDefault(Return(6.0)); + + ON_CALL(*cluster.info_, addedViaApi()).WillByDefault(Return(true)); + + Upstream::MockHostSet* host_set = cluster.priority_set_.getMockHostSet(0); + auto host = std::make_shared>(); + + envoy::api::v2::core::Locality locality; + locality.set_region("test_region"); + locality.set_zone("test_zone"); + locality.set_sub_zone("test_sub_zone"); + ON_CALL(*host, locality()).WillByDefault(ReturnRef(locality)); + + host_set->hosts_.emplace_back(host); + Network::Address::InstanceConstSharedPtr address = + Network::Utility::resolveUrl("tcp://1.2.3.4:80"); + ON_CALL(*host, address()).WillByDefault(Return(address)); + + Stats::IsolatedStoreImpl store; + store.counter("test_counter").add(10); + store.gauge("test_gauge").set(11); + ON_CALL(*host, gauges()).WillByDefault(Invoke([&store]() { return store.gauges(); })); + ON_CALL(*host, counters()).WillByDefault(Invoke([&store]() { return store.counters(); })); + + ON_CALL(*host, healthFlagGet(Upstream::Host::HealthFlag::FAILED_ACTIVE_HC)) + .WillByDefault(Return(true)); + ON_CALL(*host, healthFlagGet(Upstream::Host::HealthFlag::FAILED_OUTLIER_CHECK)) + .WillByDefault(Return(true)); + ON_CALL(*host, healthFlagGet(Upstream::Host::HealthFlag::FAILED_EDS_HEALTH)) + .WillByDefault(Return(false)); + + ON_CALL(host->outlier_detector_, successRate()).WillByDefault(Return(43.2)); + + Buffer::OwnedImpl response; + Http::HeaderMapImpl header_map; + EXPECT_EQ(Http::Code::OK, getCallback("/clusters?format=json", header_map, response)); + std::string output_json = response.toString(); + envoy::admin::v2alpha::Clusters output_proto; + MessageUtil::loadFromJson(output_json, output_proto); + + const std::string expected_json = R"EOF({ + "cluster_statuses": [ + { + "name": "fake_cluster", + "success_rate_ejection_threshold": { + "value": 6 + }, + "added_via_api": true, + "host_statuses": [ + { + "address": { + "socket_address": { + "protocol": "TCP", + "address": "1.2.3.4", + "port_value": 80 + } + }, + "stats": { + "test_counter": { + "value": "10", + "type": "COUNTER" + }, + "test_gauge": { + "value": "11", + "type": "GAUGE" + }, + }, + "health_status": { + "eds_health_status": "HEALTHY", + "failed_active_health_check": true, + "failed_outlier_check": true + }, + "success_rate": { + "value": 43.2 + } + } + ] + } + ] +} +)EOF"; + + envoy::admin::v2alpha::Clusters expected_proto; + MessageUtil::loadFromJson(expected_json, expected_proto); + + // Ensure the protos created from each JSON are equivalent. + EXPECT_THAT(output_proto, ProtoEq(expected_proto)); + + // Ensure that the normal text format is used by default. + EXPECT_EQ(Http::Code::OK, getCallback("/clusters", header_map, response)); + std::string text_output = response.toString(); + envoy::admin::v2alpha::Clusters failed_conversion_proto; + EXPECT_THROW(MessageUtil::loadFromJson(text_output, failed_conversion_proto), EnvoyException); +} + TEST_P(AdminInstanceTest, GetRequest) { Http::HeaderMapImpl response_headers; std::string body; @@ -685,6 +781,7 @@ TEST_P(AdminInstanceTest, PostRequest) { class PrometheusStatsFormatterTest : public testing::Test { protected: + PrometheusStatsFormatterTest() /*: alloc_(stats_options_)*/ {} void addCounter(const std::string& name, std::vector cluster_tags) { std::string tname = std::string(name); counters_.push_back(alloc_.makeCounter(name, std::move(tname), std::move(cluster_tags))); @@ -695,7 +792,8 @@ class PrometheusStatsFormatterTest : public testing::Test { gauges_.push_back(alloc_.makeGauge(name, std::move(tname), std::move(cluster_tags))); } - Stats::HeapRawStatDataAllocator alloc_; + Stats::StatsOptionsImpl stats_options_; + Stats::HeapStatDataAllocator alloc_; std::vector counters_; std::vector gauges_; }; @@ -708,13 +806,12 @@ TEST_F(PrometheusStatsFormatterTest, MetricName) { } TEST_F(PrometheusStatsFormatterTest, FormattedTags) { - // If value has - then it should be replaced by _ . std::vector tags; Stats::Tag tag1 = {"a.tag-name", "a.tag-value"}; - Stats::Tag tag2 = {"another_tag_name", "another.tag-value"}; + Stats::Tag tag2 = {"another_tag_name", "another_tag-value"}; tags.push_back(tag1); tags.push_back(tag2); - std::string expected = "a_tag_name=\"a_tag_value\",another_tag_name=\"another_tag_value\""; + std::string expected = "a_tag_name=\"a.tag-value\",another_tag_name=\"another_tag-value\""; auto actual = PrometheusStatsFormatter::formattedTags(tags); EXPECT_EQ(expected, actual); } diff --git a/test/server/listener_manager_impl_test.cc b/test/server/listener_manager_impl_test.cc index 7f2f1cf646005..949d84787d8f4 100644 --- a/test/server/listener_manager_impl_test.cc +++ b/test/server/listener_manager_impl_test.cc @@ -23,6 +23,7 @@ #include "test/test_common/threadsafe_singleton_injector.h" #include "test/test_common/utility.h" +#include "absl/strings/match.h" #include "gtest/gtest.h" using testing::InSequence; @@ -137,31 +138,49 @@ class ListenerManagerImplWithRealFiltersTest : public ListenerManagerImplTest { context); })); socket_.reset(new NiceMock()); + address_.reset(new Network::Address::Ipv4Instance("127.0.0.1", 1234)); } const Network::FilterChain* - findFilterChain(const std::string& server_name, bool expect_server_name_match, + findFilterChain(uint16_t destination_port, bool expect_destination_port_match, + const std::string& destination_address, bool expect_destination_address_match, + const std::string& server_name, bool expect_server_name_match, const std::string& transport_protocol, bool expect_transport_protocol_match, const std::vector& application_protocols) { - EXPECT_CALL(*socket_, requestedServerName()).WillOnce(Return(absl::string_view(server_name))); + const int times = expect_destination_port_match ? 2 : 1; + if (absl::StartsWith(destination_address, "/")) { + address_.reset(new Network::Address::PipeInstance(destination_address)); + } else { + address_.reset(new Network::Address::Ipv4Instance(destination_address, destination_port)); + } + EXPECT_CALL(*socket_, localAddress()).Times(times).WillRepeatedly(ReturnRef(address_)); + + if (expect_destination_address_match) { + EXPECT_CALL(*socket_, requestedServerName()).WillOnce(Return(absl::string_view(server_name))); + } else { + EXPECT_CALL(*socket_, requestedServerName()).Times(0); + } + if (expect_server_name_match) { EXPECT_CALL(*socket_, detectedTransportProtocol()) .WillOnce(Return(absl::string_view(transport_protocol))); - if (expect_transport_protocol_match) { - EXPECT_CALL(*socket_, requestedApplicationProtocols()) - .WillOnce(ReturnRef(application_protocols)); - } else { - EXPECT_CALL(*socket_, requestedApplicationProtocols()).Times(0); - } } else { EXPECT_CALL(*socket_, detectedTransportProtocol()).Times(0); + } + + if (expect_transport_protocol_match) { + EXPECT_CALL(*socket_, requestedApplicationProtocols()) + .WillOnce(ReturnRef(application_protocols)); + } else { EXPECT_CALL(*socket_, requestedApplicationProtocols()).Times(0); } + return manager_->listeners().back().get().filterChainManager().findFilterChain(*socket_); } private: std::unique_ptr socket_; + Network::Address::InstanceConstSharedPtr address_; }; class MockLdsApi : public LdsApi { @@ -232,7 +251,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SslContext) { manager_->addOrUpdateListener(parseListenerFromJson(json), "", true); EXPECT_EQ(1U, manager_->listeners().size()); - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); } @@ -387,6 +406,66 @@ TEST_F(ListenerManagerImplTest, AddListenerAddressNotMatching) { EXPECT_CALL(*listener_foo, onDestroy()); } +// Make sure that a listener creation does not fail on IPv4 ony setups when FilterChainMatch is not +// specified and we try to create default CidrRange. See convertDestinationIPsMapToTrie function for +// more details. +TEST_F(ListenerManagerImplTest, AddListenerOnIpv4OnlySetups) { + InSequence s; + + NiceMock os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + const std::string listener_foo_json = R"EOF( + { + "name": "foo", + "address": "tcp://127.0.0.1:1234", + "filters": [], + "drain_type": "default" + } + )EOF"; + + ListenerHandle* listener_foo = expectListenerCreate(false); + + EXPECT_CALL(os_sys_calls, socket(AF_INET, _, 0)).WillOnce(Return(5)); + EXPECT_CALL(os_sys_calls, socket(AF_INET6, _, 0)).WillOnce(Return(-1)); + + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + + EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromJson(listener_foo_json), "", true)); + checkStats(1, 0, 0, 0, 1, 0); + EXPECT_CALL(*listener_foo, onDestroy()); +} + +// Make sure that a listener creation does not fail on IPv6 ony setups when FilterChainMatch is not +// specified and we try to create default CidrRange. See convertDestinationIPsMapToTrie function for +// more details. +TEST_F(ListenerManagerImplTest, AddListenerOnIpv6OnlySetups) { + InSequence s; + + NiceMock os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + const std::string listener_foo_json = R"EOF( + { + "name": "foo", + "address": "tcp://[::0001]:1234", + "filters": [], + "drain_type": "default" + } + )EOF"; + + ListenerHandle* listener_foo = expectListenerCreate(false); + + EXPECT_CALL(os_sys_calls, socket(AF_INET, _, 0)).WillOnce(Return(-1)); + EXPECT_CALL(os_sys_calls, socket(AF_INET6, _, 0)).WillOnce(Return(5)); + + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + + EXPECT_TRUE(manager_->addOrUpdateListener(parseListenerFromJson(listener_foo_json), "", true)); + checkStats(1, 0, 0, 0, 1, 0); + EXPECT_CALL(*listener_foo, onDestroy()); +} + // Make sure that a listener that is not modifiable cannot be updated or removed. TEST_F(ListenerManagerImplTest, UpdateRemoveNotModifiableListener) { ON_CALL(system_time_source_, currentTime()) @@ -1033,6 +1112,90 @@ TEST_F(ListenerManagerImplTest, EarlyShutdown) { manager_->stopWorkers(); } +TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationPortMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + destination_port: 8080 + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to unknown port - no match. + auto filter_chain = findFilterChain(1234, false, "127.0.0.1", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); + + // IPv4 client connects to valid port - using 1st filter chain. + filter_chain = findFilterChain(8080, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // UDS client - no match. + filter_chain = findFilterChain(0, false, "/tmp/test.sock", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); +} + +TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDestinationIPMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + prefix_ranges: { address_prefix: 127.0.0.0, prefix_len: 8 } + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to unknown IP - no match. + auto filter_chain = findFilterChain(1234, true, "1.2.3.4", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); + + // IPv4 client connects to valid IP - using 1st filter chain. + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // UDS client - no match. + filter_chain = findFilterChain(0, true, "/tmp/test.sock", false, "", false, "tls", false, {}); + EXPECT_EQ(filter_chain, nullptr); +} + TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithServerNamesMatch) { const std::string yaml = TestEnvironment::substitute(R"EOF( address: @@ -1057,15 +1220,17 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithServerNamesM EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI - no match. - auto filter_chain = findFilterChain("", false, "tls", false, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", false, "tls", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client without matching SNI - no match. - filter_chain = findFilterChain("www.example.com", false, "tls", false, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "www.example.com", false, "tls", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with matching SNI - using 1st filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1099,11 +1264,12 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithTransportPro EXPECT_EQ(1U, manager_->listeners().size()); // TCP client - no match. - auto filter_chain = findFilterChain("", true, "raw_buffer", false, {}); + auto filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "raw_buffer", false, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client - using 1st filter chain. - filter_chain = findFilterChain("", true, "tls", true, {}); + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1137,11 +1303,12 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithApplicationP EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without ALPN - no match. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with "http/1.1" ALPN - using 1st filter chain. - filter_chain = findFilterChain("", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1151,6 +1318,158 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithApplicationP EXPECT_EQ(server_names.front(), "server1.example.com"); } +TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinationPortMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + # empty + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" } + - filter_chain_match: + destination_port: 8080 + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + - filter_chain_match: + destination_port: 8081 + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to default port - using 1st filter chain. + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); + + // IPv4 client connects to port 8080 - using 2nd filter chain. + filter_chain = findFilterChain(8080, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // IPv4 client connects to port 8081 - using 3nd filter chain. + filter_chain = findFilterChain(8081, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 2); + EXPECT_EQ(server_names.front(), "*.example.com"); + + // UDS client - using 1st filter chain. + filter_chain = findFilterChain(0, true, "/tmp/test.sock", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); +} + +TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithDestinationIPMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + # empty + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_uri_key.pem" } + - filter_chain_match: + prefix_ranges: { address_prefix: 192.168.0.1, prefix_len: 32 } + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } + - filter_chain_match: + prefix_ranges: { address_prefix: 192.168.0.0, prefix_len: 16 } + tls_context: + common_tls_context: + tls_certificates: + - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_cert.pem" } + private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_multiple_dns_key.pem" } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + // IPv4 client connects to default IP - using 1st filter chain. + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + auto ssl_socket = dynamic_cast(transport_socket.get()); + auto uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); + + // IPv4 client connects to exact IP match - using 2nd filter chain. + filter_chain = findFilterChain(1234, true, "192.168.0.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + auto server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 1); + EXPECT_EQ(server_names.front(), "server1.example.com"); + + // IPv4 client connects to wildcard IP match - using 3nd filter chain. + filter_chain = findFilterChain(1234, true, "192.168.1.1", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + server_names = ssl_socket->dnsSansLocalCertificate(); + EXPECT_EQ(server_names.size(), 2); + EXPECT_EQ(server_names.front(), "*.example.com"); + + // UDS client - using 1st filter chain. + filter_chain = findFilterChain(0, true, "/tmp/test.sock", true, "", true, "tls", true, {}); + ASSERT_NE(filter_chain, nullptr); + EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); + transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); + ssl_socket = dynamic_cast(transport_socket.get()); + uri = ssl_socket->uriSanLocalCertificate(); + EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); +} + TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNamesMatch) { const std::string yaml = TestEnvironment::substitute(R"EOF( address: @@ -1198,7 +1517,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1207,7 +1526,8 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(uri, "spiffe://lyft.com/test-team"); // TLS client with exact SNI match - using 2nd filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1217,7 +1537,8 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(server_names.front(), "server1.example.com"); // TLS client with wildcard SNI match - using 3nd filter chain. - filter_chain = findFilterChain("server2.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server2.example.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1227,7 +1548,8 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithServerNam EXPECT_EQ(server_names.front(), "*.example.com"); // TLS client with wildcard SNI match - using 3nd filter chain. - filter_chain = findFilterChain("www.wildcard.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "www.wildcard.com", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1263,12 +1585,13 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithTransport EXPECT_EQ(1U, manager_->listeners().size()); // TCP client - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "raw_buffer", true, {}); + auto filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "raw_buffer", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client - using 2nd filter chain. - filter_chain = findFilterChain("", true, "tls", true, {}); + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1304,12 +1627,13 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithApplicati EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without ALPN - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client with "h2,http/1.1" ALPN - using 2nd filter chain. - filter_chain = findFilterChain("", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1347,21 +1671,24 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithMultipleR EXPECT_EQ(1U, manager_->listeners().size()); // TLS client without SNI and ALPN - using 1st filter chain. - auto filter_chain = findFilterChain("", true, "tls", true, {}); + auto filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client with exact SNI match but without ALPN - no match (SNI blackholed by configuration). - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", true, {}); EXPECT_EQ(filter_chain, nullptr); // TLS client with ALPN match but without SNI - using 1st filter chain. - filter_chain = findFilterChain("", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = + findFilterChain(1234, true, "127.0.0.1", true, "", true, "tls", true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_FALSE(filter_chain->transportSocketFactory().implementsSecureTransport()); // TLS client with exact SNI match and ALPN match - using 2nd filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {"h2", "http/1.1"}); + filter_chain = findFilterChain(1234, true, "127.0.0.1", true, "server1.example.com", true, "tls", + true, {"h2", "http/1.1"}); ASSERT_NE(filter_chain, nullptr); EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); @@ -1443,6 +1770,23 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, EXPECT_EQ(1U, manager_->listeners().size()); } +TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithInvalidDestinationIPMatch) { + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: 127.0.0.1, port_value: 1234 } + listener_filters: + - name: "envoy.listener.tls_inspector" + config: {} + filter_chains: + - filter_chain_match: + prefix_ranges: { address_prefix: a.b.c.d, prefix_len: 32 } + )EOF", + Network::Address::IpVersion::v4); + + EXPECT_THROW_WITH_MESSAGE(manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true), + EnvoyException, "malformed IP address: a.b.c.d"); +} + TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithInvalidServerNamesMatch) { const std::string yaml = TestEnvironment::substitute(R"EOF( address: @@ -1459,7 +1803,7 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithInvalidServe EXPECT_THROW_WITH_MESSAGE(manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true), EnvoyException, "error adding listener '127.0.0.1:1234': partial wildcards are not " - "supported in \"server_names\" (or the deprecated \"sni_domains\")"); + "supported in \"server_names\""); } TEST_F(ListenerManagerImplWithRealFiltersTest, MultipleFilterChainsWithSameMatch) { @@ -1590,67 +1934,6 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, CustomTransportProtocolWithSniWit EXPECT_TRUE(filterChainFactory.createListenerFilterChain(manager)); } -// Copy of the SingleFilterChainWithServerNamesMatch to make sure it behaves the same. -TEST_F(ListenerManagerImplWithRealFiltersTest, SingleFilterChainWithDeprecatedSniDomainsMatch) { - const std::string yaml = TestEnvironment::substitute(R"EOF( - address: - socket_address: { address: 127.0.0.1, port_value: 1234 } - listener_filters: - - name: "envoy.listener.tls_inspector" - config: {} - filter_chains: - - filter_chain_match: - sni_domains: "server1.example.com" - tls_context: - common_tls_context: - tls_certificates: - - certificate_chain: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_cert.pem" } - private_key: { filename: "{{ test_rundir }}/test/common/ssl/test_data/san_dns_key.pem" } - )EOF", - Network::Address::IpVersion::v4); - - EXPECT_CALL(server_.random_, uuid()); - EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); - manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); - EXPECT_EQ(1U, manager_->listeners().size()); - - // TLS client without SNI - no match. - auto filter_chain = findFilterChain("", false, "tls", false, {}); - EXPECT_EQ(filter_chain, nullptr); - - // TLS client without matching SNI - no match. - filter_chain = findFilterChain("www.example.com", false, "tls", false, {}); - EXPECT_EQ(filter_chain, nullptr); - - // TLS client with matching SNI - using 1st filter chain. - filter_chain = findFilterChain("server1.example.com", true, "tls", true, {}); - ASSERT_NE(filter_chain, nullptr); - EXPECT_TRUE(filter_chain->transportSocketFactory().implementsSecureTransport()); - auto transport_socket = filter_chain->transportSocketFactory().createTransportSocket(); - auto ssl_socket = dynamic_cast(transport_socket.get()); - auto server_names = ssl_socket->dnsSansLocalCertificate(); - EXPECT_EQ(server_names.size(), 1); - EXPECT_EQ(server_names.front(), "server1.example.com"); -} - -TEST_F(ListenerManagerImplWithRealFiltersTest, DeprecatedSniDomainsAndServerNamesUsedTogether) { - const std::string yaml = TestEnvironment::substitute(R"EOF( - address: - socket_address: { address: 127.0.0.1, port_value: 1234 } - filter_chains: - - filter_chain_match: - server_names: "example.com" - sni_domains: "www.example.com" - )EOF", - Network::Address::IpVersion::v4); - - EXPECT_THROW_WITH_MESSAGE( - manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true), EnvoyException, - "error adding listener '127.0.0.1:1234': both \"server_names\" and the deprecated " - "\"sni_domains\" are used, please merge the list of expected server names into " - "\"server_names\" and remove \"sni_domains\""); -} - TEST_F(ListenerManagerImplWithRealFiltersTest, TlsCertificateInline) { const std::string yaml = R"EOF( address: @@ -1983,6 +2266,143 @@ TEST_F(ListenerManagerImplWithRealFiltersTest, OriginalDstTestFilterOptionFail) EXPECT_EQ(0U, manager_->listeners().size()); } +class OriginalDstTestFilterIPv6 + : public Extensions::ListenerFilters::OriginalDst::OriginalDstFilter { + Network::Address::InstanceConstSharedPtr getOriginalDst(int) override { + return Network::Address::InstanceConstSharedPtr{ + new Network::Address::Ipv6Instance("1::2", 2345)}; + } +}; + +TEST_F(ListenerManagerImplWithRealFiltersTest, OriginalDstTestFilterIPv6) { + static int fd; + fd = -1; + EXPECT_CALL(*listener_factory_.socket_, fd()).WillOnce(Return(0)); + + class OriginalDstTestConfigFactory : public Configuration::NamedListenerFilterConfigFactory { + public: + // NamedListenerFilterConfigFactory + Network::ListenerFilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message&, + Configuration::ListenerFactoryContext& context) override { + auto option = std::make_unique(); + EXPECT_CALL(*option, setOption(_, envoy::api::v2::core::SocketOption::STATE_PREBIND)) + .WillOnce(Return(true)); + EXPECT_CALL(*option, setOption(_, envoy::api::v2::core::SocketOption::STATE_BOUND)) + .WillOnce(Invoke( + [](Network::Socket& socket, envoy::api::v2::core::SocketOption::SocketState) -> bool { + fd = socket.fd(); + return true; + })); + context.addListenSocketOption(std::move(option)); + return [](Network::ListenerFilterManager& filter_manager) -> void { + filter_manager.addAcceptFilter(std::make_unique()); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return "test.listener.original_dstipv6"; } + }; + + /** + * Static registration for the original dst filter. @see RegisterFactory. + */ + static Registry::RegisterFactory + registered_; + + const std::string yaml = TestEnvironment::substitute(R"EOF( + address: + socket_address: { address: ::0001, port_value: 1111 } + filter_chains: {} + listener_filters: + - name: "test.listener.original_dstipv6" + config: {} + )EOF", + Network::Address::IpVersion::v6); + + EXPECT_CALL(server_.random_, uuid()); + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true); + EXPECT_EQ(1U, manager_->listeners().size()); + + Network::ListenerConfig& listener = manager_->listeners().back().get(); + + Network::FilterChainFactory& filterChainFactory = listener.filterChainFactory(); + Network::MockListenerFilterManager manager; + + NiceMock callbacks; + Network::AcceptedSocketImpl socket( + -1, std::make_unique("::0001", 1234), + std::make_unique("::0001", 5678)); + + EXPECT_CALL(callbacks, socket()).WillOnce(Invoke([&]() -> Network::ConnectionSocket& { + return socket; + })); + + EXPECT_CALL(manager, addAcceptFilter_(_)) + .WillOnce(Invoke([&](Network::ListenerFilterPtr& filter) -> void { + EXPECT_EQ(Network::FilterStatus::Continue, filter->onAccept(callbacks)); + })); + + EXPECT_TRUE(filterChainFactory.createListenerFilterChain(manager)); + EXPECT_TRUE(socket.localAddressRestored()); + EXPECT_EQ("[1::2]:2345", socket.localAddress()->asString()); + EXPECT_NE(fd, -1); +} + +TEST_F(ListenerManagerImplWithRealFiltersTest, OriginalDstTestFilterOptionFailIPv6) { + class OriginalDstTestConfigFactory : public Configuration::NamedListenerFilterConfigFactory { + public: + // NamedListenerFilterConfigFactory + Network::ListenerFilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message&, + Configuration::ListenerFactoryContext& context) override { + auto option = std::make_unique(); + EXPECT_CALL(*option, setOption(_, envoy::api::v2::core::SocketOption::STATE_PREBIND)) + .WillOnce(Return(false)); + context.addListenSocketOption(std::move(option)); + return [](Network::ListenerFilterManager& filter_manager) -> void { + filter_manager.addAcceptFilter(std::make_unique()); + }; + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return "testfail.listener.original_dstipv6"; } + }; + + /** + * Static registration for the original dst filter. @see RegisterFactory. + */ + static Registry::RegisterFactory + registered_; + + const std::string yaml = TestEnvironment::substitute(R"EOF( + name: "socketOptionFailListener" + address: + socket_address: { address: ::0001, port_value: 1111 } + filter_chains: {} + listener_filters: + - name: "testfail.listener.original_dstipv6" + config: {} + )EOF", + Network::Address::IpVersion::v6); + + EXPECT_CALL(listener_factory_, createListenSocket(_, _, true)); + + EXPECT_THROW_WITH_MESSAGE(manager_->addOrUpdateListener(parseListenerFromV2Yaml(yaml), "", true), + EnvoyException, + "MockListenerComponentFactory: Setting socket options failed"); + EXPECT_EQ(0U, manager_->listeners().size()); +} + // Validate that when neither transparent nor freebind is not set in the // Listener, we see no socket option set. TEST_F(ListenerManagerImplWithRealFiltersTest, TransparentFreebindListenerDisabled) { diff --git a/test/server/options_impl_test.cc b/test/server/options_impl_test.cc index e76fc76e22a7c..d5d7a6a5f4f05 100644 --- a/test/server/options_impl_test.cc +++ b/test/server/options_impl_test.cc @@ -37,16 +37,6 @@ std::unique_ptr createOptionsImpl(const std::string& args) { } // namespace TEST(OptionsImplTest, HotRestartVersion) { - // There's an evil static local in - // Stats::RawStatsData::initializeAndGetMutableMaxObjNameLength, which causes - // problems when all test.cc files are linked together for coverage-testing. - // This resets the static to the default options-value of 60. Note; this is only - // needed in coverage tests. - { - auto options = createOptionsImpl("envoy"); - Stats::RawStatData::configureForTestsOnly(*options); - } - EXPECT_THROW_WITH_REGEX(createOptionsImpl("envoy --hot-restart-version"), NoServingException, "NoServingException"); } @@ -60,6 +50,27 @@ TEST(OptionsImplTest, InvalidCommandLine) { "Couldn't find match for argument"); } +TEST(OptionsImplTest, v1Allowed) { + std::unique_ptr options = createOptionsImpl( + "envoy --mode validate --concurrency 2 -c hello --admin-address-path path --restart-epoch 1 " + "--local-address-ip-version v6 -l info --service-cluster cluster --service-node node " + "--service-zone zone --file-flush-interval-msec 9000 --drain-time-s 60 --log-format [%v] " + "--parent-shutdown-time-s 90 --log-path /foo/bar --allow-deprecated-v1-api " + "--disable-hot-restart"); + EXPECT_EQ(Server::Mode::Validate, options->mode()); + EXPECT_FALSE(options->v2ConfigOnly()); +} + +TEST(OptionsImplTest, v1Disallowed) { + std::unique_ptr options = createOptionsImpl( + "envoy --mode validate --concurrency 2 -c hello --admin-address-path path --restart-epoch 1 " + "--local-address-ip-version v6 -l info --service-cluster cluster --service-node node " + "--service-zone zone --file-flush-interval-msec 9000 --drain-time-s 60 --log-format [%v] " + "--parent-shutdown-time-s 90 --log-path /foo/bar --disable-hot-restart"); + EXPECT_EQ(Server::Mode::Validate, options->mode()); + EXPECT_TRUE(options->v2ConfigOnly()); +} + TEST(OptionsImplTest, All) { std::unique_ptr options = createOptionsImpl( "envoy --mode validate --concurrency 2 -c hello --admin-address-path path --restart-epoch 1 " @@ -91,7 +102,11 @@ TEST(OptionsImplTest, All) { TEST(OptionsImplTest, SetAll) { std::unique_ptr options = createOptionsImpl("envoy -c hello"); bool v2_config_only = options->v2ConfigOnly(); - bool hot_restart_disabled = options->v2ConfigOnly(); + bool hot_restart_disabled = options->hotRestartDisabled(); + Stats::StatsOptionsImpl stats_options; + stats_options.max_obj_name_length_ = 54321; + stats_options.max_stat_suffix_length_ = 1234; + options->setBaseId(109876); options->setConcurrency(42); options->setConfigPath("foo"); @@ -111,7 +126,7 @@ TEST(OptionsImplTest, SetAll) { options->setServiceNodeName("node_foo"); options->setServiceZone("zone_foo"); options->setMaxStats(12345); - options->setMaxObjNameLength(54321); + options->setStatsOptions(stats_options); options->setHotRestartDisabled(!options->hotRestartDisabled()); EXPECT_EQ(109876, options->baseId()); @@ -133,7 +148,8 @@ TEST(OptionsImplTest, SetAll) { EXPECT_EQ("node_foo", options->serviceNodeName()); EXPECT_EQ("zone_foo", options->serviceZone()); EXPECT_EQ(12345U, options->maxStats()); - EXPECT_EQ(54321U, options->maxObjNameLength()); + EXPECT_EQ(stats_options.max_obj_name_length_, options->statsOptions().maxObjNameLength()); + EXPECT_EQ(stats_options.max_stat_suffix_length_, options->statsOptions().maxStatSuffixLength()); EXPECT_EQ(!hot_restart_disabled, options->hotRestartDisabled()); } diff --git a/test/server/overload_manager_impl_test.cc b/test/server/overload_manager_impl_test.cc new file mode 100644 index 0000000000000..352a9bf6ec8b1 --- /dev/null +++ b/test/server/overload_manager_impl_test.cc @@ -0,0 +1,247 @@ +#include "envoy/server/resource_monitor.h" +#include "envoy/server/resource_monitor_config.h" + +#include "server/overload_manager_impl.h" + +#include "extensions/resource_monitors/common/factory_base.h" + +#include "test/mocks/event/mocks.h" +#include "test/test_common/registry.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Invoke; +using testing::NiceMock; +using testing::_; + +namespace Envoy { +namespace Server { +namespace { + +class FakeResourceMonitor : public ResourceMonitor { +public: + FakeResourceMonitor(Event::Dispatcher& dispatcher) + : success_(true), pressure_(0), error_("fake error"), dispatcher_(dispatcher) {} + + void setPressure(double pressure) { + success_ = true; + pressure_ = pressure; + } + + void setError() { success_ = false; } + + void updateResourceUsage(ResourceMonitor::Callbacks& callbacks) override { + if (success_) { + Server::ResourceUsage usage; + usage.resource_pressure_ = pressure_; + dispatcher_.post([&, usage]() { callbacks.onSuccess(usage); }); + } else { + EnvoyException& error = error_; + dispatcher_.post([&, error]() { callbacks.onFailure(error); }); + } + } + +private: + bool success_; + double pressure_; + EnvoyException error_; + Event::Dispatcher& dispatcher_; +}; + +class FakeResourceMonitorFactory : public Extensions::ResourceMonitors::Common::FactoryBase< + envoy::config::overload::v2alpha::EmptyConfig> { +public: + FakeResourceMonitorFactory(const std::string& name) : FactoryBase(name), monitor_(nullptr) {} + + ResourceMonitorPtr createResourceMonitorFromProtoTyped( + const envoy::config::overload::v2alpha::EmptyConfig&, + Server::Configuration::ResourceMonitorFactoryContext& context) override { + auto monitor = std::make_unique(context.dispatcher()); + monitor_ = monitor.get(); + return std::move(monitor); + } + + FakeResourceMonitor* monitor_; // not owned +}; + +class OverloadManagerImplTest : public testing::Test { +protected: + OverloadManagerImplTest() + : factory1_("envoy.resource_monitors.fake_resource1"), + factory2_("envoy.resource_monitors.fake_resource2"), register_factory1_(factory1_), + register_factory2_(factory2_) {} + + void setDispatcherExpectation() { + EXPECT_CALL(dispatcher_, createTimer_(_)).WillOnce(Invoke([&](Event::TimerCb cb) { + timer_cb_ = cb; + return new NiceMock(); + })); + } + + envoy::config::overload::v2alpha::OverloadManager parseConfig(const std::string& config) { + envoy::config::overload::v2alpha::OverloadManager proto; + bool success = Protobuf::TextFormat::ParseFromString(config, &proto); + ASSERT(success); + return proto; + } + + FakeResourceMonitorFactory factory1_; + FakeResourceMonitorFactory factory2_; + Registry::InjectFactory register_factory1_; + Registry::InjectFactory register_factory2_; + NiceMock dispatcher_; + Event::TimerCb timer_cb_; +}; + +TEST_F(OverloadManagerImplTest, CallbackOnlyFiresWhenStateChanges) { + setDispatcherExpectation(); + + const std::string config = R"EOF( + refresh_interval { + seconds: 1 + } + resource_monitors { + name: "envoy.resource_monitors.fake_resource1" + } + resource_monitors { + name: "envoy.resource_monitors.fake_resource2" + } + actions { + name: "envoy.overload_actions.dummy_action" + triggers { + name: "envoy.resource_monitors.fake_resource1" + threshold { + value: 0.9 + } + } + triggers { + name: "envoy.resource_monitors.fake_resource2" + threshold { + value: 0.8 + } + } + } + )EOF"; + + OverloadManagerImpl manager(dispatcher_, parseConfig(config)); + bool is_active = false; + int cb_count = 0; + manager.registerForAction("envoy.overload_actions.dummy_action", dispatcher_, + [&](OverloadActionState state) { + is_active = state == OverloadActionState::Active; + cb_count++; + }); + manager.registerForAction("envoy.overload_actions.unknown_action", dispatcher_, + [&](OverloadActionState) { ASSERT(false); }); + manager.start(); + + factory1_.monitor_->setPressure(0.5); + timer_cb_(); + EXPECT_FALSE(is_active); + EXPECT_EQ(0, cb_count); + + factory1_.monitor_->setPressure(0.95); + timer_cb_(); + EXPECT_TRUE(is_active); + EXPECT_EQ(1, cb_count); + + // Callback should not be invoked if action active state has not changed + factory1_.monitor_->setPressure(0.94); + timer_cb_(); + EXPECT_TRUE(is_active); + EXPECT_EQ(1, cb_count); + + // Different triggers firing but overall action remains active so no callback expected + factory1_.monitor_->setPressure(0.5); + factory2_.monitor_->setPressure(0.9); + timer_cb_(); + EXPECT_TRUE(is_active); + EXPECT_EQ(1, cb_count); + + factory2_.monitor_->setPressure(0.4); + timer_cb_(); + EXPECT_FALSE(is_active); + EXPECT_EQ(2, cb_count); + + factory1_.monitor_->setPressure(0.95); + factory1_.monitor_->setError(); + timer_cb_(); + EXPECT_FALSE(is_active); + EXPECT_EQ(2, cb_count); +} + +TEST_F(OverloadManagerImplTest, DuplicateResourceMonitor) { + const std::string config = R"EOF( + resource_monitors { + name: "envoy.resource_monitors.fake_resource1" + } + resource_monitors { + name: "envoy.resource_monitors.fake_resource1" + } + )EOF"; + + EXPECT_THROW_WITH_REGEX(OverloadManagerImpl(dispatcher_, parseConfig(config)), EnvoyException, + "Duplicate resource monitor .*"); +} + +TEST_F(OverloadManagerImplTest, DuplicateOverloadAction) { + const std::string config = R"EOF( + actions { + name: "envoy.overload_actions.dummy_action" + } + actions { + name: "envoy.overload_actions.dummy_action" + } + )EOF"; + + EXPECT_THROW_WITH_REGEX(OverloadManagerImpl(dispatcher_, parseConfig(config)), EnvoyException, + "Duplicate overload action .*"); +} + +TEST_F(OverloadManagerImplTest, UnknownTrigger) { + const std::string config = R"EOF( + actions { + name: "envoy.overload_actions.dummy_action" + triggers { + name: "envoy.resource_monitors.fake_resource1" + threshold { + value: 0.9 + } + } + } + )EOF"; + + EXPECT_THROW_WITH_REGEX(OverloadManagerImpl(dispatcher_, parseConfig(config)), EnvoyException, + "Unknown trigger resource .*"); +} + +TEST_F(OverloadManagerImplTest, DuplicateTrigger) { + const std::string config = R"EOF( + resource_monitors { + name: "envoy.resource_monitors.fake_resource1" + } + actions { + name: "envoy.overload_actions.dummy_action" + triggers { + name: "envoy.resource_monitors.fake_resource1" + threshold { + value: 0.9 + } + } + triggers { + name: "envoy.resource_monitors.fake_resource1" + threshold { + value: 0.8 + } + } + } + )EOF"; + + EXPECT_THROW_WITH_REGEX(OverloadManagerImpl(dispatcher_, parseConfig(config)), EnvoyException, + "Duplicate trigger .*"); +} +} // namespace +} // namespace Server +} // namespace Envoy diff --git a/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-5366294281977856 b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-5366294281977856 new file mode 100644 index 0000000000000..675806abca08d --- /dev/null +++ b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-5366294281977856 @@ -0,0 +1,126 @@ +static_resources { + clusters { + name: "9" + connect_timeout { + nanos: 1 + } + hosts { + pipe { + path: "N" + } + } + hosts { + pipe { + path: "n" + } + } + hosts { + pipe { + path: "=" + } + } + hosts { + pipe { + path: "s" + } + } + hosts { + pipe { + path: "n" + } + } + hosts { + pipe { + path: "N" + } + } + hosts { + pipe { + path: "W" + } + } + hosts { + pipe { + path: "=" + } + } + hosts { + pipe { + path: "N" + } + } + hosts { + pipe { + path: "n" + } + } + health_checks { + timeout { + nanos: 1 + } + interval { + nanos: 1 + } + unhealthy_threshold { + } + healthy_threshold { + value: 1701650432 + } + http_health_check { + path: "~" + request_headers_to_add { + } + request_headers_to_add { + header { + value: "W" + } + } + request_headers_to_add { + } + request_headers_to_add { + } + request_headers_to_add { + } + request_headers_to_add { + } + request_headers_to_add { + } + request_headers_to_add { + } + request_headers_to_add { + } + use_http2: true + } + } + health_checks { + timeout { + nanos: 1 + } + interval { + nanos: 1 + } + unhealthy_threshold { + } + healthy_threshold { + } + http_health_check { + path: "E" + request_headers_to_add { + } + request_headers_to_add { + } + request_headers_to_add { + } + } + } + } +} +admin { + access_log_path: "@\'" + address { + socket_address { + address: "::" + port_value: 0 + } + } +} diff --git a/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-5988544525893632 b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-5988544525893632 new file mode 100644 index 0000000000000..5c8b2ec2c49e0 --- /dev/null +++ b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-5988544525893632 @@ -0,0 +1,29 @@ +static_resources { + clusters { + name: "-2353373969551157135775236" + connect_timeout { + seconds: 12884901890 + } + hosts { + pipe { + path: "@" + } + } + outlier_detection { + } + common_lb_config { + healthy_panic_threshold { + value: nan + } + } + } +} +admin { + access_log_path: "@r" + address { + pipe { + path: "W" + } + } +} + diff --git a/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6036175623028736 b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6036175623028736 new file mode 100644 index 0000000000000..9161317f6aff9 --- /dev/null +++ b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6036175623028736 @@ -0,0 +1,26 @@ +dynamic_resources { + ads_config { + api_type: GRPC + grpc_services { + google_grpc { + target_uri: "\177\177" + stat_prefix: "\177\001D\177" + } + timeout { + seconds: 2048 + } + initial_metadata { + value: "\177\177\177\177" + } + } + } +} +flags_path: "\'" +admin { + access_log_path: "@" + address { + pipe { + path: "^" + } + } +} diff --git a/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6419204524736512 b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6419204524736512 new file mode 100644 index 0000000000000..24707716b0f17 --- /dev/null +++ b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6419204524736512 @@ -0,0 +1,22 @@ +static_resources { + clusters { + name: "`" + connect_timeout { + nanos: 20 + } + load_assignment { + cluster_name: "`" + endpoints { + priority: 1030831324 + } + } + } +} +admin { + access_log_path: "@@" + address { + pipe { + path: "`" + } + } +} diff --git a/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6610050496856064 b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6610050496856064 index 64ceff65aa686..3314510bb3fde 100644 --- a/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6610050496856064 +++ b/test/server/server_corpus/clusterfuzz-testcase-server_fuzz_test-6610050496856064 @@ -28,7 +28,7 @@ static_resources { } filter_chains { filter_chain_match { - sni_domains: "6e702f1f66d415068aabbc60377ad67a326b6b2b" + server_names: "6e702f1f66d415068aabbc60377ad67a326b6b2b" } } filter_chains { diff --git a/test/server/server_fuzz_test.cc b/test/server/server_fuzz_test.cc index 48906b07bddc8..6b8d057ce682a 100644 --- a/test/server/server_fuzz_test.cc +++ b/test/server/server_fuzz_test.cc @@ -3,6 +3,7 @@ #include "common/network/address_impl.h" #include "common/thread_local/thread_local_impl.h" +#include "server/proto_descriptors.h" #include "server/server.h" #include "server/test_hooks.h" @@ -24,6 +25,8 @@ DEFINE_PROTO_FUZZER(const envoy::config::bootstrap::v2::Bootstrap& input) { TestComponentFactory component_factory; ThreadLocal::InstanceImpl thread_local_instance; + RELEASE_ASSERT(Envoy::Server::validateProtoDescriptors(), ""); + { const std::string bootstrap_path = TestEnvironment::temporaryPath("bootstrap.pb_text"); std::ofstream bootstrap_file(bootstrap_path); diff --git a/test/server/server_test.cc b/test/server/server_test.cc index 72710b45aa25c..ec0c56d040dbe 100644 --- a/test/server/server_test.cc +++ b/test/server/server_test.cc @@ -117,6 +117,22 @@ class ServerInstanceImplTest : public testing::TestWithParamapi().fileExists("/dev/null")); } + void initializeWithHealthCheckParams(const std::string& bootstrap_path, const double timeout, + const double interval) { + options_.config_path_ = TestEnvironment::temporaryFileSubstitute( + bootstrap_path, + {{"health_check_timeout", fmt::format("{}", timeout).c_str()}, + {"health_check_interval", fmt::format("{}", interval).c_str()}}, + TestEnvironment::PortMap{}, version_); + server_.reset(new InstanceImpl( + options_, + Network::Address::InstanceConstSharedPtr(new Network::Address::Ipv4Instance("127.0.0.1")), + hooks_, restart_, stats_store_, fakelock_, component_factory_, + std::make_unique>(), thread_local_)); + + EXPECT_TRUE(server_->api().fileExists("/dev/null")); + } + Network::Address::IpVersion version_; testing::NiceMock options_; DefaultTestHooks hooks_; @@ -189,6 +205,41 @@ TEST_P(ServerInstanceImplTest, BootstrapClusterManagerInitializationFail) { "cluster manager: duplicate cluster 'service_google'"); } +// Test for protoc-gen-validate constraint on invalid timeout entry of a health check config entry. +TEST_P(ServerInstanceImplTest, BootstrapClusterHealthCheckInvalidTimeout) { + options_.v2_config_only_ = true; + EXPECT_THROW_WITH_REGEX( + initializeWithHealthCheckParams("test/server/cluster_health_check_bootstrap.yaml", 0, 0.25), + EnvoyException, + "HealthCheckValidationError.Timeout: \\[\"value must be greater than \" \"0s\"\\]"); +} + +// Test for protoc-gen-validate constraint on invalid interval entry of a health check config entry. +TEST_P(ServerInstanceImplTest, BootstrapClusterHealthCheckInvalidInterval) { + options_.v2_config_only_ = true; + EXPECT_THROW_WITH_REGEX( + initializeWithHealthCheckParams("test/server/cluster_health_check_bootstrap.yaml", 0.5, 0), + EnvoyException, + "HealthCheckValidationError.Interval: \\[\"value must be greater than \" \"0s\"\\]"); +} + +// Test for protoc-gen-validate constraint on invalid timeout and interval entry of a health check +// config entry. +TEST_P(ServerInstanceImplTest, BootstrapClusterHealthCheckInvalidTimeoutAndInterval) { + options_.v2_config_only_ = true; + EXPECT_THROW_WITH_REGEX( + initializeWithHealthCheckParams("test/server/cluster_health_check_bootstrap.yaml", 0, 0), + EnvoyException, + "HealthCheckValidationError.Timeout: \\[\"value must be greater than \" \"0s\"\\]"); +} + +// Test for protoc-gen-validate constraint on valid interval entry of a health check config entry. +TEST_P(ServerInstanceImplTest, BootstrapClusterHealthCheckValidTimeoutAndInterval) { + options_.v2_config_only_ = true; + EXPECT_NO_THROW(initializeWithHealthCheckParams("test/server/cluster_health_check_bootstrap.yaml", + 0.25, 0.5)); +} + // Negative test for protoc-gen-validate constraints. TEST_P(ServerInstanceImplTest, ValidateFail) { options_.service_cluster_name_ = "some_cluster_name"; diff --git a/test/server/utility.h b/test/server/utility.h index 9f9376e62cb1c..76838f2c79c21 100644 --- a/test/server/utility.h +++ b/test/server/utility.h @@ -10,7 +10,8 @@ namespace { inline envoy::api::v2::Listener parseListenerFromJson(const std::string& json_string) { envoy::api::v2::Listener listener; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Config::LdsJson::translateListener(*json_object_ptr, listener); + Stats::StatsOptionsImpl stats_options; + Config::LdsJson::translateListener(*json_object_ptr, listener, stats_options); return listener; } diff --git a/test/test_common/environment.cc b/test/test_common/environment.cc index 5082efd0580c7..a761afcf38492 100644 --- a/test/test_common/environment.cc +++ b/test/test_common/environment.cc @@ -33,7 +33,7 @@ namespace { void createParentPath(const std::string& path) { #ifdef __APPLE__ // No support in Clang OS X libc++ today for std::filesystem. - RELEASE_ASSERT(::system(("mkdir -p $(dirname " + path + ")").c_str()) == 0); + RELEASE_ASSERT(::system(("mkdir -p $(dirname " + path + ")").c_str()) == 0, ""); #else // We don't want to rely on mkdir etc. if we can avoid it, since it might not // exist in some environments such as ClusterFuzz. @@ -51,7 +51,7 @@ std::string getOrCreateUnixDomainSocketDirectory() { // for the sun_path limit on sockaddr_un, since TEST_TMPDIR as generated by // Bazel may be too long. char test_udsdir[] = "/tmp/envoy_test_uds.XXXXXX"; - RELEASE_ASSERT(::mkdtemp(test_udsdir) != nullptr); + RELEASE_ASSERT(::mkdtemp(test_udsdir) != nullptr, ""); return std::string(test_udsdir); } @@ -81,7 +81,7 @@ absl::optional TestEnvironment::getOptionalEnvVar(const std::string std::string TestEnvironment::getCheckedEnvVar(const std::string& var) { auto optional = getOptionalEnvVar(var); - RELEASE_ASSERT(optional.has_value()); + RELEASE_ASSERT(optional.has_value(), ""); return optional.value(); } @@ -182,7 +182,7 @@ std::string TestEnvironment::readFileToStringForTest(const std::string& filename std::ifstream file(filename); if (file.fail()) { std::cerr << "failed to open: " << filename << std::endl; - RELEASE_ASSERT(false); + RELEASE_ASSERT(false, ""); } std::stringstream file_string_stream; @@ -240,7 +240,7 @@ void TestEnvironment::exec(const std::vector& args) { } if (::system(cmd.str().c_str()) != 0) { std::cerr << "Failed " << cmd.str() << "\n"; - RELEASE_ASSERT(false); + RELEASE_ASSERT(false, ""); } } @@ -251,7 +251,7 @@ std::string TestEnvironment::writeStringToFileForTest(const std::string& filenam unlink(out_path.c_str()); { std::ofstream out_file(out_path); - RELEASE_ASSERT(!out_file.fail()); + RELEASE_ASSERT(!out_file.fail(), ""); out_file << contents; } return out_path; diff --git a/test/test_common/environment.h b/test/test_common/environment.h index 2e680e98cb5f7..993b20cfe7947 100644 --- a/test/test_common/environment.h +++ b/test/test_common/environment.h @@ -101,10 +101,12 @@ class TestEnvironment { /** * Prefix a given path with the Unix Domain Socket temporary directory. * @param path path suffix. + * @param abstract_namespace true if an abstract namespace should be returned. * @return std::string path qualified with the Unix Domain Socket temporary directory. */ - static std::string unixDomainSocketPath(const std::string& path) { - return unixDomainSocketDirectory() + "/" + path; + static std::string unixDomainSocketPath(const std::string& path, + bool abstract_namespace = false) { + return (abstract_namespace ? "@" : "") + unixDomainSocketDirectory() + "/" + path; } /** diff --git a/test/test_common/network_utility.cc b/test/test_common/network_utility.cc index 5dad1523c407e..76e190170ddb1 100644 --- a/test/test_common/network_utility.cc +++ b/test/test_common/network_utility.cc @@ -27,29 +27,24 @@ Address::InstanceConstSharedPtr findOrCheckFreePort(Address::InstanceConstShared return nullptr; } const int fd = addr_port->socket(type); - if (fd < 0) { - const int err = errno; - ADD_FAILURE() << "socket failed for '" << addr_port->asString() - << "' with error: " << strerror(err) << " (" << err << ")"; - return nullptr; - } ScopedFdCloser closer(fd); // Not setting REUSEADDR, therefore if the address has been recently used we won't reuse it here. // However, because we're going to use the address while checking if it is available, we'll need // to set REUSEADDR on listener sockets created by tests using an address validated by this means. - int rc = addr_port->bind(fd); + Api::SysCallResult result = addr_port->bind(fd); + int err; const char* failing_fn = nullptr; - if (rc != 0) { + if (result.rc_ != 0) { + err = result.errno_; failing_fn = "bind"; } else if (type == Address::SocketType::Stream) { // Try listening on the port also, if the type is TCP. - rc = ::listen(fd, 1); - if (rc != 0) { + if (::listen(fd, 1) != 0) { + err = errno; failing_fn = "listen"; } } if (failing_fn != nullptr) { - const int err = errno; if (err == EADDRINUSE) { // The port is already in use. Perfectly normal. return nullptr; @@ -148,39 +143,28 @@ Address::InstanceConstSharedPtr getAnyAddress(const Address::IpVersion version, bool supportsIpVersion(const Address::IpVersion version) { Address::InstanceConstSharedPtr addr = getCanonicalLoopbackAddress(version); const int fd = addr->socket(Address::SocketType::Stream); - if (fd < 0) { - // Socket creation failed. - return false; - } - if (0 != addr->bind(fd)) { + if (0 != addr->bind(fd).rc_) { // Socket bind failed. - RELEASE_ASSERT(::close(fd) == 0); + RELEASE_ASSERT(::close(fd) == 0, ""); return false; } - RELEASE_ASSERT(::close(fd) == 0); + RELEASE_ASSERT(::close(fd) == 0, ""); return true; } std::pair bindFreeLoopbackPort(Address::IpVersion version, Address::SocketType type) { Address::InstanceConstSharedPtr addr = getCanonicalLoopbackAddress(version); - const char* failing_fn = nullptr; const int fd = addr->socket(type); - if (fd < 0) { - failing_fn = "socket"; - } else if (0 != addr->bind(fd)) { - failing_fn = "bind"; - } else { - return std::make_pair(Address::addressFromFd(fd), fd); - } - const int err = errno; - if (fd >= 0) { + Api::SysCallResult result = addr->bind(fd); + if (0 != result.rc_) { close(fd); + std::string msg = fmt::format("bind failed for address {} with error: {} ({})", + addr->asString(), strerror(result.errno_), result.errno_); + ADD_FAILURE() << msg; + throw EnvoyException(msg); } - std::string msg = fmt::format("{} failed for address {} with error: {} ({})", failing_fn, - addr->asString(), strerror(err), err); - ADD_FAILURE() << msg; - throw EnvoyException(msg); + return std::make_pair(Address::addressFromFd(fd), fd); } TransportSocketPtr createRawBufferSocket() { return std::make_unique(); } diff --git a/test/test_common/utility.cc b/test/test_common/utility.cc index 3dfdfeb0dfa05..66553e7f08926 100644 --- a/test/test_common/utility.cc +++ b/test/test_common/utility.cc @@ -143,7 +143,8 @@ envoy::config::bootstrap::v2::Bootstrap TestUtility::parseBootstrapFromJson(const std::string& json_string) { envoy::config::bootstrap::v2::Bootstrap bootstrap; auto json_object_ptr = Json::Factory::loadFromString(json_string); - Config::BootstrapJson::translateBootstrap(*json_object_ptr, bootstrap); + Stats::StatsOptionsImpl stats_options; + Config::BootstrapJson::translateBootstrap(*json_object_ptr, bootstrap, stats_options); return bootstrap; } @@ -182,6 +183,8 @@ void ConditionalInitializer::waitReady() { ScopedFdCloser::ScopedFdCloser(int fd) : fd_(fd) {} ScopedFdCloser::~ScopedFdCloser() { ::close(fd_); } +constexpr std::chrono::milliseconds TestUtility::DefaultTimeout; + namespace Http { // Satisfy linker @@ -225,4 +228,24 @@ bool TestHeaderMapImpl::has(const std::string& key) { return get(LowerCaseString bool TestHeaderMapImpl::has(const LowerCaseString& key) { return get(key) != nullptr; } } // namespace Http + +namespace Stats { + +MockedTestAllocator::MockedTestAllocator(const StatsOptions& stats_options) + : alloc_(stats_options) { + ON_CALL(*this, alloc(_)).WillByDefault(Invoke([this](absl::string_view name) -> RawStatData* { + return alloc_.alloc(name); + })); + + ON_CALL(*this, free(_)).WillByDefault(Invoke([this](RawStatData& data) -> void { + return alloc_.free(data); + })); + + EXPECT_CALL(*this, alloc(absl::string_view("stats.overflow"))); +} + +MockedTestAllocator::~MockedTestAllocator() {} + +} // namespace Stats + } // namespace Envoy diff --git a/test/test_common/utility.h b/test/test_common/utility.h index ddecc92b61b0a..2cdb7117bf3ca 100644 --- a/test/test_common/utility.h +++ b/test/test_common/utility.h @@ -26,6 +26,8 @@ using testing::AssertionFailure; using testing::AssertionResult; using testing::AssertionSuccess; +using testing::Invoke; +using testing::_; namespace Envoy { #define EXPECT_THROW_WITH_MESSAGE(statement, expected_exception, message) \ @@ -75,6 +77,14 @@ namespace Envoy { EXPECT_DEATH(statement, message); \ } while (false) +#define VERIFY_ASSERTION(statement) \ + do { \ + ::testing::AssertionResult status = statement; \ + if (!status) { \ + return status; \ + } \ + } while (false) + // Random number generator which logs its seed to stderr. To repeat a test run with a non-zero seed // one can run the test with --test_arg=--gtest_random_seed=[seed] class TestRandomGenerator { @@ -98,16 +108,6 @@ class TestUtility { */ static bool buffersEqual(const Buffer::Instance& lhs, const Buffer::Instance& rhs); - /** - * Convert a buffer to a string. - * @param buffer supplies the buffer to convert. - * @return std::string the converted string. - */ - static std::string bufferToString(const Buffer::OwnedImpl& buffer) { - // TODO(jmarantz): remove this indirection and update all ~53 call sites. - return buffer.toString(); - } - /** * Feed a buffer with random characters. * @param buffer supplies the buffer to be fed. @@ -154,12 +154,10 @@ class TestUtility { * * @param lhs proto on LHS. * @param rhs proto on RHS. - * @return bool indicating whether the protos are equal. Type name and string serialization are - * used for equality testing. + * @return bool indicating whether the protos are equal. */ static bool protoEqual(const Protobuf::Message& lhs, const Protobuf::Message& rhs) { - return lhs.GetTypeName() == rhs.GetTypeName() && - lhs.SerializeAsString() == rhs.SerializeAsString(); + return Protobuf::util::MessageDifferencer::Equivalent(lhs, rhs); } /** @@ -275,6 +273,8 @@ class TestUtility { } return result; } + + static constexpr std::chrono::milliseconds DefaultTimeout = std::chrono::milliseconds(10000); }; /** @@ -345,19 +345,22 @@ class TestHeaderMapImpl : public HeaderMapImpl { } // namespace Http namespace Stats { + /** * This is a heap test allocator that works similar to how the shared memory allocator works in * terms of reference counting, etc. */ class TestAllocator : public RawStatDataAllocator { public: + TestAllocator(const StatsOptions& stats_options) : stats_options_(stats_options) {} ~TestAllocator() { EXPECT_TRUE(stats_.empty()); } - RawStatData* alloc(const std::string& name) override { - CSmartPtr& stat_ref = stats_[name]; + RawStatData* alloc(absl::string_view name) override { + CSmartPtr& stat_ref = stats_[std::string(name)]; if (!stat_ref) { - stat_ref.reset(static_cast(::calloc(RawStatData::size(), 1))); - stat_ref->initialize(name); + stat_ref.reset(static_cast( + ::calloc(RawStatData::structSizeWithOptions(stats_options_), 1))); + stat_ref->initialize(name, stats_options_); } else { stat_ref->ref_count_++; } @@ -378,6 +381,18 @@ class TestAllocator : public RawStatDataAllocator { private: static void freeAdapter(RawStatData* data) { ::free(data); } std::unordered_map> stats_; + const StatsOptions& stats_options_; +}; + +class MockedTestAllocator : public RawStatDataAllocator { +public: + MockedTestAllocator(const StatsOptions& stats_options); + virtual ~MockedTestAllocator(); + + MOCK_METHOD1(alloc, RawStatData*(absl::string_view name)); + MOCK_METHOD1(free, void(RawStatData& data)); + + TestAllocator alloc_; }; } // namespace Stats diff --git a/test/tools/router_check/BUILD b/test/tools/router_check/BUILD index 90b29481696f0..0b0ab09cbc3a6 100644 --- a/test/tools/router_check/BUILD +++ b/test/tools/router_check/BUILD @@ -27,6 +27,7 @@ envoy_cc_test_library( "//source/common/http:headers_lib", "//source/common/json:json_loader_lib", "//source/common/router:config_lib", + "//source/common/stats:stats_lib", "//test/mocks/server:server_mocks", "//test/test_common:printers_lib", "//test/test_common:utility_lib", diff --git a/test/tools/router_check/router.cc b/test/tools/router_check/router.cc index 9947c285b0c1c..1d9dbbfe5d03c 100644 --- a/test/tools/router_check/router.cc +++ b/test/tools/router_check/router.cc @@ -7,6 +7,7 @@ #include "common/network/utility.h" #include "common/request_info/request_info_impl.h" +#include "common/stats/stats_impl.h" #include "test/test_common/printers.h" @@ -44,7 +45,9 @@ RouterCheckTool RouterCheckTool::create(const std::string& router_config_json) { // TODO(hennna): Allow users to load a full config and extract the route configuration from it. Json::ObjectSharedPtr loader = Json::Factory::loadFromFile(router_config_json); envoy::api::v2::RouteConfiguration route_config; - Config::RdsJson::translateRouteConfiguration(*loader, route_config); + // TODO(ambuc): Add a CLI option to allow for a maxStatNameLength constraint + Stats::StatsOptionsImpl stats_options; + Config::RdsJson::translateRouteConfiguration(*loader, route_config, stats_options); std::unique_ptr> factory_context( std::make_unique>()); diff --git a/test/tools/schema_validator/BUILD b/test/tools/schema_validator/BUILD index bea8d2dd0fbb0..5d67c3230eb56 100644 --- a/test/tools/schema_validator/BUILD +++ b/test/tools/schema_validator/BUILD @@ -27,6 +27,7 @@ envoy_cc_test_library( "//source/common/config:rds_json_lib", "//source/common/json:json_loader_lib", "//source/common/router:config_lib", + "//source/common/stats:stats_lib", "//test/mocks/runtime:runtime_mocks", "//test/mocks/upstream:upstream_mocks", "//test/test_common:printers_lib", diff --git a/test/tools/schema_validator/validator.cc b/test/tools/schema_validator/validator.cc index fde2decc8eeb4..b7f32f050d3a2 100644 --- a/test/tools/schema_validator/validator.cc +++ b/test/tools/schema_validator/validator.cc @@ -1,6 +1,7 @@ #include "test/tools/schema_validator/validator.h" #include "common/router/config_impl.h" +#include "common/stats/stats_impl.h" #include "test/test_common/printers.h" @@ -16,7 +17,7 @@ const std::string& Schema::toString(Type type) { return ROUTE; } - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } Options::Options(int argc, char** argv) { @@ -55,11 +56,14 @@ void Validator::validate(const std::string& json_path, Schema::Type schema_type) // Construct a envoy::api::v2::RouteConfiguration to validate the Route configuration and // ignore the output since nothing will consume it. envoy::api::v2::RouteConfiguration route_config; - Config::RdsJson::translateRouteConfiguration(*loader, route_config); + // TODO(ambuc): Add a CLI option to the schema_validator to allow for a maxStatNameLength + // constraint + Stats::StatsOptionsImpl stats_options; + Config::RdsJson::translateRouteConfiguration(*loader, route_config, stats_options); break; } default: - NOT_REACHED; + NOT_REACHED_GCOVR_EXCL_LINE; } } diff --git a/tools/BUILD b/tools/BUILD index f1896114d4e11..87880b11b6d5d 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -41,6 +41,7 @@ envoy_cc_binary( "//source/common/config:bootstrap_json_lib", "//source/common/json:json_loader_lib", "//source/common/protobuf:utility_lib", + "//source/common/stats:stats_lib", "@envoy_api//envoy/config/bootstrap/v2:bootstrap_cc", ], ) diff --git a/tools/check_format.py b/tools/check_format.py index 56f25bc609693..5cd498832baa9 100755 --- a/tools/check_format.py +++ b/tools/check_format.py @@ -13,13 +13,15 @@ EXCLUDED_PREFIXES = ("./generated/", "./thirdparty/", "./build", "./.git/", "./bazel-", "./bazel/external", "./.cache", + "./source/extensions/extensions_build_config.bzl", "./tools/testdata/check_format/") -SUFFIXES = (".cc", ".h", "BUILD", ".md", ".rst", ".proto") +SUFFIXES = (".cc", ".h", "BUILD", ".bzl", ".md", ".rst", ".proto") DOCS_SUFFIX = (".md", ".rst") PROTO_SUFFIX = (".proto") # Files in these paths can make reference to protobuf stuff directly -GOOGLE_PROTOBUF_WHITELIST = ('ci/prebuilt', 'source/common/protobuf', 'api/test') +GOOGLE_PROTOBUF_WHITELIST = ("ci/prebuilt", "source/common/protobuf", "api/test") +REPOSITORIES_BZL = "bazel/repositories.bzl" CLANG_FORMAT_PATH = os.getenv("CLANG_FORMAT", "clang-format-5.0") BUILDIFIER_PATH = os.getenv("BUILDIFIER_BIN", "$GOPATH/bin/buildifier") @@ -60,7 +62,7 @@ def checkNamespace(file_path): # To avoid breaking the Lyft import, we just check for path inclusion here. def whitelistedForProtobufDeps(file_path): - return (file_path.endswith(PROTO_SUFFIX) or + return (file_path.endswith(PROTO_SUFFIX) or file_path.endswith(REPOSITORIES_BZL) or \ any(path_segment in file_path for path_segment in GOOGLE_PROTOBUF_WHITELIST)) def findSubstringAndReturnError(pattern, file_path, error_message): @@ -83,6 +85,9 @@ def isBuildFile(file_path): return True return False +def isSkylarkFile(file_path): + return file_path.endswith(".bzl") + def hasInvalidAngleBracketDirectory(line): if not line.startswith(INCLUDE_ANGLE): return False @@ -145,36 +150,35 @@ def checkBuildLine(line, file_path, reportError): if not whitelistedForProtobufDeps(file_path) and '"protobuf"' in line: reportError("unexpected direct external dependency on protobuf, use " "//source/common/protobuf instead.") - if envoy_build_rule_check and '@envoy//' in line: + if envoy_build_rule_check and not isSkylarkFile(file_path) and '@envoy//' in line: reportError("Superfluous '@envoy//' prefix") -def fixBuildLine(line): - if envoy_build_rule_check: +def fixBuildLine(line, file_path): + if envoy_build_rule_check and not isSkylarkFile(file_path): line = line.replace('@envoy//', '//') return line def fixBuildPath(file_path): for line in fileinput.input(file_path, inplace=True): - sys.stdout.write(fixBuildLine(line)) + sys.stdout.write(fixBuildLine(line, file_path)) error_messages = [] # TODO(htuch): Add API specific BUILD fixer script. - if not isApiFile(file_path): - if os.system( - "%s %s %s" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path)) != 0: + if not isApiFile(file_path) and not isSkylarkFile(file_path): + if os.system("%s %s %s" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path)) != 0: error_messages += ["envoy_build_fixer rewrite failed for file: %s" % file_path] - if os.system("%s -mode=fix %s" % (BUILDIFIER_PATH, file_path)) != 0: - error_messages += ["buildifier rewrite failed for file: %s" % file_path] + + if os.system("%s -mode=fix %s" % (BUILDIFIER_PATH, file_path)) != 0: + error_messages += ["buildifier rewrite failed for file: %s" % file_path] return error_messages def checkBuildPath(file_path): error_messages = [] - if not isApiFile(file_path): + if not isApiFile(file_path) and not isSkylarkFile(file_path): command = "%s %s | diff %s -" % (ENVOY_BUILD_FIXER_PATH, file_path, file_path) - error_messages += executeCommand( - command, "envoy_build_fixer check failed", file_path) + error_messages += executeCommand(command, "envoy_build_fixer check failed", file_path) - command = "cat %s | %s -mode=fix | diff %s -" % (file_path, BUILDIFIER_PATH, file_path) + command = "%s -mode=diff %s" % (BUILDIFIER_PATH, file_path) error_messages += executeCommand(command, "buildifier check failed", file_path) error_messages += checkFileContents(file_path, checkBuildLine) return error_messages @@ -184,18 +188,20 @@ def fixSourcePath(file_path): sys.stdout.write(fixSourceLine(line)) error_messages = [] - if not file_path.endswith(DOCS_SUFFIX) and not file_path.endswith(PROTO_SUFFIX): - error_messages += fixHeaderOrder(file_path) + if not file_path.endswith(DOCS_SUFFIX): + if not file_path.endswith(PROTO_SUFFIX): + error_messages += fixHeaderOrder(file_path) error_messages += clangFormat(file_path) return error_messages def checkSourcePath(file_path): error_messages = checkFileContents(file_path, checkSourceLine) - if not file_path.endswith(DOCS_SUFFIX) and not file_path.endswith(PROTO_SUFFIX): - error_messages += checkNamespace(file_path) - command = ("%s %s | diff %s -" % (HEADER_ORDER_PATH, file_path, file_path)) - error_messages += executeCommand(command, "header_order.py check failed", file_path) + if not file_path.endswith(DOCS_SUFFIX): + if not file_path.endswith(PROTO_SUFFIX): + error_messages += checkNamespace(file_path) + command = ("%s %s | diff %s -" % (HEADER_ORDER_PATH, file_path, file_path)) + error_messages += executeCommand(command, "header_order.py check failed", file_path) command = ("%s %s | diff %s -" % (CLANG_FORMAT_PATH, file_path, file_path)) error_messages += executeCommand(command, "clang-format check failed", file_path) @@ -245,7 +251,7 @@ def checkFormat(file_path): # Apply fixes first, if asked, and then run checks. If we wind up attempting to fix # an issue, but there's still an error, that's a problem. try_to_fix = operation_type == "fix" - if isBuildFile(file_path): + if isBuildFile(file_path) or isSkylarkFile(file_path): if try_to_fix: error_messages += fixBuildPath(file_path) error_messages += checkBuildPath(file_path) diff --git a/tools/check_format_test_helper.py b/tools/check_format_test_helper.py index 7ac9f5af94ea3..17ad7edf9f583 100755 --- a/tools/check_format_test_helper.py +++ b/tools/check_format_test_helper.py @@ -134,6 +134,7 @@ def checkFileExpectingOK(filename): errors += fixFileExpectingSuccess("header_order.cc") errors += fixFileExpectingSuccess("license.BUILD") errors += fixFileExpectingSuccess("bad_envoy_build_sys_ref.BUILD") + errors += fixFileExpectingSuccess("proto_format.proto") errors += fixFileExpectingFailure("no_namespace_envoy.cc", "Unable to find Envoy namespace or NOLINT(namespace-envoy)") errors += fixFileExpectingFailure("mutex.cc", @@ -161,6 +162,8 @@ def checkFileExpectingOK(filename): errors += checkFileExpectingError("license.BUILD", "envoy_build_fixer check failed") errors += checkFileExpectingError("bad_envoy_build_sys_ref.BUILD", "Superfluous '@envoy//' prefix") + errors += checkFileExpectingError("proto_format.proto", "clang-format check failed") + errors += checkFileExpectingOK("ok_file.cc") errors += fixFileExpectingFailure("proto.BUILD", diff --git a/tools/envoy_build_fixer.py b/tools/envoy_build_fixer.py index a2d283a7fc0c3..fb687ceac063a 100755 --- a/tools/envoy_build_fixer.py +++ b/tools/envoy_build_fixer.py @@ -10,14 +10,20 @@ '\n' 'envoy_package()\n') - def FixBuild(path): with open(path, 'r') as f: outlines = [LICENSE_STRING] + first = True in_load = False seen_ebs = False seen_epkg = False for line in f: + if line.startswith('licenses'): + continue + if first: + if line != '\n': + outlines.append('\n') + first = False if line.startswith('package(') and not path.endswith( 'bazel/BUILD') and not path.endswith( 'ci/prebuilt/BUILD') and not path.endswith( @@ -41,8 +47,7 @@ def FixBuild(path): outlines.append(line) outlines.append(ENVOY_PACKAGE_STRING) continue - if not line.startswith('licenses'): - outlines.append(line) + outlines.append(line) return ''.join(outlines) diff --git a/tools/protodoc/protodoc.bzl b/tools/protodoc/protodoc.bzl index dc55a5f047564..c7ab5c8948890 100644 --- a/tools/protodoc/protodoc.bzl +++ b/tools/protodoc/protodoc.bzl @@ -10,10 +10,14 @@ def _proto_path(proto): path = proto.path root = proto.root.path ws = proto.owner.workspace_root - if path.startswith(root): path = path[len(root):] - if path.startswith("/"): path = path[1:] - if path.startswith(ws): path = path[len(ws):] - if path.startswith("/"): path = path[1:] + if path.startswith(root): + path = path[len(root):] + if path.startswith("/"): + path = path[1:] + if path.startswith(ws): + path = path[len(ws):] + if path.startswith("/"): + path = path[1:] return path # Bazel aspect (https://docs.bazel.build/versions/master/skylark/aspects.html) @@ -31,26 +35,30 @@ def _proto_doc_aspect_impl(target, ctx): for dep in ctx.rule.attr.deps: transitive_outputs = transitive_outputs | dep.output_groups["rst"] proto_sources = target.proto.direct_sources + # If this proto_library doesn't actually name any sources, e.g. //api:api, # but just glues together other libs, we just need to follow the graph. if not proto_sources: - return [OutputGroupInfo(rst=transitive_outputs)] + return [OutputGroupInfo(rst = transitive_outputs)] + # Figure out the set of import paths. Ideally we would use descriptor sets # built by proto_library, which avoid having to do nasty path mangling, but # these don't include source_code_info, which we need for comment # extractions. See https://github.com/bazelbuild/bazel/issues/3971. import_paths = [] for f in target.proto.transitive_sources: - if f.root.path: - import_path = f.root.path + "/" + f.owner.workspace_root - else: - import_path = f.owner.workspace_root - if import_path: - import_paths += [import_path] + if f.root.path: + import_path = f.root.path + "/" + f.owner.workspace_root + else: + import_path = f.owner.workspace_root + if import_path: + import_paths += [import_path] + # The outputs live in the ctx.label's package root. We add some additional # path information to match with protoc's notion of path relative locations. outputs = [ctx.actions.declare_file(ctx.label.name + "/" + _proto_path(f) + ".rst") for f in proto_sources] + # Create the protoc command-line args. ctx_path = ctx.label.package + "/" + ctx.label.name output_path = outputs[0].root.path + "/" + outputs[0].owner.workspace_root + "/" + ctx_path @@ -58,23 +66,30 @@ def _proto_doc_aspect_impl(target, ctx): args += ["-I" + import_path for import_path in import_paths] args += ["--plugin=protoc-gen-protodoc=" + ctx.executable._protodoc.path, "--protodoc_out=" + output_path] args += [_proto_path(src) for src in target.proto.direct_sources] - ctx.action(executable=ctx.executable._protoc, - arguments=args, - inputs=[ctx.executable._protodoc] + target.proto.transitive_sources.to_list(), - outputs=outputs, - mnemonic="ProtoDoc", - use_default_shell_env=True) + ctx.action( + executable = ctx.executable._protoc, + arguments = args, + inputs = [ctx.executable._protodoc] + target.proto.transitive_sources.to_list(), + outputs = outputs, + mnemonic = "ProtoDoc", + use_default_shell_env = True, + ) transitive_outputs = depset(outputs) | transitive_outputs - return [OutputGroupInfo(rst=transitive_outputs)] + return [OutputGroupInfo(rst = transitive_outputs)] -proto_doc_aspect = aspect(implementation = _proto_doc_aspect_impl, +proto_doc_aspect = aspect( + implementation = _proto_doc_aspect_impl, attr_aspects = ["deps"], attrs = { - "_protoc": attr.label(default=Label("@com_google_protobuf//:protoc"), - executable=True, - cfg="host"), - "_protodoc": attr.label(default=Label("//tools/protodoc"), - executable=True, - cfg="host"), - } + "_protoc": attr.label( + default = Label("@com_google_protobuf//:protoc"), + executable = True, + cfg = "host", + ), + "_protodoc": attr.label( + default = Label("//tools/protodoc"), + executable = True, + cfg = "host", + ), + }, ) diff --git a/tools/socket_passing.py b/tools/socket_passing.py index d8372e46e2fde..cdbbcf0a9d99b 100755 --- a/tools/socket_passing.py +++ b/tools/socket_passing.py @@ -13,6 +13,7 @@ import httplib import json import os.path +import re import sys import time @@ -20,38 +21,56 @@ # with failure if the file is not found. ADMIN_FILE_TIMEOUT_SECS = 20 -def GenerateNewConfig(original_json, admin_address, updated_json): +# Because the hot restart files are yaml but yaml support is not included in +# python by default, we parse this fairly manually. +def GenerateNewConfig(original_yaml, admin_address, updated_json): # Get original listener addresses - with open(original_json, 'r') as original_json_file: - # Import original config file in order to get a deterministic output. This - # allows us to diff the original config file and the updated config file - # output from this script to check for any changes. - parsed_json = json.load(original_json_file, object_pairs_hook=OrderedDict) - original_listeners = parsed_json['listeners'] - - sys.stdout.write('Admin address is ' + admin_address + '\n') - try: - admin_conn = httplib.HTTPConnection(admin_address) - admin_conn.request('GET', '/listeners') - admin_response = admin_conn.getresponse() - if not admin_response.status == 200: - return False - discovered_listeners = json.loads(admin_response.read()) - except Exception as e: - sys.stderr.write('Cannot connect to admin: %s\n' % e) - return False - else: - if len(discovered_listeners) != len(original_listeners): + with open(original_yaml, 'r') as original_file: + sys.stdout.write('Admin address is ' + admin_address + '\n') + try: + admin_conn = httplib.HTTPConnection(admin_address) + admin_conn.request('GET', '/listeners') + admin_response = admin_conn.getresponse() + if not admin_response.status == 200: + return False + discovered_listeners = json.loads(admin_response.read()) + except Exception as e: + sys.stderr.write('Cannot connect to admin: %s\n' % e) return False - for discovered, original in zip(discovered_listeners, original_listeners): - if discovered.startswith('/'): - original['address'] = 'unix://' + discovered - else: - original['address'] = 'tcp://' + discovered - with open(updated_json, 'w') as outfile: - json.dump(OrderedDict(parsed_json), outfile, indent=2, separators=(',',':')) - finally: - admin_conn.close() + else: + raw_yaml = original_file.readlines() + index = 0; + for discovered in discovered_listeners: + replaced = False; + if discovered.startswith('/'): + for index in range(index + 1, len(raw_yaml) - 1): + if 'pipe:' in raw_yaml[index] and 'path:' in raw_yaml[index + 1]: + raw_yaml[index + 1] = re.sub( + 'path:.*', 'path: "' + discovered + '"', raw_yaml[index + 1]) + replaced = True + break + else: + addr, _, port = discovered.rpartition(':') + if addr[0] == '[': + addr = addr[1:-1] # strip [] from ipv6 address. + for index in range(index + 1, len(raw_yaml) - 2): + if ('socket_address:' in raw_yaml[index] and 'address:' in raw_yaml[index + 1] + and'port_value:' in raw_yaml[index + 2]): + raw_yaml[index + 1] = re.sub( + 'address:.*', 'address: "' + addr + '"', raw_yaml[index + 1]) + raw_yaml[index + 2] = re.sub( + 'port_value:.*', 'port_value: ' + port, raw_yaml[index + 2]) + replaced = True + break + if replaced: + sys.stderr.write('replaced listener at line ' + str(index) + ' with ' + discovered + '\n') + else: + sys.stderr.write('Failed to replace a discovered listener ' + discovered + '\n') + return False; + with open(updated_json, 'w') as outfile: + outfile.writelines(raw_yaml) + finally: + admin_conn.close() return True diff --git a/tools/testdata/check_format/proto_format.proto b/tools/testdata/check_format/proto_format.proto new file mode 100644 index 0000000000000..b31d9dfa47db3 --- /dev/null +++ b/tools/testdata/check_format/proto_format.proto @@ -0,0 +1 @@ +// This commment is too long for the line-limit built into our clang configuration, so it will need to be wrapped. diff --git a/tools/testdata/check_format/proto_format.proto.gold b/tools/testdata/check_format/proto_format.proto.gold new file mode 100644 index 0000000000000..db5b8dea9cc7b --- /dev/null +++ b/tools/testdata/check_format/proto_format.proto.gold @@ -0,0 +1,2 @@ +// This commment is too long for the line-limit built into our clang +// configuration, so it will need to be wrapped. diff --git a/tools/v1_to_bootstrap.cc b/tools/v1_to_bootstrap.cc index 20ec6e5ea516f..e712e1bcb2c26 100644 --- a/tools/v1_to_bootstrap.cc +++ b/tools/v1_to_bootstrap.cc @@ -13,6 +13,7 @@ #include "common/config/bootstrap_json.h" #include "common/json/json_loader.h" #include "common/protobuf/utility.h" +#include "common/stats/stats_impl.h" // NOLINT(namespace-envoy) int main(int argc, char** argv) { @@ -23,7 +24,8 @@ int main(int argc, char** argv) { envoy::config::bootstrap::v2::Bootstrap bootstrap; auto config_json = Envoy::Json::Factory::loadFromFile(argv[1]); - Envoy::Config::BootstrapJson::translateBootstrap(*config_json, bootstrap); + Envoy::Stats::StatsOptionsImpl stats_options; + Envoy::Config::BootstrapJson::translateBootstrap(*config_json, bootstrap, stats_options); Envoy::MessageUtil::validate(bootstrap); std::cout << Envoy::MessageUtil::getJsonStringFromMessage(bootstrap, true); diff --git a/windows/setup/workstation_setup.ps1 b/windows/setup/workstation_setup.ps1 new file mode 100644 index 0000000000000..bf253519602ee --- /dev/null +++ b/windows/setup/workstation_setup.ps1 @@ -0,0 +1,54 @@ +$ErrorActionPreference = "Stop"; +$ProgressPreference="SilentlyContinue" + +trap { $host.SetShouldExit(1) } + +Start-BitsTransfer "https://aka.ms/vs/15/release/vs_buildtools.exe" "$env:TEMP\vs_buildtools.exe" + +# Install VS Build Tools in a directory without spaces to work around: https://github.com/bazelbuild/bazel/issues/4496 +# otherwise none of the go code will build (c++ is fine) + +$vsInstallDir="c:\VSBuildTools\2017" +echo "Installing VS Build Tools..." +cmd.exe /s /c "$env:TEMP\vs_buildtools.exe --installPath $vsInstallDir --passive --wait --norestart --nocache --add Microsoft.VisualStudio.Component.VC.CoreBuildTools --add Microsoft.VisualStudio.Component.VC.Redist.14.Latest --add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 --add Microsoft.VisualStudio.Component.Windows10SDK --add Microsoft.VisualStudio.Component.Windows10SDK.17134" + +if ($LASTEXITCODE -ne 0) { + echo "VS Build Tools install failed: $LASTEXITCODE" + exit $LASTEXITCODE +} +Remove-Item "$env:TEMP\vs_buildtools.exe" +echo "Done" + +Set-ExecutionPolicy Bypass -Scope Process -Force; iex ((New-Object System.Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')) + +choco install make bazel cmake ninja git -y +if ($LASTEXITCODE -ne 0) { + echo "choco install failed: $LASTEXITCODE" + exit $LASTEXITCODE +} + +$envoyBazelRootDir = "c:\_eb" + +$env:ENVOY_BAZEL_ROOT=$envoyBazelRootDir +setx ENVOY_BAZEL_ROOT $envoyBazelRootDir > $nul +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} + +$env:PATH ="$env:PATH;c:\tools\msys64\usr\bin;c:\make\bin;c:\Program Files\CMake\bin;C:\Python27;c:\programdata\chocolatey\bin;C:\Program Files\Git\bin" +setx PATH $env:PATH > $nul +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} + +$env:BAZEL_VC="$vsInstallDir\VC" +setx BAZEL_VC $env:BAZEL_VC > $nul +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} + +$env:BAZEL_SH="C:\tools\msys64\usr\bin\bash.exe" +setx BAZEL_SH $env:BAZEL_SH > $nul +if ($LASTEXITCODE -ne 0) { + exit $LASTEXITCODE +} diff --git a/windows/tools/bazel.rc b/windows/tools/bazel.rc new file mode 100644 index 0000000000000..234d5bc4a85a3 --- /dev/null +++ b/windows/tools/bazel.rc @@ -0,0 +1,7 @@ +# Windows/Envoy specific Bazel build/test options. + +// TODO: remove experimental_shortened_obj_file_path for Bazel 0.16.0 +build --experimental_shortened_obj_file_path +build --define signal_trace=disabled +build --define hot_restart=disabled +build --define tcmalloc=disabled