From 85616132ce31d191fe151b02adb673a37fda8882 Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Tue, 16 Dec 2025 09:52:16 -0700 Subject: [PATCH 1/2] feat: Use configured DNS name to lookup instance IP address When a custom DNS name is used to connect to a Cloud SQL instance, the dialer should first attempt to resolve the custom DNS name to an IP address and use that for the connection. If the lookup fails, the dialer should fall back to using the IP address from the instance metadata. This change modifies the dialer to: - Use the configured resolver to look up the host's IP address. - Use the IP address from the A record of the custom DNS name if available. - Fall back to the IP address from the instance metadata if the A record is not available. --- dialer.go | 25 +++++ dialer_test.go | 160 +++++++++++++++++++++++------ internal/cloudsql/resolver.go | 20 ++-- internal/cloudsql/resolver_test.go | 4 + monitored_cache_test.go | 4 +- options.go | 3 +- 6 files changed, 171 insertions(+), 45 deletions(-) diff --git a/dialer.go b/dialer.go index b064af28..0e392406 100644 --- a/dialer.go +++ b/dialer.go @@ -185,6 +185,7 @@ type Dialer struct { // resolver converts instance names into DNS names. resolver instance.ConnectionNameResolver + dnsResolver cloudsql.NetResolver failoverPeriod time.Duration // metadataExchangeDisabled true when the dialer should never @@ -213,6 +214,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { logger: nullLogger{}, useragents: []string{userAgent}, failoverPeriod: cloudsql.FailoverPeriod, + dnsResolver: net.DefaultResolver, } for _, opt := range opts { opt(cfg) @@ -321,6 +323,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { dialerID: uuid.New().String(), iamTokenProvider: cfg.iamLoginTokenProvider, dialFunc: cfg.dialFunc, + dnsResolver: cfg.dnsResolver, resolver: r, failoverPeriod: cfg.failoverPeriod, metadataExchangeDisabled: cfg.metadataExchangeDisabled, @@ -407,6 +410,28 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn d.removeCached(ctx, cn, c, err) return nil, err } + + // If the connector is configured with a custom DNS name, attempt to use + // that DNS name to connect to the instance. Fall back to the metadata IP + // address if the DNS name does not resolve to an IP address. + if cn.HasDomainName() { + addrs, err := d.dnsResolver.LookupHost(ctx, cn.DomainName()) + if err != nil { + d.logger.Debugf(ctx, + "[%v] custom DNS name %q did not resolve to an IP address: %v, using %s from instance metadata", + cn.String(), cn.DomainName(), err, addr) + } else if len(addrs) == 0 { + d.logger.Debugf(ctx, + "[%v] custom DNS name %q resolved but returned no entries, using %s from instance metadata", + cn.String(), cn.DomainName(), addr) + } else { + d.logger.Debugf(ctx, + "[%v] custom DNS name %q resolved to %q, using it to connect", + cn.String(), cn.DomainName(), addrs[0]) + addr = addrs[0] + } + } + addr = net.JoinHostPort(addr, serverProxyPort) f := d.dialFunc if cfg.dialFunc != nil { diff --git a/dialer_test.go b/dialer_test.go index fd13e64f..2689caba 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -74,6 +74,13 @@ func testSucessfulDialWithInstanceName( } } +// withMockDNSResolver replaces net.DefaultResolver with a mock resolver +func withMockDNSResolver(r cloudsql.NetResolver) Option { + return func(d *dialerConfig) { + d.dnsResolver = r + } +} + // setupConfig holds all the configuration to use when setting up a dialer. type setupConfig struct { testInstance mock.FakeCSQLInstance @@ -1017,15 +1024,23 @@ func TestDialerInitializesLazyCache(t *testing.T) { } } -type fakeResolver struct { - entries map[string]instance.ConnName +type mockNetResolver struct { + txtEntries map[string]string + hostEntries map[string]string +} + +func (r *mockNetResolver) LookupTXT(_ context.Context, name string) ([]string, error) { + if val, ok := r.txtEntries[name]; ok { + return []string{val}, nil + } + return nil, fmt.Errorf("no resolution for %q", name) } -func (r *fakeResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) { - if val, ok := r.entries[name]; ok { - return val, nil +func (r *mockNetResolver) LookupHost(_ context.Context, name string) ([]string, error) { + if val, ok := r.hostEntries[name]; ok { + return []string{val}, nil } - return instance.ConnName{}, fmt.Errorf("no resolution for %q", name) + return nil, fmt.Errorf("no resolution for %q", name) } func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { @@ -1034,8 +1049,6 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"), mock.WithDNSMapping("db2.example.com", "INSTANCE", "CUSTOM_SAN"), ) - wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") - wantName2, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db2.example.com") // This will create 2 separate connectionInfoCache entries, one for // each DNS name. d := setupDialer(t, setupConfig{ @@ -1045,13 +1058,55 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { mock.CreateEphemeralSuccess(inst, 2), }, dialerOptions: []Option{ + withMockDNSResolver(&mockNetResolver{ + txtEntries: map[string]string{ + "db.example.com": "my-project:my-region:my-instance", + "db2.example.com": "my-project:my-region:my-instance", + }, + }), + WithDNSResolver(), WithTokenSource(mock.EmptyTokenSource{}), - WithResolver(&fakeResolver{ - entries: map[string]instance.ConnName{ - "db.example.com": wantName, - "db2.example.com": wantName2, + }, + }) + + testSuccessfulDial( + context.Background(), t, d, + "db.example.com", + ) + testSuccessfulDial( + context.Background(), t, d, + "db2.example.com", + ) +} + +func TestDialerSuccessfullyDialsDnsTxtRecordWithCustomARecords(t *testing.T) { + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"), + mock.WithDNSMapping("db2.example.com", "INSTANCE", "CUSTOM_SAN"), + ) + + // This will create 2 separate connectionInfoCache entries, one for + // each DNS name. + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 2), + mock.CreateEphemeralSuccess(inst, 2), + }, + dialerOptions: []Option{ + withMockDNSResolver(&mockNetResolver{ + txtEntries: map[string]string{ + "db.example.com": "my-project:my-region:my-instance", + "db2.example.com": "my-project:my-region:my-instance", + }, + hostEntries: map[string]string{ + "db.example.com": "127.0.0.1", + "db2.example.com": "127.0.0.2", }, }), + WithTokenSource(mock.EmptyTokenSource{}), + WithDNSResolver(), }, }) @@ -1065,7 +1120,45 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) { ) } -func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { +func TestDialerFailsDnsTxtRecordWithInvalidCustomARecords(t *testing.T) { + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"), + ) + + // This will create 2 separate connectionInfoCache entries, one for + // each DNS name. + d := setupDialer(t, setupConfig{ + testInstance: inst, + reqs: []*mock.Request{ + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + dialerOptions: []Option{ + withMockDNSResolver(&mockNetResolver{ + txtEntries: map[string]string{ + "db.example.com": "my-project:my-region:my-instance", + }, + hostEntries: map[string]string{ + "db.example.com": "1.1.1.1", + }, + }), + WithTokenSource(mock.EmptyTokenSource{}), + WithDNSResolver(), + }, + }) + ctx, cancelFn := context.WithTimeout(context.Background(), 1*time.Second) + defer cancelFn() + _, err := d.Dial(ctx, "db.example.com") + // Expect an error due to the timeout. + if err == nil { + t.Fatal("Dial should have failed due to bad IP address") + } + t.Log("timeout", err) + +} + +func TestDialerFailsDNSTxtRecordMissing(t *testing.T) { inst := mock.NewFakeCSQLInstance( "my-project", "my-region", "my-instance", ) @@ -1073,8 +1166,9 @@ func TestDialerFailsDnsTxtRecordMissing(t *testing.T) { testInstance: inst, reqs: []*mock.Request{}, dialerOptions: []Option{ + withMockDNSResolver(&mockNetResolver{}), WithTokenSource(mock.EmptyTokenSource{}), - WithResolver(&fakeResolver{}), + WithDNSResolver(), }, }) _, err := d.Dial(context.Background(), "doesnt-exist.example.com") @@ -1106,6 +1200,10 @@ func (r *changingResolver) Resolve(ctx context.Context, name string) (instance.C } } +func (r *changingResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + return net.DefaultResolver.LookupHost(ctx, host) +} + func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) { // At first, the resolver will resolve // update.example.com to "my-instance" @@ -1334,7 +1432,6 @@ func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) { ) } - wantName, _ := instance.ParseConnNameWithDomainName(tc.icn, tc.dn) d := setupDialer(t, setupConfig{ testInstance: inst, reqs: []*mock.Request{ @@ -1342,13 +1439,13 @@ func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) { mock.CreateEphemeralSuccess(inst, 1), }, dialerOptions: []Option{ - WithTokenSource(mock.EmptyTokenSource{}), - WithResolver(&fakeResolver{ - entries: map[string]instance.ConnName{ - "db.example.com": wantName, - "my-project:my-region:my-instance": wantName, + withMockDNSResolver(&mockNetResolver{ + txtEntries: map[string]string{ + "db.example.com": "my-project:my-region:my-instance", }, }), + WithTokenSource(mock.EmptyTokenSource{}), + WithDNSResolver(), }, }) dnOrIcn := tc.icn @@ -1375,8 +1472,6 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) { ) // Resolve the dns name 'bad.example.com' to the the instance. - wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "bad.example.com") - d := setupDialer(t, setupConfig{ testInstance: inst, reqs: []*mock.Request{ @@ -1384,12 +1479,13 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) { mock.CreateEphemeralSuccess(inst, 1), }, dialerOptions: []Option{ - WithTokenSource(mock.EmptyTokenSource{}), - WithResolver(&fakeResolver{ - entries: map[string]instance.ConnName{ - "bad.example.com": wantName, + withMockDNSResolver(&mockNetResolver{ + txtEntries: map[string]string{ + "bad.example.com": "my-project:my-region:my-instance", }, }), + WithTokenSource(mock.EmptyTokenSource{}), + WithDNSResolver(), }, }) @@ -1415,8 +1511,6 @@ func TestDialerChecksSubjectAlternativeNameAndFallsBackToCN(t *testing.T) { ) // resolve db.example.com to the same instance - wantName, _ := instance.ParseConnNameWithDomainName("myProject:myRegion:myInstance", "db.example.com") - d := setupDialer(t, setupConfig{ testInstance: inst, reqs: []*mock.Request{ @@ -1425,13 +1519,13 @@ func TestDialerChecksSubjectAlternativeNameAndFallsBackToCN(t *testing.T) { }, dialerOptions: []Option{ - WithTokenSource(mock.EmptyTokenSource{}), - WithResolver(&fakeResolver{ - entries: map[string]instance.ConnName{ - "db.example.com": wantName, - "myProject:myRegion:myInstance": wantName, + withMockDNSResolver(&mockNetResolver{ + txtEntries: map[string]string{ + "db.example.com": "myProject:myRegion:myInstance", }, }), + WithTokenSource(mock.EmptyTokenSource{}), + WithDNSResolver(), }, }) diff --git a/internal/cloudsql/resolver.go b/internal/cloudsql/resolver.go index 1d911890..2d93d7f4 100644 --- a/internal/cloudsql/resolver.go +++ b/internal/cloudsql/resolver.go @@ -17,19 +17,12 @@ package cloudsql import ( "context" "fmt" - "net" "sort" "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" ) -// DNSResolver uses the default net.Resolver to find -// TXT records containing an instance name for a DNS record. -var DNSResolver = &DNSInstanceConnectionNameResolver{ - dnsResolver: net.DefaultResolver, -} - // DefaultResolver simply parses instance names. var DefaultResolver = &ConnNameResolver{} @@ -44,19 +37,26 @@ func (r *ConnNameResolver) Resolve(_ context.Context, icn string) (instanceName return instance.ParseConnName(icn) } -// netResolver groups the methods on net.Resolver that are used by the DNS +// NetResolver groups the methods on net.Resolver that are used by the DNS // resolver implementation. This allows an application to replace the default // net.DefaultResolver with a custom implementation. For example: the // application may need to connect to a specific DNS server using a specially // configured instance of net.Resolver. -type netResolver interface { +type NetResolver interface { LookupTXT(ctx context.Context, name string) ([]string, error) + LookupHost(ctx context.Context, name string) ([]string, error) +} + +// NewDNSResolver returns a new DNSInstanceConnectionNameResolver with the +// provided resolver. +func NewDNSResolver(r NetResolver) *DNSInstanceConnectionNameResolver { + return &DNSInstanceConnectionNameResolver{dnsResolver: r} } // DNSInstanceConnectionNameResolver can resolve domain names into instance names using // TXT records in DNS. Implements InstanceConnectionNameResolver type DNSInstanceConnectionNameResolver struct { - dnsResolver netResolver + dnsResolver NetResolver } // Resolve returns the instance name, possibly using DNS. This will return an diff --git a/internal/cloudsql/resolver_test.go b/internal/cloudsql/resolver_test.go index c4186384..a5065ac4 100644 --- a/internal/cloudsql/resolver_test.go +++ b/internal/cloudsql/resolver_test.go @@ -35,6 +35,10 @@ func (r *fakeResolver) LookupTXT(_ context.Context, name string) (addrs []string return nil, fmt.Errorf("no resolution for %v", name) } +func (r *fakeResolver) LookupHost(_ context.Context, name string) (addrs []string, err error) { + return nil, fmt.Errorf("no resolution for %v", name) +} + func TestDNSInstanceNameResolver_Lookup_Success_TxtRecord(t *testing.T) { want, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") diff --git a/monitored_cache_test.go b/monitored_cache_test.go index 3394ab24..ae193e76 100644 --- a/monitored_cache_test.go +++ b/monitored_cache_test.go @@ -23,6 +23,7 @@ import ( "time" "cloud.google.com/go/cloudsqlconn/instance" + "cloud.google.com/go/cloudsqlconn/internal/cloudsql" ) type testLog struct { @@ -35,7 +36,8 @@ func (l *testLog) Debugf(_ context.Context, f string, args ...interface{}) { func TestMonitoredCache_purgeClosedConns(t *testing.T) { cn, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com") - c := newMonitoredCache(&spyConnectionInfoCache{}, cn, 10*time.Millisecond, &fakeResolver{entries: map[string]instance.ConnName{"db.example.com": cn}}, &testLog{t: t}) + r := cloudsql.NewDNSResolver(&mockNetResolver{txtEntries: map[string]string{"db.example.com": "my-project:my-region:my-instance"}}) + c := newMonitoredCache(&spyConnectionInfoCache{}, cn, 10*time.Millisecond, r, &testLog{t: t}) // Add connections c.mu.Lock() diff --git a/options.go b/options.go index f1276d64..7f08e9d7 100644 --- a/options.go +++ b/options.go @@ -59,6 +59,7 @@ type dialerConfig struct { resolver instance.ConnectionNameResolver failoverPeriod time.Duration metadataExchangeDisabled bool + dnsResolver cloudsql.NetResolver // err tracks any dialer options that may have failed. err error } @@ -315,7 +316,7 @@ func WithResolver(r instance.ConnectionNameResolver) Option { // - Value: `my-project:region:my-instance` – This is the instance name func WithDNSResolver() Option { return func(d *dialerConfig) { - d.resolver = cloudsql.DNSResolver + d.resolver = cloudsql.NewDNSResolver(d.dnsResolver) } } From 4673f9c05a1c9a931d6e91e3786d4c7982ee300f Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Tue, 16 Dec 2025 21:20:00 -0700 Subject: [PATCH 2/2] test: Fix incorrect mock expectation in dialer test --- dialer_test.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dialer_test.go b/dialer_test.go index 2689caba..7b557b00 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -1514,8 +1514,8 @@ func TestDialerChecksSubjectAlternativeNameAndFallsBackToCN(t *testing.T) { d := setupDialer(t, setupConfig{ testInstance: inst, reqs: []*mock.Request{ - mock.InstanceGetSuccess(inst, 1), - mock.CreateEphemeralSuccess(inst, 1), + mock.InstanceGetSuccess(inst, 2), + mock.CreateEphemeralSuccess(inst, 2), }, dialerOptions: []Option{ @@ -1526,8 +1526,7 @@ func TestDialerChecksSubjectAlternativeNameAndFallsBackToCN(t *testing.T) { }), WithTokenSource(mock.EmptyTokenSource{}), WithDNSResolver(), - }, - }) + }}) tcs := []struct { desc string