diff --git a/credentials/google/google.go b/credentials/google/google.go index fbdf7dc2997a..2f1ac6740654 100644 --- a/credentials/google/google.go +++ b/credentials/google/google.go @@ -39,6 +39,9 @@ var logger = grpclog.Component("credentials") type DefaultCredentialsOptions struct { // PerRPCCreds is a per RPC credentials that is passed to a bundle. PerRPCCreds credentials.PerRPCCredentials + // ALTSPerRPCCreds is a per RPC credentials that, if specified, will + // supercede PerRPCCreds above for and only for ALTS connections. + ALTSPerRPCCreds credentials.PerRPCCredentials } // NewDefaultCredentialsWithOptions returns a credentials bundle that is @@ -55,6 +58,12 @@ func NewDefaultCredentialsWithOptions(opts DefaultCredentialsOptions) credential logger.Warningf("NewDefaultCredentialsWithOptions: failed to create application oauth: %v", err) } } + if opts.ALTSPerRPCCreds != nil { + opts.PerRPCCreds = &dualPerRPCCreds{ + perRPCCreds: opts.PerRPCCreds, + altsPerRPCCreds: opts.ALTSPerRPCCreds, + } + } c := &creds{opts: opts} bundle, err := c.NewWithMode(internal.CredsBundleModeFallback) if err != nil { @@ -143,3 +152,27 @@ func (c *creds) NewWithMode(mode string) (credentials.Bundle, error) { return newCreds, nil } + +// dualPerRPCCreds implements credentials.PerRPCCredentials by embedding the +// fallback PerRPCCredentials and the ALTS one. It pickes one of them based on +// the channel type. +type dualPerRPCCreds struct { + perRPCCreds credentials.PerRPCCredentials + altsPerRPCCreds credentials.PerRPCCredentials +} + +func (d *dualPerRPCCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ri, ok := credentials.RequestInfoFromContext(ctx) + if !ok { + return nil, fmt.Errorf("request info not found from context") + } + if authType := ri.AuthInfo.AuthType(); authType == "alts" { + return d.altsPerRPCCreds.GetRequestMetadata(ctx, uri...) + } + // This ensures backward compatibility even if authType is not "tls". + return d.perRPCCreds.GetRequestMetadata(ctx, uri...) +} + +func (d *dualPerRPCCreds) RequireTransportSecurity() bool { + return d.altsPerRPCCreds.RequireTransportSecurity() || d.perRPCCreds.RequireTransportSecurity() +} diff --git a/credentials/google/google_test.go b/credentials/google/google_test.go index f9353df80f5b..7b2910a9e0a9 100644 --- a/credentials/google/google_test.go +++ b/credentials/google/google_test.go @@ -23,6 +23,7 @@ import ( "net" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/credentials" icredentials "google.golang.org/grpc/internal/credentials" "google.golang.org/grpc/internal/grpctest" @@ -59,6 +60,18 @@ func (t *testAuthInfo) AuthType() string { return t.typ } +type testPerRPCCreds struct { + md map[string]string +} + +func (c *testPerRPCCreds) RequireTransportSecurity() bool { + return true +} + +func (c *testPerRPCCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { + return c.md, nil +} + var ( testTLS = &testCreds{typ: "tls"} testALTS = &testCreds{typ: "alts"} @@ -161,3 +174,88 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) { } } } + +func TestDefaultCredentialsWithOptions(t *testing.T) { + md1 := map[string]string{"foo": "tls"} + md2 := map[string]string{"foo": "alts"} + tests := []struct { + desc string + defaultCredsOpts DefaultCredentialsOptions + authInfo credentials.AuthInfo + wantedMetadata map[string]string + }{ + { + desc: "no ALTSPerRPCCreds with tls channel", + defaultCredsOpts: DefaultCredentialsOptions{ + PerRPCCreds: &testPerRPCCreds{ + md: md1, + }, + }, + authInfo: &testAuthInfo{typ: "tls"}, + wantedMetadata: md1, + }, + { + desc: "no ALTSPerRPCCreds with alts channel", + defaultCredsOpts: DefaultCredentialsOptions{ + PerRPCCreds: &testPerRPCCreds{ + md: md1, + }, + }, + authInfo: &testAuthInfo{typ: "alts"}, + wantedMetadata: md1, + }, + { + desc: "ALTSPerRPCCreds specified with tls channel", + defaultCredsOpts: DefaultCredentialsOptions{ + PerRPCCreds: &testPerRPCCreds{ + md: md1, + }, + ALTSPerRPCCreds: &testPerRPCCreds{ + md: md2, + }, + }, + authInfo: &testAuthInfo{typ: "tls"}, + wantedMetadata: md1, + }, + { + desc: "ALTSPerRPCCreds specified with alts channel", + defaultCredsOpts: DefaultCredentialsOptions{ + PerRPCCreds: &testPerRPCCreds{ + md: md1, + }, + ALTSPerRPCCreds: &testPerRPCCreds{ + md: md2, + }, + }, + authInfo: &testAuthInfo{typ: "alts"}, + wantedMetadata: md2, + }, + { + desc: "ALTSPerRPCCreds specified with unknown channel", + defaultCredsOpts: DefaultCredentialsOptions{ + PerRPCCreds: &testPerRPCCreds{ + md: md1, + }, + ALTSPerRPCCreds: &testPerRPCCreds{ + md: md2, + }, + }, + authInfo: &testAuthInfo{typ: "foo"}, + wantedMetadata: md1, + }, + } + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + bundle := NewDefaultCredentialsWithOptions(tc.defaultCredsOpts) + ri := credentials.RequestInfo{AuthInfo: tc.authInfo} + ctx := icredentials.NewRequestInfoContext(context.Background(), ri) + got, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx, "uri") + if err != nil { + t.Fatalf("Bundle's PerRPCCredentials().GetRequestMetadata() unexpected error = %v", err) + } + if diff := cmp.Diff(got, tc.wantedMetadata); diff != "" { + t.Errorf("Unexpected request metadata from bundle's PerRPCCredentials. Diff (-got +want):\n%v", diff) + } + }) + } +}