diff --git a/balancer/ringhash/ringhash_e2e_test.go b/balancer/ringhash/ringhash_e2e_test.go index 98c5b6ff54b6..b901e971344e 100644 --- a/balancer/ringhash/ringhash_e2e_test.go +++ b/balancer/ringhash/ringhash_e2e_test.go @@ -299,9 +299,6 @@ func setupManagementServerAndResolver(t *testing.T) (*e2e.ManagementServer, stri bc := e2e.DefaultBootstrapContents(t, nodeID, xdsServer.Address) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } r, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/internal/testutils/xds/e2e/bootstrap.go b/internal/testutils/xds/e2e/bootstrap.go index d902e94a5144..768d32b01c8f 100644 --- a/internal/testutils/xds/e2e/bootstrap.go +++ b/internal/testutils/xds/e2e/bootstrap.go @@ -133,7 +133,8 @@ func DefaultBootstrapContents(t *testing.T, nodeID, serverURI string) []byte { bs, err := bootstrap.NewContentsForTesting(bootstrap.ConfigOptionsForTesting{ Servers: []byte(fmt.Sprintf(`[{ "server_uri": "passthrough:///%s", - "channel_creds": [{"type": "insecure"}] + "channel_creds": [{"type": "insecure"}], + "server_features": ["trusted_xds_server"] }]`, serverURI)), Node: []byte(fmt.Sprintf(`{"id": "%s"}`, nodeID)), CertificateProviders: cpc, diff --git a/internal/testutils/xds/e2e/clientresources.go b/internal/testutils/xds/e2e/clientresources.go index be375d40851c..6ce8402e1958 100644 --- a/internal/testutils/xds/e2e/clientresources.go +++ b/internal/testutils/xds/e2e/clientresources.go @@ -650,6 +650,8 @@ type BackendOptions struct { HealthStatus v3corepb.HealthStatus // Weight sets the backend weight. Defaults to 1. Weight uint32 + // Hostname sets the endpoint hostname for authority rewriting. + Hostname string // Metadata sets the LB endpoint metadata (envoy.lb FilterMetadata field). // See https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/core/v3/base.proto#envoy-v3-api-msg-config-core-v3-metadata Metadata map[string]any @@ -721,6 +723,7 @@ func EndpointResourceWithOptions(opts EndpointOptions) *v3endpointpb.ClusterLoad PortSpecifier: &v3corepb.SocketAddress_PortValue{PortValue: b.Ports[0]}, }, }}, + Hostname: b.Hostname, AdditionalAddresses: additionalAddresses, }}, HealthStatus: b.HealthStatus, diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 6daf1e002dc2..143eddfd7414 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -574,9 +574,14 @@ type CallHdr struct { DoneFunc func() // called when the stream is finished - // Authority is used to explicitly override the `:authority` header. If set, - // this value takes precedence over the Host field and will be used as the - // value for the `:authority` header. + // Authority is used to explicitly override the `:authority` header. + // + // This value comes from one of two sources: + // 1. The `CallAuthority` call option, if specified by the user. + // 2. An override provided by the LB picker (e.g. xDS authority rewriting). + // + // The `CallAuthority` call option always takes precedence over the LB + // picker override. Authority string } diff --git a/internal/xds/balancer/cdsbalancer/cdsbalancer_security_test.go b/internal/xds/balancer/cdsbalancer/cdsbalancer_security_test.go index f26571fd8cfa..5ee152208eba 100644 --- a/internal/xds/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/internal/xds/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -68,9 +68,6 @@ import ( func setupForSecurityTests(t *testing.T, bootstrapContents []byte, clientCreds, serverCreds credentials.TransportCredentials) (*grpc.ClientConn, string) { t.Helper() - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } r, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/internal/xds/balancer/cdsbalancer/cdsbalancer_test.go b/internal/xds/balancer/cdsbalancer/cdsbalancer_test.go index 3bd1a49c73aa..c2a739995f46 100644 --- a/internal/xds/balancer/cdsbalancer/cdsbalancer_test.go +++ b/internal/xds/balancer/cdsbalancer/cdsbalancer_test.go @@ -247,9 +247,6 @@ func setupWithManagementServer(t *testing.T, lis net.Listener, onStreamRequest f nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } r, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -654,7 +651,7 @@ func (s) TestClusterUpdate_SuccessWithLRS(t *testing.T) { ServiceName: serviceName, EnableLRS: true, }) - lrsServerCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: fmt.Sprintf("passthrough:///%s", mgmtServer.Address)}) + lrsServerCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: fmt.Sprintf("passthrough:///%s", mgmtServer.Address), ServerFeatures: []string{"trusted_xds_server"}}) if err != nil { t.Fatalf("Failed to create LRS server config for testing: %v", err) } diff --git a/internal/xds/balancer/cdsbalancer/e2e_test/balancer_test.go b/internal/xds/balancer/cdsbalancer/e2e_test/balancer_test.go index 5db316dc2e9e..51bdd538131f 100644 --- a/internal/xds/balancer/cdsbalancer/e2e_test/balancer_test.go +++ b/internal/xds/balancer/cdsbalancer/e2e_test/balancer_test.go @@ -74,9 +74,6 @@ import ( func setupAndDial(t *testing.T, bootstrapContents []byte) (*grpc.ClientConn, func()) { t.Helper() // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } r, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("xDS resolver creation failed: %v", err) diff --git a/internal/xds/balancer/clusterimpl/clusterimpl.go b/internal/xds/balancer/clusterimpl/clusterimpl.go index b5dc77371548..a7ac3032b922 100644 --- a/internal/xds/balancer/clusterimpl/clusterimpl.go +++ b/internal/xds/balancer/clusterimpl/clusterimpl.go @@ -45,6 +45,7 @@ import ( "google.golang.org/grpc/internal/xds/clients" "google.golang.org/grpc/internal/xds/clients/lrsclient" "google.golang.org/grpc/internal/xds/xdsclient" + "google.golang.org/grpc/internal/xds/xdsclient/xdsresource" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) @@ -420,6 +421,7 @@ type scWrapper struct { // locality needs to be atomic because it can be updated while being read by // the picker. locality atomic.Pointer[clients.Locality] + hostname string } func (scw *scWrapper) updateLocalityID(lID clients.Locality) { @@ -442,6 +444,9 @@ func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer } var sc balancer.SubConn scw := &scWrapper{} + if len(addrs) > 0 { + scw.hostname = xdsresource.Hostname(addrs[0]) + } oldListener := opts.StateListener opts.StateListener = func(state balancer.SubConnState) { b.updateSubConnState(sc, state, oldListener) diff --git a/internal/xds/balancer/clusterimpl/picker.go b/internal/xds/balancer/clusterimpl/picker.go index d766a09a6963..b73b108e9011 100644 --- a/internal/xds/balancer/clusterimpl/picker.go +++ b/internal/xds/balancer/clusterimpl/picker.go @@ -31,6 +31,7 @@ import ( xdsinternal "google.golang.org/grpc/internal/xds" "google.golang.org/grpc/internal/xds/clients" "google.golang.org/grpc/internal/xds/xdsclient" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -145,6 +146,14 @@ func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { // If locality ID isn't found in the wrapper, an empty locality ID will // be used. lID = scw.localityID() + + if scw.hostname != "" && autoHostRewriteEnabled(info.Ctx) { + if pr.Metadata == nil { + pr.Metadata = metadata.Pairs(":authority", scw.hostname) + } else { + pr.Metadata.Set(":authority", scw.hostname) + } + } } if err != nil { @@ -199,20 +208,20 @@ func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { // route's autoHostRewrite in the RPC context. type autoHostRewriteKey struct{} -// autoHostRewrite retrieves the autoHostRewrite value from the provided context. -func autoHostRewrite(ctx context.Context) bool { +// autoHostRewriteEnabled retrieves the autoHostRewrite value from the provided context. +func autoHostRewriteEnabled(ctx context.Context) bool { v, _ := ctx.Value(autoHostRewriteKey{}).(bool) return v } -// AutoHostRewriteForTesting returns the value of autoHostRewrite field; +// AutoHostRewriteEnabledForTesting returns the value of autoHostRewrite field; // to be used for testing only. -func AutoHostRewriteForTesting(ctx context.Context) bool { - return autoHostRewrite(ctx) +func AutoHostRewriteEnabledForTesting(ctx context.Context) bool { + return autoHostRewriteEnabled(ctx) } -// SetAutoHostRewrite adds the autoHostRewrite value to the context for +// EnableAutoHostRewrite adds the autoHostRewrite value to the context for // the xds_cluster_impl LB policy to pick. -func SetAutoHostRewrite(ctx context.Context, autohostRewrite bool) context.Context { - return context.WithValue(ctx, autoHostRewriteKey{}, autohostRewrite) +func EnableAutoHostRewrite(ctx context.Context) context.Context { + return context.WithValue(ctx, autoHostRewriteKey{}, true) } diff --git a/internal/xds/balancer/clusterimpl/tests/balancer_test.go b/internal/xds/balancer/clusterimpl/tests/balancer_test.go index 2bb398845458..f0418aba3030 100644 --- a/internal/xds/balancer/clusterimpl/tests/balancer_test.go +++ b/internal/xds/balancer/clusterimpl/tests/balancer_test.go @@ -20,6 +20,7 @@ package clusterimpl_test import ( "context" + "crypto/tls" "encoding/json" "errors" "fmt" @@ -41,11 +42,13 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/stub" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/xds/e2e" "google.golang.org/grpc/internal/testutils/xds/fakeserver" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" @@ -63,6 +66,7 @@ import ( v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" v3pickfirstpb "github.com/envoyproxy/go-control-plane/envoy/extensions/load_balancing_policies/pick_first/v3" v3lrspb "github.com/envoyproxy/go-control-plane/envoy/service/load_stats/v3" + xdscreds "google.golang.org/grpc/credentials/xds" testgrpc "google.golang.org/grpc/interop/grpc_testing" testpb "google.golang.org/grpc/interop/grpc_testing" "google.golang.org/protobuf/types/known/structpb" @@ -95,12 +99,8 @@ func (s) TestConfigUpdateWithSameLoadReportingServerConfig(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -206,9 +206,6 @@ func (s) TestLoadReportingPickFirstMultiLocality(t *testing.T) { bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -377,12 +374,8 @@ func (s) TestCircuitBreaking(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -574,12 +567,8 @@ func (s) TestDropByCategory(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -717,12 +706,8 @@ func (s) TestCircuitBreakingLogicalDNS(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -836,12 +821,8 @@ func (s) TestLRSLogicalDNS(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -928,12 +909,8 @@ func (s) TestReResolutionAfterTransientFailure(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -1048,12 +1025,8 @@ func (s) TestUpdateLRSServerToNil(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -1135,12 +1108,8 @@ func (s) TestChildPolicyChangeOnConfigUpdate(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -1258,12 +1227,8 @@ func (s) TestFailedToParseChildPolicyConfig(t *testing.T) { // Create bootstrap configuration pointing to the above management server. nodeID := uuid.New().String() bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - testutils.CreateBootstrapFileForTesting(t, bc) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -1316,3 +1281,190 @@ func (s) TestFailedToParseChildPolicyConfig(t *testing.T) { t.Fatal("EmptyCall RPC succeeded when expected to fail") } } + +// setupManagementServerAndResolver sets up an xDS management server and returns +// the management server, resolver builder and Node ID. +func setupManagementServerAndResolver(t *testing.T) (*e2e.ManagementServer, resolver.Builder, string) { + t.Helper() + + nodeID := uuid.New().String() + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{}) + contents := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + + // Create an xDS resolver with the above bootstrap configuration. + resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(contents) + if err != nil { + t.Fatalf("Failed to create xDS resolver for testing: %v", err) + } + + return mgmtServer, resolverBuilder, nodeID +} + +// configureXDSResources configures the management server with a route that +// enables auto_host_rewrite and an endpoint with the specified hostname. +func configureXDSResources(ctx context.Context, t *testing.T, mgmtServer *e2e.ManagementServer, nodeID string, serverAddr string, endpointHostname string, secLevel e2e.SecurityLevel) { + t.Helper() + + const ( + serviceName = "my-test-xds-service" + routeName = "route-my-test-xds-service" + clusterName = "cluster-my-test-xds-service" + endpointName = "endpoints-my-test-xds-service" + ) + + resources := e2e.DefaultClientResources(e2e.ResourceParams{ + DialTarget: serviceName, + NodeID: nodeID, + Host: "localhost", + Port: testutils.ParsePort(t, serverAddr), + SecLevel: secLevel, + }) + + // Set the endpoint hostname for authority rewriting. + resources.Endpoints[0].Endpoints[0].LbEndpoints[0].GetEndpoint().Hostname = endpointHostname + + // Modify the route to enable AutoHostRewrite. + resources.Routes[0].VirtualHosts[0].Routes[0].GetRoute().HostRewriteSpecifier = &v3routepb.RouteAction_AutoHostRewrite{ + AutoHostRewrite: &wrapperspb.BoolValue{Value: true}, + } + + if err := mgmtServer.Update(ctx, resources); err != nil { + t.Fatal(err) + } +} + +// TestAuthorityOverriding verifies that the :authority header is correctly +// rewritten to the endpoint's hostname. Also verifies that CallAuthority +// call option takes precedence. +func (s) TestAuthorityOverriding(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSAuthorityRewrite, true) + mgmtServer, resolverBuilder, nodeID := setupManagementServerAndResolver(t) + + // Start a server backend exposing the test service. + var gotAuthority string + f := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if authVals := md.Get(":authority"); len(authVals) > 0 { + gotAuthority = authVals[0] + } + } + return &testpb.Empty{}, nil + }, + } + server := stubserver.StartTestService(t, f) + defer server.Stop() + + const xdsAuthorityOverride = "rewritten.example.com" + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + configureXDSResources(ctx, t, mgmtServer, nodeID, server.Address, xdsAuthorityOverride, e2e.SecurityLevelNone) + + // Create a ClientConn and make a successful RPC. + cc, err := grpc.NewClient("xds:///my-test-xds-service", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(resolverBuilder)) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("client.EmptyCall() failed: %v", err) + } + + if gotAuthority != xdsAuthorityOverride { + t.Errorf("invalid authority got: %q, want: %q", gotAuthority, xdsAuthorityOverride) + } + + // The authority specified via the `CallAuthority` CallOption takes the + // highest precedence when determining the `:authority` header. + const userAuthorityOverride = "user-override.com" + if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.CallAuthority(userAuthorityOverride)); err != nil { + t.Fatalf("client.EmptyCall() failed: %v", err) + } + + if gotAuthority != userAuthorityOverride { + t.Errorf("Server received authority %q, want %q (user override)", gotAuthority, userAuthorityOverride) + } +} + +// TestAuthorityOverridingWithTLS verifies the interaction between xDS Authority +// Rewriting and TLS Secure Naming. It ensures that when the :authority header +// is rewritten by the clusterimpl picker, the new authority is correctly +// validated against the server's TLS certificate before the RPC proceeds. +// Also check that RPC fails when the rewritten authority does not match the +// server's certificate due to secure naming validation. +func (s) TestAuthorityOverridingWithTLS(t *testing.T) { + tests := []struct { + name string + xdsAuthorityOverride string + wantSuccess bool + }{ + { + name: "Valid_Authority_Rewrite", + xdsAuthorityOverride: "x.test.example.com", + wantSuccess: true, + }, + { + name: "Authority_Rewrite_Mismatch", + xdsAuthorityOverride: "xyz.exmaple.com", + wantSuccess: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.XDSAuthorityRewrite, true) + mgmtServer, resolverBuilder, nodeID := setupManagementServerAndResolver(t) + + serverCreds := testutils.CreateServerTLSCredentials(t, tls.RequireAndVerifyClientCert) + + // Start a server backend exposing the test service. + var gotAuthority string + f := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) { + if md, ok := metadata.FromIncomingContext(ctx); ok { + if authVals := md.Get(":authority"); len(authVals) > 0 { + gotAuthority = authVals[0] + } + } + return &testpb.Empty{}, nil + }, + } + f.StartServer(grpc.Creds(serverCreds)) + defer f.Stop() + + clientCreds, err := xdscreds.NewClientCredentials(xdscreds.ClientOptions{FallbackCreds: insecure.NewCredentials()}) + if err != nil { + t.Fatalf("Failed to create client credentials: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + configureXDSResources(ctx, t, mgmtServer, nodeID, f.Address, test.xdsAuthorityOverride, e2e.SecurityLevelMTLS) + + // Create ClientConn with TLS + cc, err := grpc.NewClient("xds:///my-test-xds-service", grpc.WithTransportCredentials(clientCreds), grpc.WithResolvers(resolverBuilder)) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + peer := &peer.Peer{} + _, err = client.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(peer)) + + if test.wantSuccess { + if err != nil { + t.Fatalf("RPC failed unexpectedly: %v", err) + } + if gotAuthority != test.xdsAuthorityOverride { + t.Errorf("invalid authority got: %q, want: %q", gotAuthority, test.xdsAuthorityOverride) + } + } else { + if status.Code(err) != codes.Unavailable { + t.Fatalf("Expected TLS failure due to authority mismatch, got: %q want: %q", codes.Unavailable, status.Code(err)) + } + } + }) + } +} diff --git a/internal/xds/httpfilter/fault/fault_test.go b/internal/xds/httpfilter/fault/fault_test.go index 2612094baa5a..1f77ad3943aa 100644 --- a/internal/xds/httpfilter/fault/fault_test.go +++ b/internal/xds/httpfilter/fault/fault_test.go @@ -461,9 +461,6 @@ func (s) TestFaultInjection_Unary(t *testing.T) { fs, nodeID, port, bc := clientSetup(t) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsResolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -551,9 +548,6 @@ func (s) TestFaultInjection_Unary(t *testing.T) { func (s) TestFaultInjection_MaxActiveFaults(t *testing.T) { fs, nodeID, port, bc := clientSetup(t) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsResolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/internal/xds/resolver/helpers_test.go b/internal/xds/resolver/helpers_test.go index 174b03883445..2f0049188460 100644 --- a/internal/xds/resolver/helpers_test.go +++ b/internal/xds/resolver/helpers_test.go @@ -103,9 +103,6 @@ func buildResolverForTarget(t *testing.T, target resolver.Target, bootstrapConte var builder resolver.Builder if bootstrapContents != nil { // Create an xDS resolver with the provided bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } var err error builder, err = internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { diff --git a/internal/xds/resolver/serviceconfig.go b/internal/xds/resolver/serviceconfig.go index e04163666dcb..40a423f1f1e2 100644 --- a/internal/xds/resolver/serviceconfig.go +++ b/internal/xds/resolver/serviceconfig.go @@ -199,7 +199,9 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name) lbCtx = iringhash.SetXDSRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) - lbCtx = clusterimpl.SetAutoHostRewrite(lbCtx, rt.autoHostRewrite) + if rt.autoHostRewrite { + lbCtx = clusterimpl.EnableAutoHostRewrite(lbCtx) + } config := &iresolver.RPCConfig{ // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. diff --git a/internal/xds/resolver/xds_http_filters_test.go b/internal/xds/resolver/xds_http_filters_test.go index 9d2c5acf0b43..55e2060256b5 100644 --- a/internal/xds/resolver/xds_http_filters_test.go +++ b/internal/xds/resolver/xds_http_filters_test.go @@ -268,9 +268,6 @@ func (s) TestXDSResolverHTTPFilters_AllOverrides(t *testing.T) { // management server. nodeID := uuid.New().String() bootstrapContents := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -526,9 +523,6 @@ func (s) TestXDSResolverHTTPFilters_NewStreamError(t *testing.T) { // management server. nodeID := uuid.New().String() bootstrapContents := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/internal/xds/resolver/xds_resolver_test.go b/internal/xds/resolver/xds_resolver_test.go index 16d6fd76b1f9..ebe225b3b822 100644 --- a/internal/xds/resolver/xds_resolver_test.go +++ b/internal/xds/resolver/xds_resolver_test.go @@ -95,9 +95,6 @@ func (s) TestResolverBuilder_AuthorityNotDefinedInBootstrap(t *testing.T) { contents := e2e.DefaultBootstrapContents(t, "node-id", "dummy-management-server") // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsResolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(contents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -310,10 +307,6 @@ func (s) TestNoMatchingVirtualHost(t *testing.T) { target := resolver.Target{URL: *testutils.MustParseURL("xds:///" + defaultTestServiceName)} // Create an xDS resolver with the provided bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } - builder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -383,10 +376,6 @@ func (s) TestResolverBadServiceUpdate_NACKedWithoutCache(t *testing.T) { target := resolver.Target{URL: *testutils.MustParseURL("xds:///" + defaultTestServiceName)} // Create an xDS resolver with the provided bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } - builder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -1443,7 +1432,7 @@ func (s) TestResolver_AutoHostRewrite(t *testing.T) { t.Fatalf("cs.SelectConfig(): %v", err) } - gotAutoHostRewrite := clusterimpl.AutoHostRewriteForTesting(res.Context) + gotAutoHostRewrite := clusterimpl.AutoHostRewriteEnabledForTesting(res.Context) if gotAutoHostRewrite != tt.wantAutoHostRewrite { t.Fatalf("Got autoHostRewrite: %v, want: %v", gotAutoHostRewrite, tt.wantAutoHostRewrite) } diff --git a/internal/xds/xdsclient/tests/resource_update_test.go b/internal/xds/xdsclient/tests/resource_update_test.go index 87d5653fa006..59513304ab53 100644 --- a/internal/xds/xdsclient/tests/resource_update_test.go +++ b/internal/xds/xdsclient/tests/resource_update_test.go @@ -875,7 +875,7 @@ func (s) TestHandleClusterResponseFromManagementServer(t *testing.T) { // server at that point, hence we do it here before verifying the // received update. if test.wantErr == "" { - serverCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: fmt.Sprintf("passthrough:///%s", mgmtServer.Address)}) + serverCfg, err := bootstrap.ServerConfigForTesting(bootstrap.ServerConfigTestingOptions{URI: fmt.Sprintf("passthrough:///%s", mgmtServer.Address), ServerFeatures: []string{"trusted_xds_server"}}) if err != nil { t.Fatalf("Failed to create server config for testing: %v", err) } diff --git a/internal/xds/xdsclient/xdsresource/unmarshal_eds.go b/internal/xds/xdsclient/xdsresource/unmarshal_eds.go index 4ec133249bbe..becd61de845b 100644 --- a/internal/xds/xdsclient/xdsresource/unmarshal_eds.go +++ b/internal/xds/xdsclient/xdsresource/unmarshal_eds.go @@ -51,10 +51,10 @@ func setHostname(endpoint resolver.Endpoint, hostname string) resolver.Endpoint return endpoint } -// HostnameFromEndpoint returns the hostname attribute of endpoint. If this -// attribute is not set, it returns the empty string. -func HostnameFromEndpoint(endpoint resolver.Endpoint) string { - hostname, _ := endpoint.Attributes.Value(hostnameKeyType{}).(string) +// Hostname returns the hostname from the BalancerAttributes of the given +// Address. If this attribute is not set, it returns the empty string. +func Hostname(addr resolver.Address) string { + hostname, _ := addr.BalancerAttributes.Value(hostnameKeyType{}).(string) return hostname } diff --git a/stream.go b/stream.go index ec9577b2789c..b2699d8c9563 100644 --- a/stream.go +++ b/stream.go @@ -537,8 +537,16 @@ func (a *csAttempt) newStream() error { md, _ := metadata.FromOutgoingContext(a.ctx) md = metadata.Join(md, a.pickResult.Metadata) a.ctx = metadata.NewOutgoingContext(a.ctx, md) - } + // If the `CallAuthority` CallOption is not set, check if the LB picker + // has provided an authority override in the PickResult metadata and + // apply it, as specified in gRFC A81. + if cs.callInfo.authority == "" { + if authMD := a.pickResult.Metadata.Get(":authority"); len(authMD) > 0 { + cs.callHdr.Authority = authMD[0] + } + } + } s, err := a.transport.NewStream(a.ctx, cs.callHdr) if err != nil { nse, ok := err.(*transport.NewStreamError) diff --git a/test/xds/xds_client_ack_nack_test.go b/test/xds/xds_client_ack_nack_test.go index 4ff6272e7b7a..05f5a1bd59a3 100644 --- a/test/xds/xds_client_ack_nack_test.go +++ b/test/xds/xds_client_ack_nack_test.go @@ -131,9 +131,6 @@ func (s) TestClientResourceVersionAfterStreamRestart(t *testing.T) { bootstrapContents := e2e.DefaultBootstrapContents(t, nodeID, managementServer.Address) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsResolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/test/xds/xds_client_certificate_providers_test.go b/test/xds/xds_client_certificate_providers_test.go index 03bcd603c812..b7b2fc4610df 100644 --- a/test/xds/xds_client_certificate_providers_test.go +++ b/test/xds/xds_client_certificate_providers_test.go @@ -129,9 +129,6 @@ func (s) TestClientSideXDS_WithNoCertificateProvidersInBootstrap_Failure(t *test } // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolverBuilder, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bc) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/test/xds/xds_client_federation_test.go b/test/xds/xds_client_federation_test.go index df4a7a1383bb..815b520f4710 100644 --- a/test/xds/xds_client_federation_test.go +++ b/test/xds/xds_client_federation_test.go @@ -87,9 +87,6 @@ func (s) TestClientSideFederation(t *testing.T) { t.Fatalf("Failed to create bootstrap file: %v", err) } - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) @@ -183,9 +180,6 @@ func (s) TestClientSideFederationWithOnlyXDSTPStyleLDS(t *testing.T) { t.Fatalf("Failed to create bootstrap file: %v", err) } - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } resolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/test/xds/xds_client_ignore_resource_deletion_test.go b/test/xds/xds_client_ignore_resource_deletion_test.go index 4459ed299844..dd55adb1865d 100644 --- a/test/xds/xds_client_ignore_resource_deletion_test.go +++ b/test/xds/xds_client_ignore_resource_deletion_test.go @@ -301,9 +301,6 @@ func generateBootstrapContents(t *testing.T, serverURI string, ignoreResourceDel // as parameter. func xdsResolverBuilder(t *testing.T, bs []byte) resolver.Builder { t.Helper() - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsR, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bs) if err != nil { t.Fatalf("Creating xDS resolver for testing failed for config %q: %v", string(bs), err) diff --git a/test/xds/xds_security_config_nack_test.go b/test/xds/xds_security_config_nack_test.go index f1e512a2261c..0efc080251b7 100644 --- a/test/xds/xds_security_config_nack_test.go +++ b/test/xds/xds_security_config_nack_test.go @@ -329,9 +329,6 @@ func (s) TestUnmarshalCluster_WithUpdateValidatorFunc(t *testing.T) { bootstrapContents := e2e.DefaultBootstrapContents(t, nodeID, managementServer.Address) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsResolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err) diff --git a/test/xds/xds_server_integration_test.go b/test/xds/xds_server_integration_test.go index 1dcfcc1eee83..ff56c9ea97e1 100644 --- a/test/xds/xds_server_integration_test.go +++ b/test/xds/xds_server_integration_test.go @@ -327,9 +327,6 @@ func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { bootstrapContents := e2e.DefaultBootstrapContents(t, nodeID, managementServer.Address) // Create an xDS resolver with the above bootstrap configuration. - if internal.NewXDSResolverWithConfigForTesting == nil { - t.Fatalf("internal.NewXDSResolverWithConfigForTesting is nil") - } xdsResolver, err := internal.NewXDSResolverWithConfigForTesting.(func([]byte) (resolver.Builder, error))(bootstrapContents) if err != nil { t.Fatalf("Failed to create xDS resolver for testing: %v", err)