diff --git a/lib/service/discovery.go b/lib/service/discovery.go index f452fb8b83be3..16f521e9dd919 100644 --- a/lib/service/discovery.go +++ b/lib/service/discovery.go @@ -75,11 +75,6 @@ func (process *TeleportProcess) initDiscoveryService() error { return trace.Wrap(err, "failed to build access graph configuration") } - publicProxyAddress, err := process.publicProxyAddr(accessPoint) - if err != nil { - return trace.Wrap(err, "failed to determine the public proxy address") - } - discoveryService, err := discovery.New(process.ExitContext(), &discovery.Config{ IntegrationOnlyCredentials: process.integrationOnlyCredentials(), Matchers: discovery.Matchers{ @@ -89,17 +84,16 @@ func (process *TeleportProcess) initDiscoveryService() error { Kubernetes: process.Config.Discovery.KubernetesMatchers, AccessGraph: process.Config.Discovery.AccessGraph, }, - DiscoveryGroup: process.Config.Discovery.DiscoveryGroup, - Emitter: asyncEmitter, - AccessPoint: accessPoint, - ServerID: conn.HostUUID(), - Log: process.logger, - ClusterName: conn.ClusterName(), - ClusterFeatures: process.GetClusterFeatures, - PollInterval: process.Config.Discovery.PollInterval, - GetClientCert: conn.ClientGetCertificate, - AccessGraphConfig: accessGraphCfg, - PublicProxyAddress: publicProxyAddress, + DiscoveryGroup: process.Config.Discovery.DiscoveryGroup, + Emitter: asyncEmitter, + AccessPoint: accessPoint, + ServerID: conn.HostUUID(), + Log: process.logger, + ClusterName: conn.ClusterName(), + ClusterFeatures: process.GetClusterFeatures, + PollInterval: process.Config.Discovery.PollInterval, + GetClientCert: conn.ClientGetCertificate, + AccessGraphConfig: accessGraphCfg, }) if err != nil { return trace.Wrap(err) @@ -136,41 +130,6 @@ func (process *TeleportProcess) initDiscoveryService() error { return nil } -type proxiesGetter interface { - GetProxies() ([]types.Server, error) -} - -func (process *TeleportProcess) publicProxyAddr(accessPoint proxiesGetter) (string, error) { - // If the proxy server is explicitly set, use that. - if !process.Config.ProxyServer.IsEmpty() { - return process.Config.ProxyServer.String(), nil - } - - // If DiscoveryService is running alongside a Proxy, use the first - // public address of the Proxy. - if process.Config.Proxy.Enabled { - for _, proxyAddr := range process.Config.Proxy.PublicAddrs { - if !proxyAddr.IsEmpty() { - return proxyAddr.String(), nil - } - } - } - - proxies, err := accessPoint.GetProxies() - if err != nil { - return "", trace.Wrap(err) - } - for _, proxy := range proxies { - for _, proxyAddr := range proxy.GetPublicAddrs() { - if proxyAddr != "" { - return proxyAddr, nil - } - } - } - - return "", trace.NotFound("could not find the public proxy address for server discovery") -} - // integrationOnlyCredentials indicates whether the DiscoveryService must only use Cloud APIs credentials using an integration. // // If Auth is running alongside this DiscoveryService and License is Cloud, then this process is running in Teleport's Cloud Infra. diff --git a/lib/service/discovery_test.go b/lib/service/discovery_test.go index 8aa30de15cce4..226b308765634 100644 --- a/lib/service/discovery_test.go +++ b/lib/service/discovery_test.go @@ -27,13 +27,11 @@ import ( "github.com/stretchr/testify/require" clusterconfigpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/clusterconfig/v1" - "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/modules/modulestest" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/srv/discovery" - "github.com/gravitational/teleport/lib/utils" ) func TestTeleportProcessIntegrationsOnly(t *testing.T) { @@ -167,81 +165,6 @@ func TestTeleportProcess_initDiscoveryService(t *testing.T) { }) } } -func TestProcessPublicProxyAddr(t *testing.T) { - proxyServerWithPublicAddr := func(addr string) *types.ServerV2 { - return &types.ServerV2{ - Spec: types.ServerSpecV2{ - PublicAddrs: []string{addr}, - }, - } - } - - tests := []struct { - name string - config *servicecfg.Config - proxyGetter proxiesGetter - wantAddr string - errCheck require.ErrorAssertionFunc - }{ - { - name: "proxy server was set in config", - config: &servicecfg.Config{ - ProxyServer: utils.NetAddr{Addr: "proxy.example.com:3080"}, - }, - wantAddr: "proxy.example.com:3080", - errCheck: require.NoError, - }, - { - name: "proxy is running alongside discovery service", - config: &servicecfg.Config{ - Proxy: servicecfg.ProxyConfig{ - Enabled: true, - PublicAddrs: []utils.NetAddr{ - {Addr: "public.proxy.com:443"}, - }, - }, - }, - wantAddr: "public.proxy.com:443", - errCheck: require.NoError, - }, - { - name: "discovery service is running alongside auth, (no proxy server defined and no proxy service enabled)", - config: &servicecfg.Config{}, - proxyGetter: &mockProxyGetter{ - servers: []types.Server{proxyServerWithPublicAddr("proxy.example:8080")}, - }, - wantAddr: "proxy.example:8080", - errCheck: require.NoError, - }, - { - name: "no proxy available", - config: &servicecfg.Config{}, - proxyGetter: &mockProxyGetter{ - servers: []types.Server{}, - }, - errCheck: require.Error, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - process := &TeleportProcess{ - Config: tt.config, - } - addr, err := process.publicProxyAddr(tt.proxyGetter) - tt.errCheck(t, err) - require.Equal(t, tt.wantAddr, addr) - }) - } -} - -type mockProxyGetter struct { - servers []types.Server -} - -func (f *mockProxyGetter) GetProxies() ([]types.Server, error) { - return f.servers, nil -} type fakeClient struct { authclient.ClientI diff --git a/lib/srv/discovery/config_test.go b/lib/srv/discovery/config_test.go index e3b7a220bf619..615c1852643a8 100644 --- a/lib/srv/discovery/config_test.go +++ b/lib/srv/discovery/config_test.go @@ -112,14 +112,6 @@ func TestConfigCheckAndSetDefaults(t *testing.T) { }, postCheckAndSetDefaultsFunc: func(t *testing.T, c *Config) {}, }, - { - name: "missing public proxy address", - errAssertFunc: require.Error, - cfgChange: func(c *Config) { - c.PublicProxyAddress = "" - }, - postCheckAndSetDefaultsFunc: func(t *testing.T, c *Config) {}, - }, { name: "missing cluster features", errAssertFunc: require.Error, @@ -145,8 +137,7 @@ func TestConfigCheckAndSetDefaults(t *testing.T) { ClusterFeatures: func() proto.Features { return proto.Features{} }, - DiscoveryGroup: "test", - PublicProxyAddress: "proxy.example.com", + DiscoveryGroup: "test", } tt.cfgChange(cfg) err := cfg.CheckAndSetDefaults() diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index bbe09dc936e5f..0d911fe2526ba 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -118,11 +118,6 @@ type Config struct { // CloudClients is an interface for retrieving cloud clients. CloudClients cloud.Clients - // PublicProxyAddress is the public address of the proxy. - // Used to configure installation scripts for Server auto discovery. - // Example: proxy.example.com:443 or proxy.example.com - PublicProxyAddress string - // AWSFetchersClients gets the AWS clients for the given region for the fetchers. AWSFetchersClients fetchers.AWSClientGetter @@ -239,10 +234,6 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("no AccessPoint configured for discovery") } - if c.PublicProxyAddress == "" { - return trace.BadParameter("no PublicProxyAddress configured for discovery") - } - if len(c.Matchers.Kubernetes) > 0 && c.DiscoveryGroup == "" { return trace.BadParameter(`the DiscoveryGroup name should be set for discovery server if kubernetes matchers are present.`) @@ -569,6 +560,25 @@ func (s *Server) startDynamicMatchersWatcher(ctx context.Context) error { return nil } +// publicProxyAddress returns the public proxy address to use for installation scripts. +// This is only used if the matcher does not specify a ProxyAddress. +// Example: proxy.example.com:3080 or proxy.example.com +func (s *Server) publicProxyAddress(ctx context.Context) (string, error) { + proxies, err := s.AccessPoint.GetProxies() + if err != nil { + return "", trace.Wrap(err) + } + for _, proxy := range proxies { + for _, proxyAddr := range proxy.GetPublicAddrs() { + if proxyAddr != "" { + return proxyAddr, nil + } + } + } + + return "", trace.NotFound("could not find the public proxy address for server discovery") +} + // initAWSWatchers starts AWS resource watchers based on types provided. func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error { var err error @@ -578,9 +588,9 @@ func (s *Server) initAWSWatchers(matchers []types.AWSMatcher) error { }) s.staticServerAWSFetchers, err = server.MatchersToEC2InstanceFetchers(s.ctx, server.MatcherToEC2FetcherParams{ - Matchers: ec2Matchers, - EC2ClientGetter: s.GetEC2Client, - PublicProxyAddr: s.PublicProxyAddress, + Matchers: ec2Matchers, + EC2ClientGetter: s.GetEC2Client, + PublicProxyAddrGetter: s.publicProxyAddress, }) if err != nil { return trace.Wrap(err) @@ -701,10 +711,10 @@ func (s *Server) awsServerFetchersFromMatchers(ctx context.Context, matchers []t }) fetchers, err := server.MatchersToEC2InstanceFetchers(ctx, server.MatcherToEC2FetcherParams{ - Matchers: serverMatchers, - EC2ClientGetter: s.GetEC2Client, - DiscoveryConfigName: discoveryConfigName, - PublicProxyAddr: s.PublicProxyAddress, + Matchers: serverMatchers, + EC2ClientGetter: s.GetEC2Client, + DiscoveryConfigName: discoveryConfigName, + PublicProxyAddrGetter: s.publicProxyAddress, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/discovery/discovery_eks_test.go b/lib/srv/discovery/discovery_eks_test.go index 54569e55749e7..0e90ffa096683 100644 --- a/lib/srv/discovery/discovery_eks_test.go +++ b/lib/srv/discovery/discovery_eks_test.go @@ -273,13 +273,12 @@ func TestDiscoveryServerEKS(t *testing.T) { AWSConfigProvider: fakeConfigProvider, eksClusters: tt.eksClusters, }, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - AccessPoint: mockAccessPoint, - Matchers: Matchers{}, - Emitter: tt.emitter, - DiscoveryGroup: defaultDiscoveryGroup, - Log: logtest.NewLogger(), - PublicProxyAddress: "proxy.example.com", + ClusterFeatures: func() proto.Features { return proto.Features{} }, + AccessPoint: mockAccessPoint, + Matchers: Matchers{}, + Emitter: tt.emitter, + DiscoveryGroup: defaultDiscoveryGroup, + Log: logtest.NewLogger(), }) require.NoError(t, err) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index b0c4038593fdd..e7ccd859cf9f8 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -245,6 +245,18 @@ func TestDiscoveryServer(t *testing.T) { }, }}, } + staticMatcherUsingManagedSSMDoc := Matchers{ + AWS: []types.AWSMatcher{{ + Types: []string{"ec2"}, + Regions: []string{"eu-central-1"}, + Tags: map[string]utils.Strings{"teleport": {"yes"}}, + SSM: &types.AWSSSM{DocumentName: "AWS-RunShellScript"}, + Params: &types.InstallerParams{ + InstallTeleport: true, + EnrollMode: types.InstallParamEnrollMode_INSTALL_PARAM_ENROLL_MODE_SCRIPT, + }, + }}, + } defaultDiscoveryConfig, err := discoveryconfig.NewDiscoveryConfig( header.Metadata{Name: uuid.NewString()}, @@ -319,7 +331,8 @@ func TestDiscoveryServer(t *testing.T) { require.NoError(t, err) tcs := []struct { - name string + name string + requiresProxy bool // presentInstances is a list of servers already present in teleport. presentInstances []types.Server foundEC2Instances []ec2types.Instance @@ -378,6 +391,88 @@ func TestDiscoveryServer(t *testing.T) { staticMatchers: defaultStaticMatcher, wantInstalledInstances: []string{"instance-id-1"}, }, + { + name: "no nodes present, 1 found, using the pre-defined SSM-RunShellScript", + requiresProxy: true, + presentInstances: []types.Server{}, + foundEC2Instances: []ec2types.Instance{ + { + InstanceId: aws.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: aws.String("env"), + Value: aws.String("dev"), + }}, + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, + }, + }, + }, + ssm: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), + }, + }, + invokeOutput: &ssm.GetCommandInvocationOutput{ + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + }, + }, + emitter: &mockEmitter{ + eventHandler: func(t *testing.T, ae events.AuditEvent, server *Server) { + t.Helper() + require.Equal(t, &events.SSMRun{ + Metadata: events.Metadata{ + Type: libevents.SSMRunEvent, + Code: libevents.SSMRunSuccessCode, + }, + CommandID: "command-id-1", + AccountID: "owner", + InstanceID: "instance-id-1", + Region: "eu-central-1", + ExitCode: 0, + Status: string(ssmtypes.CommandInvocationStatusSuccess), + }, ae) + }, + }, + staticMatchers: staticMatcherUsingManagedSSMDoc, + wantInstalledInstances: []string{"instance-id-1"}, + }, + { + name: "fails if proxy address is not available when using AWS-RunShellScript", + requiresProxy: false, + presentInstances: []types.Server{}, + foundEC2Instances: []ec2types.Instance{ + { + InstanceId: aws.String("instance-id-1"), + Tags: []ec2types.Tag{{ + Key: aws.String("env"), + Value: aws.String("dev"), + }}, + State: &ec2types.InstanceState{ + Name: ec2types.InstanceStateNameRunning, + }, + }, + }, + ssm: &mockSSMClient{ + commandOutput: &ssm.SendCommandOutput{ + Command: &ssmtypes.Command{ + CommandId: aws.String("command-id-1"), + }, + }, + invokeOutput: &ssm.GetCommandInvocationOutput{ + Status: ssmtypes.CommandInvocationStatusSuccess, + ResponseCode: 0, + }, + }, + emitter: &mockEmitter{ + eventHandler: func(t *testing.T, ae events.AuditEvent, server *Server) { + t.Helper() + }, + }, + staticMatchers: staticMatcherUsingManagedSSMDoc, + wantInstalledInstances: []string{}, + }, { name: "nodes present, instance filtered", presentInstances: []types.Server{ @@ -706,6 +801,19 @@ func TestDiscoveryServer(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { require.NoError(t, testAuthServer.Close()) }) + if tc.requiresProxy { + err = testAuthServer.AuthServer.UpsertProxy(ctx, &types.ServerV2{ + Kind: types.KindProxy, + Metadata: types.Metadata{ + Name: "proxy", + }, + Spec: types.ServerSpecV2{ + PublicAddrs: []string{"proxy.example.com:443"}, + }, + }) + require.NoError(t, err) + } + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC(types.Metadata{ Name: "my-integration", }, &types.AWSOIDCIntegrationSpecV1{ @@ -764,15 +872,14 @@ func TestDiscoveryServer(t *testing.T) { AWSFetchersClients: &mockFetchersClients{ AWSConfigProvider: fakeConfigProvider, }, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPointWithEKSEnroller(tlsServer.Auth(), authClient, authClient.IntegrationAWSOIDCClient()), - Matchers: tc.staticMatchers, - Emitter: tc.emitter, - Log: logger, - DiscoveryGroup: defaultDiscoveryGroup, - clock: fakeClock, - PublicProxyAddress: "proxy.example.com", + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPointWithEKSEnroller(tlsServer.Auth(), authClient, authClient.IntegrationAWSOIDCClient()), + Matchers: tc.staticMatchers, + Emitter: tc.emitter, + Log: logger, + DiscoveryGroup: defaultDiscoveryGroup, + clock: fakeClock, }) require.NoError(t, err) server.ec2Installer = installer @@ -945,31 +1052,29 @@ func TestDiscoveryServerConcurrency(t *testing.T) { // Create Server1 server1, err := New(authz.ContextWithUser(ctx, identity.I), &Config{ - CloudClients: testCloudClients, - GetEC2Client: getEC2Client, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - Matchers: staticMatcher, - Emitter: emitter, - Log: logger, - DiscoveryGroup: defaultDiscoveryGroup, - PublicProxyAddress: "proxy.example.com", + CloudClients: testCloudClients, + GetEC2Client: getEC2Client, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + Matchers: staticMatcher, + Emitter: emitter, + Log: logger, + DiscoveryGroup: defaultDiscoveryGroup, }) require.NoError(t, err) // Create Server2 server2, err := New(authz.ContextWithUser(ctx, identity.I), &Config{ - CloudClients: testCloudClients, - GetEC2Client: getEC2Client, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - Matchers: staticMatcher, - Emitter: emitter, - Log: logger, - DiscoveryGroup: defaultDiscoveryGroup, - PublicProxyAddress: "proxy.example.com", + CloudClients: testCloudClients, + GetEC2Client: getEC2Client, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + Matchers: staticMatcher, + Emitter: emitter, + Log: logger, + DiscoveryGroup: defaultDiscoveryGroup, }) require.NoError(t, err) @@ -1171,10 +1276,9 @@ func TestDiscoveryKubeServices(t *testing.T) { Matchers: Matchers{ Kubernetes: tt.kubernetesMatchers, }, - PublicProxyAddress: "proxy.example.com", - Emitter: authClient, - DiscoveryGroup: mainDiscoveryGroup, - protocolChecker: &noopProtocolChecker{}, + Emitter: authClient, + DiscoveryGroup: mainDiscoveryGroup, + protocolChecker: &noopProtocolChecker{}, }) require.NoError(t, err) @@ -1510,7 +1614,6 @@ func TestDiscoveryInCloudKube(t *testing.T) { ClusterFeatures: func() proto.Features { return proto.Features{} }, KubernetesClient: fake.NewSimpleClientset(), AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - PublicProxyAddress: "proxy.example.com", Matchers: Matchers{ AWS: tc.awsMatchers, Azure: tc.azureMatchers, @@ -1664,7 +1767,6 @@ func TestDiscoveryServer_New(t *testing.T) { Matchers: tt.matchers, Emitter: &mockEmitter{}, protocolChecker: &noopProtocolChecker{}, - PublicProxyAddress: "proxy.example.com", }) tt.errAssertion(t, err) @@ -2459,7 +2561,6 @@ func TestDiscoveryDatabase(t *testing.T) { AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), AWSDatabaseFetcherFactory: dbFetcherFactory, AWSConfigProvider: fakeConfigProvider, - PublicProxyAddress: "proxy.example.com", Matchers: Matchers{ AWS: tc.awsMatchers, Azure: tc.azureMatchers, @@ -2595,7 +2696,6 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { Matchers: Matchers{}, Emitter: authClient, DiscoveryGroup: mainDiscoveryGroup, - PublicProxyAddress: "proxy.example.com", clock: clock, }) @@ -3023,15 +3123,14 @@ func TestAzureVMDiscovery(t *testing.T) { } tlsServer.Auth().SetUsageReporter(reporter) server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - CloudClients: testCloudClients, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - Matchers: tc.staticMatchers, - Emitter: emitter, - Log: logger, - DiscoveryGroup: defaultDiscoveryGroup, - PublicProxyAddress: "proxy.example.com", + CloudClients: testCloudClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + Matchers: tc.staticMatchers, + Emitter: emitter, + Log: logger, + DiscoveryGroup: defaultDiscoveryGroup, }) require.NoError(t, err) @@ -3330,15 +3429,14 @@ func TestGCPVMDiscovery(t *testing.T) { } tlsServer.Auth().SetUsageReporter(reporter) server, err := New(authz.ContextWithUser(context.Background(), identity.I), &Config{ - CloudClients: testCloudClients, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - Matchers: tc.staticMatchers, - Emitter: emitter, - Log: logger, - DiscoveryGroup: defaultDiscoveryGroup, - PublicProxyAddress: "proxy.example.com", + CloudClients: testCloudClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + Matchers: tc.staticMatchers, + Emitter: emitter, + Log: logger, + DiscoveryGroup: defaultDiscoveryGroup, }) require.NoError(t, err) @@ -3530,8 +3628,7 @@ func TestEmitUsageEvents(t *testing.T) { ResourceTags: types.Labels{"teleport": {"yes"}}, }}, }, - PublicProxyAddress: "proxy.example.com", - Emitter: &mockEmitter{}, + Emitter: &mockEmitter{}, }) require.NoError(t, err) diff --git a/lib/srv/discovery/kube_integration_watcher_test.go b/lib/srv/discovery/kube_integration_watcher_test.go index c6b25ca74f95f..04789cc6a9b8d 100644 --- a/lib/srv/discovery/kube_integration_watcher_test.go +++ b/lib/srv/discovery/kube_integration_watcher_test.go @@ -408,10 +408,9 @@ func TestDiscoveryKubeIntegrationEKS(t *testing.T) { Matchers: Matchers{ AWS: tc.awsMatchers, }, - Emitter: authClient, - Log: logtest.NewLogger(), - DiscoveryGroup: mainDiscoveryGroup, - PublicProxyAddress: "proxy.example.com", + Emitter: authClient, + Log: logtest.NewLogger(), + DiscoveryGroup: mainDiscoveryGroup, }) require.NoError(t, err) diff --git a/lib/srv/server/ec2_watcher.go b/lib/srv/server/ec2_watcher.go index 3133620e790a7..a780e356f3b6e 100644 --- a/lib/srv/server/ec2_watcher.go +++ b/lib/srv/server/ec2_watcher.go @@ -19,7 +19,6 @@ package server import ( - "cmp" "context" "fmt" "log/slog" @@ -211,10 +210,10 @@ type MatcherToEC2FetcherParams struct { // DiscoveryConfigName is the name of the DiscoveryConfig that contains the matchers. // Empty if using static matchers (coming from the `teleport.yaml`). DiscoveryConfigName string - // PublicProxyAddr is the public proxy address to use for installation scripts. + // PublicProxyAddrGetter returns the public proxy address to use for installation scripts. // This is only used if the matcher does not specify a ProxyAddress. // Example: proxy.example.com:3080 or proxy.example.com - PublicProxyAddr string + PublicProxyAddrGetter func(context.Context) (string, error) } // MatchersToEC2InstanceFetchers converts a list of AWS EC2 Matchers into a list of AWS EC2 Fetchers. @@ -253,17 +252,17 @@ func matchersToEC2InstanceFetchers(matcherParams MatcherToEC2FetcherParams, getE for _, matcher := range matcherParams.Matchers { for _, region := range matcher.Regions { fetcher := newEC2InstanceFetcher(ec2FetcherConfig{ - ProxyPublicAddr: matcherParams.PublicProxyAddr, - Matcher: matcher, - Region: region, - Document: matcher.SSM.DocumentName, - InstallSuffix: matcher.Params.Suffix, - UpdateGroup: matcher.Params.UpdateGroup, - EC2ClientGetter: getEC2Client.withMatcher(&matcher), - Labels: matcher.Tags, - Integration: matcher.Integration, - DiscoveryConfigName: matcherParams.DiscoveryConfigName, - EnrollMode: matcher.Params.EnrollMode, + ProxyPublicAddrGetter: matcherParams.PublicProxyAddrGetter, + Matcher: matcher, + Region: region, + Document: matcher.SSM.DocumentName, + InstallSuffix: matcher.Params.Suffix, + UpdateGroup: matcher.Params.UpdateGroup, + EC2ClientGetter: getEC2Client.withMatcher(&matcher), + Labels: matcher.Tags, + Integration: matcher.Integration, + DiscoveryConfigName: matcherParams.DiscoveryConfigName, + EnrollMode: matcher.Params.EnrollMode, }) ret = append(ret, fetcher) } @@ -282,9 +281,10 @@ type ec2FetcherConfig struct { Integration string DiscoveryConfigName string EnrollMode types.InstallParamEnrollMode - // ProxyPublicAddr is the public proxy address to use for installation scripts. - // Format: "proxy.example.com:443" or "proxy.example.com" - ProxyPublicAddr string + // PublicProxyAddrGetter returns the public proxy address to use for installation scripts. + // This is only used if the matcher does not specify a ProxyAddress. + // Example: proxy.example.com:3080 or proxy.example.com + ProxyPublicAddrGetter func(ctx context.Context) (string, error) } type ec2InstanceFetcher struct { @@ -402,7 +402,7 @@ func ssmRunCommandParametersForCustomDocuments(cfg ec2FetcherConfig, envVars []s return parameters } -func ssmRunCommandParameters(cfg ec2FetcherConfig) map[string]string { +func ssmRunCommandParameters(ctx context.Context, cfg ec2FetcherConfig) (map[string]string, error) { var envVars []string // InstallSuffix and UpdateGroup only contains alphanumeric characters and hyphens. @@ -421,10 +421,17 @@ func ssmRunCommandParameters(cfg ec2FetcherConfig) map[string]string { // However, when using the pre-defined SSM Document AWS-RunShellScript, we need to construct // the "commands" parameter. if cfg.Document != types.AWSSSMDocumentRunShellScript { - return ssmRunCommandParametersForCustomDocuments(cfg, envVars) + return ssmRunCommandParametersForCustomDocuments(cfg, envVars), nil } - publicProxyAddr := cmp.Or(cfg.Matcher.Params.PublicProxyAddr, cfg.ProxyPublicAddr) + publicProxyAddr := cfg.Matcher.Params.PublicProxyAddr + if publicProxyAddr == "" { + proxyAddr, err := cfg.ProxyPublicAddrGetter(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + publicProxyAddr = proxyAddr + } safeProxyAddr := shsprintf.EscapeDefaultContext(publicProxyAddr) safeScriptName := shsprintf.EscapeDefaultContext(cfg.Matcher.Params.ScriptName) @@ -442,15 +449,20 @@ func ssmRunCommandParameters(cfg ec2FetcherConfig) map[string]string { return map[string]string{ ParamCommands: command, - } + }, nil } // GetMatchingInstances returns a list of EC2 instances from a list of matching Teleport nodes func (f *ec2InstanceFetcher) GetMatchingInstances(nodes []types.Server, rotation bool) ([]Instances, error) { + ssmRunParams, err := ssmRunCommandParameters(context.Background(), f.ec2FetcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + insts := EC2Instances{ Region: f.Region, DocumentName: f.Document, - Parameters: ssmRunCommandParameters(f.ec2FetcherConfig), + Parameters: ssmRunParams, Rotation: rotation, Integration: f.Integration, DiscoveryConfigName: f.DiscoveryConfigName, @@ -514,6 +526,11 @@ func chunkInstances(insts EC2Instances) []Instances { // GetInstances fetches all EC2 instances matching configured filters. func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([]Instances, error) { + ssmRunParams, err := ssmRunCommandParameters(context.Background(), f.ec2FetcherConfig) + if err != nil { + return nil, trace.Wrap(err) + } + ec2Client, err := f.EC2ClientGetter(ctx, f.Region) if err != nil { return nil, trace.Wrap(err) @@ -540,7 +557,7 @@ func (f *ec2InstanceFetcher) GetInstances(ctx context.Context, rotation bool) ([ Region: f.Region, DocumentName: f.Document, Instances: ToEC2Instances(res.Instances[i:end]), - Parameters: ssmRunCommandParameters(f.ec2FetcherConfig), + Parameters: ssmRunParams, Rotation: rotation, Integration: f.Integration, AssumeRoleARN: f.Matcher.AssumeRole.RoleARN, diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go index fbcc440ffc729..c2ffb5f33ebcf 100644 --- a/lib/srv/server/ec2_watcher_test.go +++ b/lib/srv/server/ec2_watcher_test.go @@ -27,6 +27,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" @@ -278,8 +279,10 @@ func TestEC2Watcher(t *testing.T) { fetchersFn := func() []Fetcher { fetchers, err := matchersToEC2InstanceFetchers(MatcherToEC2FetcherParams{ - Matchers: matchers, - PublicProxyAddr: "proxy.example.com:3080", + Matchers: matchers, + PublicProxyAddrGetter: func(ctx context.Context) (string, error) { + return "proxy.example.com:3080", nil + }, }, getClient) require.NoError(t, err) @@ -516,6 +519,7 @@ func TestSSMRunCommandParameters(t *testing.T) { for _, tt := range []struct { name string cfg ec2FetcherConfig + errCheck require.ErrorAssertionFunc expectedParams map[string]string }{ { @@ -530,6 +534,7 @@ func TestSSMRunCommandParameters(t *testing.T) { }, Document: "TeleportDiscoveryInstaller", }, + errCheck: require.NoError, expectedParams: map[string]string{ "token": "my-token", "scriptName": "default-installer", @@ -548,6 +553,7 @@ func TestSSMRunCommandParameters(t *testing.T) { }, Document: "TeleportDiscoveryInstaller", }, + errCheck: require.NoError, expectedParams: map[string]string{ "token": "my-token", "scriptName": "default-agentless-installer", @@ -564,9 +570,12 @@ func TestSSMRunCommandParameters(t *testing.T) { ScriptName: "default-installer", }, }, - Document: "AWS-RunShellScript", - ProxyPublicAddr: "proxy.example.com", + Document: "AWS-RunShellScript", + ProxyPublicAddrGetter: func(ctx context.Context) (string, error) { + return "proxy.example.com", nil + }, }, + errCheck: require.NoError, expectedParams: map[string]string{ "commands": "curl -s -L https://proxy.example.com/v1/webapi/scripts/installer/default-installer | bash -s my-token", }, @@ -581,18 +590,42 @@ func TestSSMRunCommandParameters(t *testing.T) { ScriptName: "default-installer", }, }, - Document: "AWS-RunShellScript", - ProxyPublicAddr: "proxy.example.com", - InstallSuffix: "cluster-green", + Document: "AWS-RunShellScript", + ProxyPublicAddrGetter: func(ctx context.Context) (string, error) { + return "proxy.example.com", nil + }, + InstallSuffix: "cluster-green", }, + errCheck: require.NoError, expectedParams: map[string]string{ "commands": "export TELEPORT_INSTALL_SUFFIX=cluster-green; curl -s -L https://proxy.example.com/v1/webapi/scripts/installer/default-installer | bash -s my-token", }, }, + { + name: "error if using AWS-RunShellScript but proxy addr is not yet available", + cfg: ec2FetcherConfig{ + Matcher: types.AWSMatcher{ + Params: &types.InstallerParams{ + InstallTeleport: true, + JoinToken: "my-token", + ScriptName: "default-installer", + }, + }, + Document: "AWS-RunShellScript", + ProxyPublicAddrGetter: func(ctx context.Context) (string, error) { + return "", trace.NotFound("proxy is not yet available") + }, + InstallSuffix: "cluster-green", + }, + errCheck: require.Error, + }, } { t.Run(tt.name, func(t *testing.T) { - got := ssmRunCommandParameters(tt.cfg) - require.Equal(t, tt.expectedParams, got) + got, err := ssmRunCommandParameters(t.Context(), tt.cfg) + tt.errCheck(t, err) + if tt.expectedParams != nil { + require.Equal(t, tt.expectedParams, got) + } }) } }