diff --git a/integration/appaccess/pack.go b/integration/appaccess/pack.go index 51fc1db0db012..0af7ddf7127a6 100644 --- a/integration/appaccess/pack.go +++ b/integration/appaccess/pack.go @@ -45,7 +45,7 @@ import ( "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/httplib/csrf" "github.com/gravitational/teleport/lib/httplib/reverseproxy" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -753,7 +753,7 @@ func (p *Pack) startRootAppServers(t *testing.T, count int, opts AppTestOptions) return servers } -func waitForAppServer(t *testing.T, tunnel reversetunnel.Server, name string, hostUUID string, apps []servicecfg.App) { +func waitForAppServer(t *testing.T, tunnel reversetunnelclient.Server, name string, hostUUID string, apps []servicecfg.App) { // Make sure that the app server is ready to accept connections. // The remote site cache needs to be filled with new registered application services. waitForAppRegInRemoteSiteCache(t, tunnel, name, apps, hostUUID) @@ -891,7 +891,7 @@ func (p *Pack) startLeafAppServers(t *testing.T, count int, opts AppTestOptions) return servers } -func waitForAppRegInRemoteSiteCache(t *testing.T, tunnel reversetunnel.Server, clusterName string, cfgApps []servicecfg.App, hostUUID string) { +func waitForAppRegInRemoteSiteCache(t *testing.T, tunnel reversetunnelclient.Server, clusterName string, cfgApps []servicecfg.App, hostUUID string) { require.Eventually(t, func() bool { site, err := tunnel.GetSite(clusterName) require.NoError(t, err) diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index 1df4164aa70d5..3a3d9dd8f9f18 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -56,6 +56,7 @@ import ( "github.com/gravitational/teleport/lib/httplib/csrf" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -258,7 +259,7 @@ type TeleInstance struct { // Internal stuff... Process *service.TeleportProcess Config *servicecfg.Config - Tunnel reversetunnel.Server + Tunnel reversetunnelclient.Server RemoteClusterWatcher *reversetunnel.RemoteClusterTunnelManager // Nodes is a list of additional nodes @@ -1033,7 +1034,7 @@ type ProxyConfig struct { } // StartProxy starts another Proxy Server and connects it to the cluster. -func (i *TeleInstance) StartProxy(cfg ProxyConfig, opts ...Option) (reversetunnel.Server, *service.TeleportProcess, error) { +func (i *TeleInstance) StartProxy(cfg ProxyConfig, opts ...Option) (reversetunnelclient.Server, *service.TeleportProcess, error) { dataDir, err := os.MkdirTemp("", "cluster-"+i.Secrets.SiteName+"-"+cfg.Name) if err != nil { return nil, nil, trace.Wrap(err) @@ -1107,12 +1108,12 @@ func (i *TeleInstance) StartProxy(cfg ProxyConfig, opts ...Option) (reversetunne log.Debugf("Teleport proxy (in instance %v) started: %v/%v events received.", i.Secrets.SiteName, len(expectedEvents), len(receivedEvents)) - // Extract and set reversetunnel.Server and reversetunnel.AgentPool upon + // Extract and set reversetunnelclient.Server and reversetunnel.AgentPool upon // receipt of a ProxyReverseTunnelReady event for _, re := range receivedEvents { switch re.Name { case service.ProxyReverseTunnelReady: - ts, ok := re.Payload.(reversetunnel.Server) + ts, ok := re.Payload.(reversetunnelclient.Server) if ok { return ts, process, nil } @@ -1220,12 +1221,12 @@ func (i *TeleInstance) Start() error { return trace.Wrap(err) } - // Extract and set reversetunnel.Server and reversetunnel.AgentPool upon + // Extract and set reversetunnelclient.Server and reversetunnel.AgentPool upon // receipt of a ProxyReverseTunnelReady and ProxyAgentPoolReady respectively. for _, re := range receivedEvents { switch re.Name { case service.ProxyReverseTunnelReady: - ts, ok := re.Payload.(reversetunnel.Server) + ts, ok := re.Payload.(reversetunnelclient.Server) if ok { i.Tunnel = ts } diff --git a/integration/helpers/trustedclusters.go b/integration/helpers/trustedclusters.go index cdcd1ceb0112b..42781025a64ca 100644 --- a/integration/helpers/trustedclusters.go +++ b/integration/helpers/trustedclusters.go @@ -28,7 +28,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" ) // WaitForTunnelConnections waits for remote tunnels connections @@ -76,7 +76,7 @@ func TryCreateTrustedCluster(t *testing.T, authServer *auth.Server, trustedClust require.FailNow(t, "Timeout creating trusted cluster") } -func WaitForClusters(tun reversetunnel.Server, expected int) func() bool { +func WaitForClusters(tun reversetunnelclient.Server, expected int) func() bool { return func() bool { clusters, err := tun.GetSites() if err != nil { @@ -129,7 +129,7 @@ func WaitForNodeCount(ctx context.Context, t *TeleInstance, clusterName string, } // WaitForActiveTunnelConnections waits for remote cluster to report a minimum number of active connections -func WaitForActiveTunnelConnections(t *testing.T, tunnel reversetunnel.Server, clusterName string, expectedCount int) { +func WaitForActiveTunnelConnections(t *testing.T, tunnel reversetunnelclient.Server, clusterName string, expectedCount int) { require.Eventually(t, func() bool { cluster, err := tunnel.GetSite(clusterName) if err != nil { diff --git a/integration/integration_test.go b/integration/integration_test.go index 45718636d6a01..39db70cfcabad 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -84,7 +84,7 @@ import ( "github.com/gravitational/teleport/lib/events/filesessions" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/pam" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -3810,7 +3810,7 @@ func testDiscoveryRecovers(t *testing.T, suite *integrationTestSuite) { var reverseTunnelAddr string // Helper function for adding a new proxy to "main". - addNewMainProxy := func(name string) (reversetunnel.Server, helpers.ProxyConfig) { + addNewMainProxy := func(name string) (reversetunnelclient.Server, helpers.ProxyConfig) { t.Logf("adding main proxy %q...", name) newConfig := helpers.ProxyConfig{ Name: name, diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index ee42ca1bf82f3..cac224aefda46 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -55,7 +55,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/alpnproxy" @@ -602,7 +602,7 @@ func makeNodeConfig(nodeName, proxyAddr string) *servicecfg.Config { } // waitForActivePeerProxyConnections waits for remote cluster to report a minimum number of active proxy peer connections -func waitForActivePeerProxyConnections(t *testing.T, tunnel reversetunnel.Server, expectedCount int) { //nolint:unused // Only used by skipped test TestProxyTunnelStrategyProxyPeering +func waitForActivePeerProxyConnections(t *testing.T, tunnel reversetunnelclient.Server, expectedCount int) { //nolint:unused // Only used by skipped test TestProxyTunnelStrategyProxyPeering require.Eventually(t, func() bool { return tunnel.GetProxyPeerClient().GetConnectionsCount() >= expectedCount }, diff --git a/lib/benchmark/web.go b/lib/benchmark/web.go index 6be3bfc559891..0896269996d6b 100644 --- a/lib/benchmark/web.go +++ b/lib/benchmark/web.go @@ -77,7 +77,7 @@ func (s WebSSHBenchmark) BenchBuilder(ctx context.Context, tc *client.TeleportCl return nil, trace.BadParameter("random ssh bench commands must use the format @all ") } - servers, err := s.getServers(ctx, tc) + servers, err := getServers(ctx, tc) if err != nil { return nil, trace.Wrap(err) } @@ -92,10 +92,39 @@ func (s WebSSHBenchmark) BenchBuilder(ctx context.Context, tc *client.TeleportCl }, nil } +type webSession struct { + mu sync.Mutex + webSession types.WebSession + clt *client.WebClient +} + +func (s *webSession) renew(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Until(s.expires().Add(-3 * time.Minute))): + resp, err := s.clt.PostJSON(ctx, s.clt.Endpoint("webapi", "sessions", "renew"), nil) + if err != nil { + continue + } + + session, err := client.GetSessionFromResponse(resp) + if err != nil { + continue + } + + s.mu.Lock() + s.webSession = session + s.mu.Unlock() + } + } +} + // runCommand starts a non-interactive SSH session and executes the provided // command before terminating the session. func (s WebSSHBenchmark) runCommand(ctx context.Context, tc *client.TeleportClient, webSess *webSession, host, command string) error { - stream, err := s.connectToHost(ctx, tc, webSess, host) + stream, err := connectToHost(ctx, tc, webSess, host) if err != nil { return trace.Wrap(err) } @@ -112,29 +141,8 @@ func (s WebSSHBenchmark) runCommand(ctx context.Context, tc *client.TeleportClie return nil } -// getServers returns all [types.Server] that the authenticated user has -// access to. -func (s WebSSHBenchmark) getServers(ctx context.Context, tc *client.TeleportClient) ([]types.Server, error) { - clt, err := tc.ConnectToCluster(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - defer clt.Close() - - resources, err := apiclient.GetAllResources[types.Server](ctx, clt.AuthClient, tc.ResourceFilter(types.KindNode)) - if err != nil { - return nil, trace.Wrap(err) - } - - if len(resources) == 0 { - return nil, trace.BadParameter("no target hosts available") - } - - return resources, nil -} - // connectToHost opens an SSH session to the target host via the Proxy web api. -func (s WebSSHBenchmark) connectToHost(ctx context.Context, tc *client.TeleportClient, webSession *webSession, host string) (*web.TerminalStream, error) { +func connectToHost(ctx context.Context, tc *client.TeleportClient, webSession *webSession, host string) (io.ReadWriteCloser, error) { req := web.TerminalRequest{ Server: host, Login: tc.HostLogin, @@ -185,33 +193,25 @@ func (s WebSSHBenchmark) connectToHost(ctx context.Context, tc *client.TeleportC return stream, trace.Wrap(err) } -type webSession struct { - mu sync.Mutex - webSession types.WebSession - clt *client.WebClient -} - -func (s *webSession) renew(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case <-time.After(time.Until(s.expires().Add(-3 * time.Minute))): - resp, err := s.clt.PostJSON(ctx, s.clt.Endpoint("webapi", "sessions", "renew"), nil) - if err != nil { - continue - } +// getServers returns all [types.Server] that the authenticated user has +// access to. +func getServers(ctx context.Context, tc *client.TeleportClient) ([]types.Server, error) { + clt, err := tc.ConnectToCluster(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer clt.Close() - session, err := client.GetSessionFromResponse(resp) - if err != nil { - continue - } + resources, err := apiclient.GetAllResources[types.Server](ctx, clt.AuthClient, tc.ResourceFilter(types.KindNode)) + if err != nil { + return nil, trace.Wrap(err) + } - s.mu.Lock() - s.webSession = session - s.mu.Unlock() - } + if len(resources) == 0 { + return nil, trace.BadParameter("no target hosts available") } + + return resources, nil } func (s *webSession) expires() time.Time { diff --git a/lib/client/conntest/database/mysql.go b/lib/client/conntest/database/mysql.go index d813fe4692c4f..5c77a6317b135 100644 --- a/lib/client/conntest/database/mysql.go +++ b/lib/client/conntest/database/mysql.go @@ -29,12 +29,27 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/srv/db/common" ) // MySQLPinger implements the DatabasePinger interface for the MySQL protocol. type MySQLPinger struct{} +// convertError converts the error from MySQL client since it can be wrapped in an [errors.Causer]. +// The MySQL engine in the agent already does this, but we need it here because +// the error is from the MySQL client. +func convertError(err error) error { + // causer defines an interface for errors wrapped by the [errors] package. + type causer interface { + Cause() error + } + + if causer, ok := err.(causer); ok { + return trace.Wrap(causer.Cause()) + } + + return trace.Wrap(err) +} + // Ping connects to the database and issues a basic select statement to validate the connection. func (p *MySQLPinger) Ping(ctx context.Context, params PingParams) error { if err := params.CheckAndSetDefaults(defaults.ProtocolMySQL); err != nil { @@ -50,10 +65,7 @@ func (p *MySQLPinger) Ping(ctx context.Context, params PingParams) error { nd.DialContext, ) if err != nil { - // convert the error from MySQL client since it can be wrapped in a "Causer". - // The MySQL engine in the agent already does this, but we need it here because - // the error is from the MySQL client. - return trace.Wrap(common.ConvertError(err)) + return convertError(err) } defer func() { @@ -63,7 +75,7 @@ func (p *MySQLPinger) Ping(ctx context.Context, params PingParams) error { }() if err := conn.Ping(); err != nil { - return trace.Wrap(common.ConvertError(err)) + return convertError(err) } return nil diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index bc965e381b0c8..fbd31366c7b3b 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -86,7 +86,6 @@ import ( "github.com/gravitational/teleport/lib/kube/proxy/streamproto" kubeutils "github.com/gravitational/teleport/lib/kube/utils" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -115,7 +114,7 @@ const ( // ForwarderConfig specifies configuration for proxy forwarder type ForwarderConfig struct { // ReverseTunnelSrv is the teleport reverse tunnel server - ReverseTunnelSrv reversetunnel.Server + ReverseTunnelSrv reversetunnelclient.Server // ClusterName is a local cluster name ClusterName string // Keygen points to a key generator implementation diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index 8c1c3ae7f27a3..92a0a1b3a0890 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -57,7 +57,6 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" testingkubemock "github.com/gravitational/teleport/lib/kube/proxy/testing/kube_server" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" @@ -162,7 +161,7 @@ func TestAuthenticate(t *testing.T) { require.NoError(t, err) tun := mockRevTunnel{ - sites: map[string]reversetunnel.RemoteSite{ + sites: map[string]reversetunnelclient.RemoteSite{ "remote": mockRemoteSite{name: "remote"}, "local": mockRemoteSite{name: "local"}, }, @@ -202,7 +201,7 @@ func TestAuthenticate(t *testing.T) { routeToCluster string kubernetesCluster string haveKubeCreds bool - tunnel reversetunnel.Server + tunnel reversetunnelclient.Server kubeServers []types.KubeServer activeRequests []string @@ -1184,11 +1183,11 @@ func (c *mockCSRClient) ProcessKubeCSR(csr auth.KubeCSR) (*auth.KubeCSRResponse, }, nil } -// mockRemoteSite is a reversetunnel.RemoteSite implementation with hardcoded +// mockRemoteSite is a reversetunnelclient.RemoteSite implementation with hardcoded // name, because there's no easy way to construct a real -// reversetunnel.RemoteSite. +// reversetunnelclient.RemoteSite. type mockRemoteSite struct { - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite name string } @@ -1233,12 +1232,12 @@ func (ap mockAccessPoint) GetCertAuthority(ctx context.Context, id types.CertAut } type mockRevTunnel struct { - reversetunnel.Server + reversetunnelclient.Server - sites map[string]reversetunnel.RemoteSite + sites map[string]reversetunnelclient.RemoteSite } -func (t mockRevTunnel) GetSite(name string) (reversetunnel.RemoteSite, error) { +func (t mockRevTunnel) GetSite(name string) (reversetunnelclient.RemoteSite, error) { s, ok := t.sites[name] if !ok { return nil, trace.NotFound("remote site %q not found", name) @@ -1246,8 +1245,8 @@ func (t mockRevTunnel) GetSite(name string) (reversetunnel.RemoteSite, error) { return s, nil } -func (t mockRevTunnel) GetSites() ([]reversetunnel.RemoteSite, error) { - var sites []reversetunnel.RemoteSite +func (t mockRevTunnel) GetSites() ([]reversetunnelclient.RemoteSite, error) { + var sites []reversetunnelclient.RemoteSite for _, s := range t.sites { sites = append(sites, s) } diff --git a/lib/kube/proxy/transport.go b/lib/kube/proxy/transport.go index e9ec4dcedbf30..33253fe51aa8c 100644 --- a/lib/kube/proxy/transport.go +++ b/lib/kube/proxy/transport.go @@ -39,7 +39,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -384,7 +383,7 @@ func (f *Forwarder) remoteClusterDialer(clusterName string) dialContextFunc { return nil, trace.Wrap(err) } - return targetCluster.DialTCP(reversetunnel.DialParams{ + return targetCluster.DialTCP(reversetunnelclient.DialParams{ // Send a sentinel value to the remote cluster because this connection // will be used to forward multiple requests to the remote cluster from // different users. @@ -471,7 +470,7 @@ func (f *Forwarder) localClusterDialer(kubeClusterName string, opts ...contextDi // It is a combination of the server's hostname and the cluster name. // . serverID := fmt.Sprintf("%s.%s", s.GetHostID(), f.cfg.ClusterName) - conn, err := localCluster.DialTCP(reversetunnel.DialParams{ + conn, err := localCluster.DialTCP(reversetunnelclient.DialParams{ // Send a sentinel value to the remote cluster because this connection // will be used to forward multiple requests to the remote cluster from // different users. diff --git a/lib/kube/proxy/transport_test.go b/lib/kube/proxy/transport_test.go index d6bffde188167..12a34a17aaf35 100644 --- a/lib/kube/proxy/transport_test.go +++ b/lib/kube/proxy/transport_test.go @@ -26,7 +26,6 @@ import ( "go.opentelemetry.io/otel" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" ) @@ -158,14 +157,14 @@ func TestForwarderClusterDialer(t *testing.T) { tests := []struct { name string dialerCreator func(kubeClusterName string) dialContextFunc - want reversetunnel.DialParams + want reversetunnelclient.DialParams }{ { name: "local site", dialerCreator: func(kubeClusterName string) dialContextFunc { return f.localClusterDialer(kubeClusterName) }, - want: reversetunnel.DialParams{ + want: reversetunnelclient.DialParams{ From: &utils.NetAddr{ Addr: "0.0.0.0:0", AddrNetwork: "tcp", @@ -182,7 +181,7 @@ func TestForwarderClusterDialer(t *testing.T) { { name: "remote site", dialerCreator: f.remoteClusterDialer, - want: reversetunnel.DialParams{ + want: reversetunnelclient.DialParams{ From: &utils.NetAddr{ Addr: "0.0.0.0:0", AddrNetwork: "tcp", @@ -207,12 +206,12 @@ func TestForwarderClusterDialer(t *testing.T) { } type fakeReverseTunnel struct { - reversetunnel.Server - want reversetunnel.DialParams + reversetunnelclient.Server + want reversetunnelclient.DialParams t *testing.T } -func (f *fakeReverseTunnel) GetSite(_ string) (reversetunnel.RemoteSite, error) { +func (f *fakeReverseTunnel) GetSite(_ string) (reversetunnelclient.RemoteSite, error) { return &fakeRemoteSiteTunnel{ want: f.want, t: f.t, @@ -220,12 +219,12 @@ func (f *fakeReverseTunnel) GetSite(_ string) (reversetunnel.RemoteSite, error) } type fakeRemoteSiteTunnel struct { - reversetunnel.RemoteSite - want reversetunnel.DialParams + reversetunnelclient.RemoteSite + want reversetunnelclient.DialParams t *testing.T } -func (f *fakeRemoteSiteTunnel) DialTCP(p reversetunnel.DialParams) (net.Conn, error) { +func (f *fakeRemoteSiteTunnel) DialTCP(p reversetunnelclient.DialParams) (net.Conn, error) { require.Equal(f.t, f.want, p) return nil, nil } diff --git a/lib/kube/proxy/utils_testing.go b/lib/kube/proxy/utils_testing.go index d103b662f178e..7ca7fb89e313d 100644 --- a/lib/kube/proxy/utils_testing.go +++ b/lib/kube/proxy/utils_testing.go @@ -53,7 +53,7 @@ import ( "github.com/gravitational/teleport/lib/kube/proxy/streamproto" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" sessPkg "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/tlsca" @@ -290,10 +290,10 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // Create kubernetes service server. testCtx.KubeProxy, err = NewTLSServer(TLSServerConfig{ ForwarderConfig: ForwarderConfig{ - ReverseTunnelSrv: &reversetunnel.FakeServer{ - Sites: []reversetunnel.RemoteSite{ + ReverseTunnelSrv: &reversetunnelclient.FakeServer{ + Sites: []reversetunnelclient.RemoteSite{ &fakeRemoteSite{ - FakeRemoteSite: reversetunnel.NewFakeRemoteSite(testCtx.ClusterName, client), + FakeRemoteSite: reversetunnelclient.NewFakeRemoteSite(testCtx.ClusterName, client), idToAddr: map[string]string{ testCtx.HostID: testCtx.kubeServerListener.Addr().String(), }, @@ -587,11 +587,11 @@ func (f *fakeClient) CreateSessionTracker(ctx context.Context, st types.SessionT // fakeRemoteSite is a fake remote site that uses a map to map server IDs to // addresses to simulate reverse tunneling. type fakeRemoteSite struct { - *reversetunnel.FakeRemoteSite + *reversetunnelclient.FakeRemoteSite idToAddr map[string]string } -func (f *fakeRemoteSite) DialTCP(p reversetunnel.DialParams) (conn net.Conn, err error) { +func (f *fakeRemoteSite) DialTCP(p reversetunnelclient.DialParams) (conn net.Conn, err error) { // The server ID is the first part of the address. addr, ok := f.idToAddr[strings.Split(p.ServerID, ".")[0]] if !ok { diff --git a/lib/proxy/clusterdial/dial.go b/lib/proxy/clusterdial/dial.go index 59a01e9a97dd8..124f8d919c3f0 100644 --- a/lib/proxy/clusterdial/dial.go +++ b/lib/proxy/clusterdial/dial.go @@ -20,7 +20,7 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/lib/proxy/peer" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" ) // ClusterDialerFunc is a function that implements a peer.ClusterDialer. @@ -32,14 +32,14 @@ func (f ClusterDialerFunc) Dial(clusterName string, request peer.DialParams) (ne } // NewClusterDialer implements proxy.ClusterDialer for a reverse tunnel server. -func NewClusterDialer(server reversetunnel.Server) ClusterDialerFunc { +func NewClusterDialer(server reversetunnelclient.Server) ClusterDialerFunc { return func(clusterName string, request peer.DialParams) (net.Conn, error) { site, err := server.GetSite(clusterName) if err != nil { return nil, trace.Wrap(err) } - dialParams := reversetunnel.DialParams{ + dialParams := reversetunnelclient.DialParams{ ServerID: request.ServerID, ConnType: request.ConnType, From: request.From, diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 33809cd8a7f84..787b4167059ed 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -37,7 +37,6 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" @@ -102,7 +101,7 @@ type serverResolverFn = func(ctx context.Context, host, port string, site site) // SiteGetter provides access to connected local or remote sites type SiteGetter interface { // GetSite returns the site matching the provided clusterName - GetSite(clusterName string) (reversetunnel.RemoteSite, error) + GetSite(clusterName string) (reversetunnelclient.RemoteSite, error) } // RemoteClusterGetter provides access to remote cluster resources @@ -164,7 +163,7 @@ type Router struct { clusterName string log *logrus.Entry clusterGetter RemoteClusterGetter - localSite reversetunnel.RemoteSite + localSite reversetunnelclient.RemoteSite siteGetter SiteGetter tracer oteltrace.Tracer serverResolver serverResolverFn @@ -290,7 +289,7 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. agentGetter = nil } - conn, err := site.Dial(reversetunnel.DialParams{ + conn, err := site.Dial(reversetunnelclient.DialParams{ From: clientSrcAddr, To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr}, OriginalClientDstAddr: clientDstAddr, @@ -313,7 +312,7 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. // getRemoteCluster looks up the provided clusterName to determine if a remote site exists with // that name and determines if the user has access to it. -func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, checker services.AccessChecker) (reversetunnel.RemoteSite, error) { +func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, checker services.AccessChecker) (reversetunnelclient.RemoteSite, error) { _, span := r.tracer.Start( ctx, "router/getRemoteCluster", @@ -341,16 +340,16 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check } // site is the minimum interface needed to match servers -// for a reversetunnel.RemoteSite. It makes testing easier. +// for a reversetunnelclient.RemoteSite. It makes testing easier. type site interface { GetNodes(ctx context.Context, fn func(n services.Node) bool) ([]types.Server, error) GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error) } // remoteSite is a site implementation that wraps -// a reversetunnel.RemoteSite +// a reversetunnelclient.RemoteSite type remoteSite struct { - site reversetunnel.RemoteSite + site reversetunnelclient.RemoteSite } // GetNodes uses the wrapped sites NodeWatcher to filter nodes @@ -472,7 +471,7 @@ func (r *Router) DialSite(ctx context.Context, clusterName string, clientSrcAddr // dial the local auth server if clusterName == r.clusterName { - conn, err := r.localSite.DialAuthServer(reversetunnel.DialParams{From: clientSrcAddr, OriginalClientDstAddr: clientDstAddr}) + conn, err := r.localSite.DialAuthServer(reversetunnelclient.DialParams{From: clientSrcAddr, OriginalClientDstAddr: clientDstAddr}) return conn, trace.Wrap(err) } @@ -482,7 +481,7 @@ func (r *Router) DialSite(ctx context.Context, clusterName string, clientSrcAddr return nil, trace.Wrap(err) } - conn, err := site.DialAuthServer(reversetunnel.DialParams{From: clientSrcAddr, OriginalClientDstAddr: clientDstAddr}) + conn, err := site.DialAuthServer(reversetunnelclient.DialParams{From: clientSrcAddr, OriginalClientDstAddr: clientDstAddr}) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index 8ad851f287eee..6af6a006f7f73 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -30,7 +30,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/observability/tracing" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" @@ -298,31 +298,31 @@ func serverResolver(srv types.Server, err error) serverResolverFn { } type tunnel struct { - reversetunnel.Tunnel + reversetunnelclient.Tunnel - site reversetunnel.RemoteSite + site reversetunnelclient.RemoteSite err error } -func (t tunnel) GetSite(cluster string) (reversetunnel.RemoteSite, error) { +func (t tunnel) GetSite(cluster string) (reversetunnelclient.RemoteSite, error) { return t.site, t.err } type testRemoteSite struct { - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite - params reversetunnel.DialParams + params reversetunnelclient.DialParams conn net.Conn err error } -func (r *testRemoteSite) Dial(params reversetunnel.DialParams) (net.Conn, error) { +func (r *testRemoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { r.params = params return r.conn, r.err } -func (r testRemoteSite) DialAuthServer(reversetunnel.DialParams) (net.Conn, error) { +func (r testRemoteSite) DialAuthServer(reversetunnelclient.DialParams) (net.Conn, error) { return r.conn, r.err } @@ -331,10 +331,10 @@ func (r testRemoteSite) GetClient() (auth.ClientI, error) { } type testSiteGetter struct { - site reversetunnel.RemoteSite + site reversetunnelclient.RemoteSite } -func (s testSiteGetter) GetSite(clusterName string) (reversetunnel.RemoteSite, error) { +func (s testSiteGetter) GetSite(clusterName string) (reversetunnelclient.RemoteSite, error) { return s.site, nil } @@ -385,7 +385,7 @@ func TestRouter_DialHost(t *testing.T) { cases := []struct { name string router Router - assertion func(t *testing.T, params reversetunnel.DialParams, conn net.Conn, err error) + assertion func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) }{ { name: "failure resolving node", @@ -395,7 +395,7 @@ func TestRouter_DialHost(t *testing.T) { tracer: tracing.NoopTracer("test"), serverResolver: serverResolver(nil, trace.NotFound(teleport.NodeIsAmbiguous)), }, - assertion: func(t *testing.T, params reversetunnel.DialParams, conn net.Conn, err error) { + assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) { require.Error(t, err) require.Nil(t, conn) }, @@ -408,7 +408,7 @@ func TestRouter_DialHost(t *testing.T) { log: logger, tracer: tracing.NoopTracer("test"), }, - assertion: func(t *testing.T, params reversetunnel.DialParams, conn net.Conn, err error) { + assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) { require.Error(t, err) require.True(t, trace.IsNotFound(err)) require.Nil(t, conn) @@ -423,7 +423,7 @@ func TestRouter_DialHost(t *testing.T) { tracer: tracing.NoopTracer("test"), serverResolver: serverResolver(srv, nil), }, - assertion: func(t *testing.T, params reversetunnel.DialParams, conn net.Conn, err error) { + assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) { require.Error(t, err) require.True(t, trace.IsConnectionProblem(err)) require.Nil(t, conn) @@ -438,7 +438,7 @@ func TestRouter_DialHost(t *testing.T) { tracer: tracing.NoopTracer("test"), serverResolver: serverResolver(srv, nil), }, - assertion: func(t *testing.T, params reversetunnel.DialParams, conn net.Conn, err error) { + assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) { require.NoError(t, err) require.Equal(t, srv, params.TargetServer) require.NotNil(t, params.GetUserAgent) @@ -456,7 +456,7 @@ func TestRouter_DialHost(t *testing.T) { tracer: tracing.NoopTracer("test"), serverResolver: serverResolver(agentlessSrv, nil), }, - assertion: func(t *testing.T, params reversetunnel.DialParams, conn net.Conn, err error) { + assertion: func(t *testing.T, params reversetunnelclient.DialParams, conn net.Conn, err error) { require.NoError(t, err) require.Equal(t, agentlessSrv, params.TargetServer) require.Nil(t, params.GetUserAgent) @@ -472,7 +472,7 @@ func TestRouter_DialHost(t *testing.T) { t.Run(tt.name, func(t *testing.T) { conn, _, err := tt.router.DialHost(ctx, &utils.NetAddr{}, &utils.NetAddr{}, "host", "0", "test", nil, agentGetter, createSigner) - var params reversetunnel.DialParams + var params reversetunnelclient.DialParams if tt.router.localSite != nil { params = tt.router.localSite.(*testRemoteSite).params } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index e1ad1181894b9..9cb33b1b2e511 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -119,7 +119,7 @@ type AgentPoolConfig struct { // either be proxy (trusted clusters) or node (dial back). Component string // ReverseTunnelServer holds all reverse tunnel connections. - ReverseTunnelServer Server + ReverseTunnelServer reversetunnelclient.Server // Resolver retrieves the reverse tunnel address Resolver reversetunnelclient.Resolver // Cluster is a cluster name of the proxy. diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 9f4eb48d5ca47..8b08716f8bd85 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -1,5 +1,5 @@ /* -Copyright 2016 Gravitational, Inc. +Copyright 2020 Gravitational, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,192 +16,8 @@ limitations under the License. package reversetunnel -import ( - "context" - "fmt" - "io" - "net" - "time" +import "github.com/gravitational/teleport/lib/reversetunnelclient" - "golang.org/x/crypto/ssh" - - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/proxy/peer" - "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/teleagent" -) - -// DialParams is a list of parameters used to Dial to a node within a cluster. -type DialParams struct { - // From is the source address. - From net.Addr - - // To is the destination address. - To net.Addr - - // GetUserAgent gets an SSH agent for use in connecting to the remote host. Used by the - // forwarding proxy. - GetUserAgent teleagent.Getter - - // AgentlessSigner is used for authenticating to the remote host when it is an - // agentless node. - AgentlessSigner ssh.Signer - - // Address is used by the forwarding proxy to generate a host certificate for - // the target node. This is needed because while dialing occurs via IP - // address, tsh thinks it's connecting via DNS name and that's how it - // validates the host certificate. - Address string - - // Principals are additional principals that need to be added to the host - // certificate. Used by the recording proxy to correctly generate a host - // certificate. - Principals []string - - // ServerID the hostUUID.clusterName of a Teleport node. Used with nodes - // that are connected over a reverse tunnel. - ServerID string - - // ProxyIDs is a list of proxy ids the node is connected to. - ProxyIDs []string - - // ConnType is the type of connection requested, either node or application. - // Only used when connecting through a tunnel. - ConnType types.TunnelType - - // TargetServer is the host that the connection is being established for. - // It **MUST** only be populated when the target is a teleport ssh server - // or an agentless server. - TargetServer types.Server - - // FromPeerProxy indicates that the dial request is being tunneled from - // a peer proxy. - FromPeerProxy bool - - // TeleportVersion shows version of the target node, if we know that it's teleport node. - TeleportVersion string - - // OriginalClientDstAddr is used in PROXY headers to show where client originally contacted Teleport infrastructure - OriginalClientDstAddr net.Addr -} - -func (params DialParams) String() string { - to := params.To.String() - if to == "" { - to = params.ServerID - } - return fmt.Sprintf("from: %q to: %q", params.From, to) -} - -func stringOrEmpty(addr net.Addr) string { - if addr == nil { - return "" - } - return addr.String() -} - -// shouldDialAndForward returns whether a connection should be proxied -// and forwarded or not. -func shouldDialAndForward(params DialParams, recConfig types.SessionRecordingConfig) bool { - // connection is already being tunneled, do not forward - if params.FromPeerProxy { - return false - } - // the node is an agentless node, the connection must be forwarded - if params.TargetServer != nil && params.TargetServer.GetSubKind() == types.SubKindOpenSSHNode { - return true - } - // proxy session recording mode is being used and an SSH session - // is being requested, the connection must be forwarded - if params.ConnType == types.NodeTunnel && services.IsRecordAtProxy(recConfig.GetMode()) { - return true - } - return false -} - -// RemoteSite represents remote teleport site that can be accessed via -// teleport tunnel or directly by proxy -// -// There are two implementations of this interface: local and remote sites. -type RemoteSite interface { - // DialAuthServer returns a net.Conn to the Auth Server of a site. - DialAuthServer(DialParams) (conn net.Conn, err error) - // Dial dials any address within the site network, in terminating - // mode it uses local instance of forwarding server to terminate - // and record the connection. - Dial(DialParams) (conn net.Conn, err error) - // DialTCP dials any address within the site network and - // ignores recording mode, used in components that need direct dialer. - DialTCP(DialParams) (conn net.Conn, err error) - // GetLastConnected returns last time the remote site was seen connected - GetLastConnected() time.Time - // GetName returns site name (identified by authority domain's name) - GetName() string - // GetStatus returns status of this site (either offline or connected) - GetStatus() string - // GetClient returns client connected to remote auth server - GetClient() (auth.ClientI, error) - // CachingAccessPoint returns access point that is lightweight - // but is resilient to auth server crashes - CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) - // NodeWatcher returns the node watcher that maintains the node set for the site - NodeWatcher() (*services.NodeWatcher, error) - // GetTunnelsCount returns the amount of active inbound tunnels - // from the remote cluster - GetTunnelsCount() int - // IsClosed reports whether this RemoteSite has been closed and should no - // longer be used. - IsClosed() bool - // Closer allows the site to be closed - io.Closer -} - -// Tunnel provides access to connected local or remote clusters -// using unified interface. -type Tunnel interface { - // GetSites returns a list of connected remote sites - GetSites() ([]RemoteSite, error) - // GetSite returns remote site this node belongs to - GetSite(domainName string) (RemoteSite, error) -} - -// Server is a TCP/IP SSH server which listens on an SSH endpoint and remote/local -// sites connect and register with it. -type Server interface { - Tunnel - // Start starts server - Start() error - // Close closes server's operations immediately - Close() error - // DrainConnections closes listeners and begins draining connections without - // closing open connections. - DrainConnections(context.Context) error - // Shutdown performs graceful server shutdown closing open connections. - Shutdown(context.Context) error - // Wait waits for server to close all outstanding operations - Wait(ctx context.Context) - // GetProxyPeerClient returns the proxy peer client - GetProxyPeerClient() *peer.Client -} - -const ( - // NoApplicationTunnel is the error message returned when application - // reverse tunnel cannot be found. - // - // It usually happens when an app agent has shut down (or crashed) but - // hasn't expired from the backend yet. - NoApplicationTunnel = "could not find reverse tunnel, check that Application Service agent proxying this application is up and running" - // NoDatabaseTunnel is the error message returned when database reverse - // tunnel cannot be found. - // - // It usually happens when a database agent has shut down (or crashed) but - // hasn't expired from the backend yet. - NoDatabaseTunnel = "could not find reverse tunnel, check that Database Service agent proxying this database is up and running" - // NoOktaTunnel is the error message returned when an Okta - // reverse tunnel cannot be found. - // - // It usually happens when an Okta service has shut down (or crashed) but - // hasn't expired from the backend yet. - NoOktaTunnel = "could not find reverse tunnel, check that Okta Service agent proxying this application is up and running" -) +// RemoteSite is an alias used to prevent breaking e. +// TODO(tross): remove once e has been updated to use reversetunnelclient +type RemoteSite = reversetunnelclient.RemoteSite diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 6001a6c057720..80024870d90ef 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -220,7 +220,7 @@ func (s *localSite) GetLastConnected() time.Time { return s.clock.Now() } -func (s *localSite) DialAuthServer(params DialParams) (net.Conn, error) { +func (s *localSite) DialAuthServer(params reversetunnelclient.DialParams) (net.Conn, error) { if len(s.authServers) == 0 { return nil, trace.ConnectionProblem(nil, "no auth servers available") } @@ -238,7 +238,26 @@ func (s *localSite) DialAuthServer(params DialParams) (net.Conn, error) { return conn, nil } -func (s *localSite) Dial(params DialParams) (net.Conn, error) { +// shouldDialAndForward returns whether a connection should be proxied +// and forwarded or not. +func shouldDialAndForward(params reversetunnelclient.DialParams, recConfig types.SessionRecordingConfig) bool { + // connection is already being tunneled, do not forward + if params.FromPeerProxy { + return false + } + // the node is an agentless node, the connection must be forwarded + if params.TargetServer != nil && params.TargetServer.GetSubKind() == types.SubKindOpenSSHNode { + return true + } + // proxy session recording mode is being used and an SSH session + // is being requested, the connection must be forwarded + if params.ConnType == types.NodeTunnel && services.IsRecordAtProxy(recConfig.GetMode()) { + return true + } + return false +} + +func (s *localSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { recConfig, err := s.accessPoint.GetSessionRecordingConfig(s.srv.Context) if err != nil { return nil, trace.Wrap(err) @@ -263,7 +282,7 @@ func shouldSendSignedPROXYHeader(signer multiplexer.PROXYHeaderSigner, version s dstAddr == nil) } -func (s *localSite) maybeSendSignedPROXYHeader(params DialParams, conn net.Conn, useTunnel, checkVersion bool) error { +func (s *localSite) maybeSendSignedPROXYHeader(params reversetunnelclient.DialParams, conn net.Conn, useTunnel, checkVersion bool) error { if !shouldSendSignedPROXYHeader(s.srv.proxySigner, params.TeleportVersion, useTunnel, checkVersion, params.From, params.OriginalClientDstAddr) { return nil } @@ -281,7 +300,7 @@ func (s *localSite) maybeSendSignedPROXYHeader(params DialParams, conn net.Conn, } // TODO(awly): unit test this -func (s *localSite) DialTCP(params DialParams) (net.Conn, error) { +func (s *localSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, error) { s.log.Debugf("Dialing %v.", params) conn, useTunnel, err := s.getConn(params) @@ -336,7 +355,7 @@ func (s *localSite) adviseReconnect(ctx context.Context) { } } -func (s *localSite) dialAndForward(params DialParams) (_ net.Conn, retErr error) { +func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { if params.GetUserAgent == nil && params.AgentlessSigner == nil { return nil, trace.BadParameter("user agent getter and agentless signer both missing") } @@ -439,7 +458,7 @@ func (s *localSite) dialTunnel(dreq *sshutils.DialReq) (net.Conn, error) { // tryProxyPeering determines whether the node should try to be reached over // a peer proxy. -func (s *localSite) tryProxyPeering(params DialParams) bool { +func (s *localSite) tryProxyPeering(params reversetunnelclient.DialParams) bool { if s.peerClient == nil { return false } @@ -454,7 +473,7 @@ func (s *localSite) tryProxyPeering(params DialParams) bool { } // skipDirectDial determines if a direct dial attempt should be made. -func (s *localSite) skipDirectDial(params DialParams) (bool, error) { +func (s *localSite) skipDirectDial(params reversetunnelclient.DialParams) (bool, error) { // Connections to application and database servers should never occur // over a direct dial. switch params.ConnType { @@ -480,7 +499,7 @@ func (s *localSite) skipDirectDial(params DialParams) (bool, error) { return false, nil } -func getTunnelErrorMessage(params DialParams, connStr string, err error) string { +func getTunnelErrorMessage(params reversetunnelclient.DialParams, connStr string, err error) string { errorMessageTemplate := `Teleport proxy failed to connect to %q agent %q over %s: %v @@ -497,7 +516,14 @@ with the cluster.` return fmt.Sprintf(errorMessageTemplate, params.ConnType, toAddr, connStr, err) } -func (s *localSite) getConn(params DialParams) (conn net.Conn, useTunnel bool, err error) { +func stringOrEmpty(addr net.Addr) string { + if addr == nil { + return "" + } + return addr.String() +} + +func (s *localSite) getConn(params reversetunnelclient.DialParams) (conn net.Conn, useTunnel bool, err error) { dreq := &sshutils.DialReq{ ServerID: params.ServerID, ConnType: params.ConnType, diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index d59117a2977f2..45eb8082b4712 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" ) @@ -127,18 +128,18 @@ func (p *clusterPeers) GetLastConnected() time.Time { return peer.GetLastConnected() } -func (p *clusterPeers) DialAuthServer(DialParams) (net.Conn, error) { +func (p *clusterPeers) DialAuthServer(reversetunnelclient.DialParams) (net.Conn, error) { return nil, trace.ConnectionProblem(nil, "unable to dial to auth server, this proxy has not been discovered yet, try again later") } // Dial is used to connect a requesting client (say, tsh) to an SSH server // located in a remote connected site, the connection goes through the // reverse proxy tunnel. -func (p *clusterPeers) Dial(params DialParams) (conn net.Conn, err error) { +func (p *clusterPeers) Dial(params reversetunnelclient.DialParams) (conn net.Conn, err error) { return p.DialTCP(params) } -func (p *clusterPeers) DialTCP(params DialParams) (conn net.Conn, err error) { +func (p *clusterPeers) DialTCP(params reversetunnelclient.DialParams) (conn net.Conn, err error) { return nil, trace.ConnectionProblem(nil, "unable to dial, this proxy has not been discovered yet, try again later") } @@ -234,7 +235,7 @@ func (s *clusterPeer) GetLastConnected() time.Time { // Dial is used to connect a requesting client (say, tsh) to an SSH server // located in a remote connected site, the connection goes through the // reverse proxy tunnel. -func (s *clusterPeer) Dial(params DialParams) (conn net.Conn, err error) { +func (s *clusterPeer) Dial(params reversetunnelclient.DialParams) (conn net.Conn, err error) { return nil, trace.ConnectionProblem(nil, "unable to dial, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/rc_manager.go b/lib/reversetunnel/rc_manager.go index 0549451570657..a6b22d9d1f1ab 100644 --- a/lib/reversetunnel/rc_manager.go +++ b/lib/reversetunnel/rc_manager.go @@ -72,7 +72,7 @@ type RemoteClusterTunnelManagerConfig struct { LocalCluster string // Local ReverseTunnelServer to reach other cluster members connecting to // this proxy over a tunnel. - ReverseTunnelServer Server + ReverseTunnelServer reversetunnelclient.Server // Clock is a mock-able clock. Clock clockwork.Clock // KubeDialAddr is an optional address of a local kubernetes proxy. diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index c952f12d72f7a..c21b5b06ac734 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/forward" "github.com/gravitational/teleport/lib/teleagent" @@ -67,7 +68,7 @@ type remoteSite struct { certificateCache *certificateCache // localClient provides access to the Auth Server API of the cluster - // within which reversetunnel.Server is running. + // within which reversetunnelclient.Server is running. localClient auth.ClientI // remoteClient provides access to the Auth Server API of the remote cluster that // this site belongs to. @@ -136,7 +137,7 @@ func (s *remoteSite) getRemoteClient() (auth.ClientI, bool, error) { } func (s *remoteSite) authServerContextDialer(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := s.DialAuthServer(DialParams{}) + conn, err := s.DialAuthServer(reversetunnelclient.DialParams{}) return conn, err } @@ -742,7 +743,7 @@ func (s *remoteSite) watchLocks() error { } } -func (s *remoteSite) DialAuthServer(params DialParams) (net.Conn, error) { +func (s *remoteSite) DialAuthServer(params reversetunnelclient.DialParams) (net.Conn, error) { conn, err := s.connThroughTunnel(&sshutils.DialReq{ Address: constants.RemoteAuthServer, ClientSrcAddr: stringOrEmpty(params.From), @@ -754,7 +755,7 @@ func (s *remoteSite) DialAuthServer(params DialParams) (net.Conn, error) { // Dial is used to connect a requesting client (say, tsh) to an SSH server // located in a remote connected site, the connection goes through the // reverse proxy tunnel. -func (s *remoteSite) Dial(params DialParams) (net.Conn, error) { +func (s *remoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { recConfig, err := s.localAccessPoint.GetSessionRecordingConfig(s.ctx) if err != nil { return nil, trace.Wrap(err) @@ -771,7 +772,7 @@ func (s *remoteSite) Dial(params DialParams) (net.Conn, error) { return s.DialTCP(params) } -func (s *remoteSite) DialTCP(params DialParams) (net.Conn, error) { +func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, error) { s.logger.Debugf("Dialing from %v to %v.", params.From, params.To) conn, err := s.connThroughTunnel(&sshutils.DialReq{ @@ -789,7 +790,7 @@ func (s *remoteSite) DialTCP(params DialParams) (net.Conn, error) { return conn, nil } -func (s *remoteSite) dialAndForward(params DialParams) (_ net.Conn, retErr error) { +func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { if params.GetUserAgent == nil && params.AgentlessSigner == nil { return nil, trace.BadParameter("user agent getter and agentless signer both missing") } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 24ee20083c5b3..8a588ca129e5d 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/proxy/peer" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/sshca" @@ -281,7 +282,7 @@ func (cfg *Config) CheckAndSetDefaults() error { // NewServer creates and returns a reverse tunnel server which is fully // initialized but hasn't been started yet -func NewServer(cfg Config) (Server, error) { +func NewServer(cfg Config) (reversetunnelclient.Server, error) { err := metrics.RegisterPrometheusCollectors(prometheusCollectors...) if err != nil { return nil, trace.Wrap(err) @@ -965,10 +966,10 @@ func (s *server) upsertRemoteCluster(conn net.Conn, sshConn *ssh.ServerConn) (*r return site, remoteConn, nil } -func (s *server) GetSites() ([]RemoteSite, error) { +func (s *server) GetSites() ([]reversetunnelclient.RemoteSite, error) { s.RLock() defer s.RUnlock() - out := make([]RemoteSite, 0, len(s.remoteSites)+len(s.clusterPeers)+1) + out := make([]reversetunnelclient.RemoteSite, 0, len(s.remoteSites)+len(s.clusterPeers)+1) out = append(out, s.localSite) haveLocalConnection := make(map[string]bool) @@ -1002,7 +1003,7 @@ func (s *server) getRemoteClusters() []*remoteSite { // with a cluster peer your best bet is to wait until the agent has discovered // all proxies behind a load balancer. Note, the cluster peer is a // services.TunnelConnection that was created by another proxy. -func (s *server) GetSite(name string) (RemoteSite, error) { +func (s *server) GetSite(name string) (reversetunnelclient.RemoteSite, error) { s.RLock() defer s.RUnlock() if s.localSite.GetName() == name { @@ -1029,7 +1030,7 @@ func (s *server) GetProxyPeerClient() *peer.Client { // alwaysClose forces onSiteTunnelClose to remove and close // the site by always returning false from HasValidConnections. type alwaysClose struct { - RemoteSite + reversetunnelclient.RemoteSite } func (a *alwaysClose) HasValidConnections() bool { @@ -1122,7 +1123,7 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, } // configure access to the full Auth Server API and the cached subset for - // the local cluster within which reversetunnel.Server is running. + // the local cluster within which reversetunnelclient.Server is running. remoteSite.localClient = srv.localAuthClient remoteSite.localAccessPoint = srv.localAccessPoint diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 9e3b2117e6baf..76214cd3d35eb 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -76,7 +76,7 @@ type transport struct { sconn sshutils.Conn // reverseTunnelServer holds all reverse tunnel connections. - reverseTunnelServer Server + reverseTunnelServer reversetunnelclient.Server // server is either an SSH or application server. It can handle a connection // (perform handshake and handle request). @@ -377,11 +377,11 @@ func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, e // a direct dial, return right away. switch r.ConnType { case types.AppTunnel: - return nil, false, trace.ConnectionProblem(err, NoApplicationTunnel) + return nil, false, trace.ConnectionProblem(err, reversetunnelclient.NoApplicationTunnel) case types.OktaTunnel: - return nil, false, trace.ConnectionProblem(err, NoOktaTunnel) + return nil, false, trace.ConnectionProblem(err, reversetunnelclient.NoOktaTunnel) case types.DatabaseTunnel: - return nil, false, trace.ConnectionProblem(err, NoDatabaseTunnel) + return nil, false, trace.ConnectionProblem(err, reversetunnelclient.NoDatabaseTunnel) } errTun := err diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go new file mode 100644 index 0000000000000..3a6f9ab9df255 --- /dev/null +++ b/lib/reversetunnelclient/api.go @@ -0,0 +1,181 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package reversetunnelclient + +import ( + "context" + "fmt" + "io" + "net" + "time" + + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/proxy/peer" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/teleagent" +) + +// DialParams is a list of parameters used to Dial to a node within a cluster. +type DialParams struct { + // From is the source address. + From net.Addr + + // To is the destination address. + To net.Addr + + // GetUserAgent gets an SSH agent for use in connecting to the remote host. Used by the + // forwarding proxy. + GetUserAgent teleagent.Getter + + // AgentlessSigner is used for authenticating to the remote host when it is an + // agentless node. + AgentlessSigner ssh.Signer + + // Address is used by the forwarding proxy to generate a host certificate for + // the target node. This is needed because while dialing occurs via IP + // address, tsh thinks it's connecting via DNS name and that's how it + // validates the host certificate. + Address string + + // Principals are additional principals that need to be added to the host + // certificate. Used by the recording proxy to correctly generate a host + // certificate. + Principals []string + + // ServerID the hostUUID.clusterName of a Teleport node. Used with nodes + // that are connected over a reverse tunnel. + ServerID string + + // ProxyIDs is a list of proxy ids the node is connected to. + ProxyIDs []string + + // ConnType is the type of connection requested, either node or application. + // Only used when connecting through a tunnel. + ConnType types.TunnelType + + // TargetServer is the host that the connection is being established for. + // It **MUST** only be populated when the target is a teleport ssh server + // or an agentless server. + TargetServer types.Server + + // FromPeerProxy indicates that the dial request is being tunneled from + // a peer proxy. + FromPeerProxy bool + + // TeleportVersion shows version of the target node, if we know that it's teleport node. + TeleportVersion string + + // OriginalClientDstAddr is used in PROXY headers to show where client originally contacted Teleport infrastructure + OriginalClientDstAddr net.Addr +} + +func (params DialParams) String() string { + to := params.To.String() + if to == "" { + to = params.ServerID + } + return fmt.Sprintf("from: %q to: %q", params.From, to) +} + +// RemoteSite represents remote teleport site that can be accessed via +// teleport tunnel or directly by proxy +// +// There are two implementations of this interface: local and remote sites. +type RemoteSite interface { + // DialAuthServer returns a net.Conn to the Auth Server of a site. + DialAuthServer(DialParams) (conn net.Conn, err error) + // Dial dials any address within the site network, in terminating + // mode it uses local instance of forwarding server to terminate + // and record the connection. + Dial(DialParams) (conn net.Conn, err error) + // DialTCP dials any address within the site network and + // ignores recording mode, used in components that need direct dialer. + DialTCP(DialParams) (conn net.Conn, err error) + // GetLastConnected returns last time the remote site was seen connected + GetLastConnected() time.Time + // GetName returns site name (identified by authority domain's name) + GetName() string + // GetStatus returns status of this site (either offline or connected) + GetStatus() string + // GetClient returns client connected to remote auth server + GetClient() (auth.ClientI, error) + // CachingAccessPoint returns access point that is lightweight + // but is resilient to auth server crashes + CachingAccessPoint() (auth.RemoteProxyAccessPoint, error) + // NodeWatcher returns the node watcher that maintains the node set for the site + NodeWatcher() (*services.NodeWatcher, error) + // GetTunnelsCount returns the amount of active inbound tunnels + // from the remote cluster + GetTunnelsCount() int + // IsClosed reports whether this RemoteSite has been closed and should no + // longer be used. + IsClosed() bool + // Closer allows the site to be closed + io.Closer +} + +// Tunnel provides access to connected local or remote clusters +// using unified interface. +type Tunnel interface { + // GetSites returns a list of connected remote sites + GetSites() ([]RemoteSite, error) + // GetSite returns remote site this node belongs to + GetSite(domainName string) (RemoteSite, error) +} + +// Server is a TCP/IP SSH server which listens on an SSH endpoint and remote/local +// sites connect and register with it. +type Server interface { + Tunnel + // Start starts server + Start() error + // Close closes server's operations immediately + Close() error + // DrainConnections closes listeners and begins draining connections without + // closing open connections. + DrainConnections(context.Context) error + // Shutdown performs graceful server shutdown closing open connections. + Shutdown(context.Context) error + // Wait waits for server to close all outstanding operations + Wait(ctx context.Context) + // GetProxyPeerClient returns the proxy peer client + GetProxyPeerClient() *peer.Client +} + +const ( + // NoApplicationTunnel is the error message returned when application + // reverse tunnel cannot be found. + // + // It usually happens when an app agent has shut down (or crashed) but + // hasn't expired from the backend yet. + NoApplicationTunnel = "could not find reverse tunnel, check that Application Service agent proxying this application is up and running" + // NoDatabaseTunnel is the error message returned when database reverse + // tunnel cannot be found. + // + // It usually happens when a database agent has shut down (or crashed) but + // hasn't expired from the backend yet. + NoDatabaseTunnel = "could not find reverse tunnel, check that Database Service agent proxying this database is up and running" + // NoOktaTunnel is the error message returned when an Okta + // reverse tunnel cannot be found. + // + // It usually happens when an Okta service has shut down (or crashed) but + // hasn't expired from the backend yet. + NoOktaTunnel = "could not find reverse tunnel, check that Okta Service agent proxying this application is up and running" +) diff --git a/lib/reversetunnel/api_with_roles.go b/lib/reversetunnelclient/api_with_roles.go similarity index 88% rename from lib/reversetunnel/api_with_roles.go rename to lib/reversetunnelclient/api_with_roles.go index 598df413f7808..b2bd2f4ce2b95 100644 --- a/lib/reversetunnel/api_with_roles.go +++ b/lib/reversetunnelclient/api_with_roles.go @@ -1,5 +1,5 @@ /* -Copyright 2020 Gravitational, Inc. +Copyright 2023 Gravitational, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package reversetunnel +package reversetunnelclient import ( "github.com/gravitational/trace" @@ -32,9 +32,10 @@ type ClusterGetter interface { } // NewTunnelWithRoles returns new authorizing tunnel -func NewTunnelWithRoles(tunnel Tunnel, accessChecker services.AccessChecker, access ClusterGetter) *TunnelWithRoles { +func NewTunnelWithRoles(tunnel Tunnel, localCluster string, accessChecker services.AccessChecker, access ClusterGetter) *TunnelWithRoles { return &TunnelWithRoles{ tunnel: tunnel, + localCluster: localCluster, accessChecker: accessChecker, access: access, } @@ -44,6 +45,8 @@ func NewTunnelWithRoles(tunnel Tunnel, accessChecker services.AccessChecker, acc type TunnelWithRoles struct { tunnel Tunnel + localCluster string + // accessChecker is used to check RBAC permissions. accessChecker services.AccessChecker @@ -58,7 +61,7 @@ func (t *TunnelWithRoles) GetSites() ([]RemoteSite, error) { } out := make([]RemoteSite, 0, len(clusters)) for _, cluster := range clusters { - if _, ok := cluster.(*localSite); ok { + if t.localCluster == cluster.GetName() { out = append(out, cluster) continue } @@ -87,7 +90,7 @@ func (t *TunnelWithRoles) GetSite(clusterName string) (RemoteSite, error) { if err != nil { return nil, trace.Wrap(err) } - if _, ok := cluster.(*localSite); ok { + if t.localCluster == cluster.GetName() { return cluster, nil } rc, err := t.access.GetRemoteCluster(clusterName) diff --git a/lib/reversetunnel/fake.go b/lib/reversetunnelclient/fake.go similarity index 94% rename from lib/reversetunnel/fake.go rename to lib/reversetunnelclient/fake.go index ea9439115ba19..6cc39242b89f1 100644 --- a/lib/reversetunnel/fake.go +++ b/lib/reversetunnelclient/fake.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package reversetunnel +package reversetunnelclient import ( "net" @@ -26,7 +26,7 @@ import ( "github.com/gravitational/teleport/lib/auth" ) -// FakeServer is a fake reversetunnel.Server implementation used in tests. +// FakeServer is a fake Server implementation used in tests. type FakeServer struct { Server // Sites is a list of sites registered via this fake reverse tunnel. @@ -48,7 +48,7 @@ func (s *FakeServer) GetSite(name string) (RemoteSite, error) { return nil, trace.NotFound("site %q not found", name) } -// FakeRemoteSite is a fake reversetunnel.RemoteSite implementation used in tests. +// FakeRemoteSite is a fake reversetunnelclient.RemoteSite implementation used in tests. type FakeRemoteSite struct { RemoteSite // Name is the remote site name. diff --git a/lib/service/acme.go b/lib/service/acme.go index 3651e905d8b44..f2b3febab6c3c 100644 --- a/lib/service/acme.go +++ b/lib/service/acme.go @@ -24,7 +24,7 @@ import ( "github.com/gravitational/trace" "golang.org/x/exp/slices" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" ) @@ -35,7 +35,7 @@ type hostPolicyCheckerConfig struct { // clt is used to get the list of registered applications clt app.Getter // tun is a reverse tunnel - tun reversetunnel.Tunnel + tun reversetunnelclient.Tunnel // clusterName is a name of this cluster clusterName string } diff --git a/lib/service/service.go b/lib/service/service.go index 80dec5d1ae51f..1c287699935f0 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3666,7 +3666,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { // register SSH reverse tunnel server that accepts connections // from remote teleport nodes - var tsrv reversetunnel.Server + var tsrv reversetunnelclient.Server var peerClient *peer.Client if !process.Config.Proxy.DisableReverseTunnel { @@ -3843,8 +3843,12 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { ProxyKubeAddr: proxyKubeAddr, TraceClient: traceClt, Router: proxyRouter, - SessionControl: sessionController, - PROXYSigner: proxySigner, + SessionControl: web.SessionControllerFunc(func(ctx context.Context, sctx *web.SessionContext, login, localAddr, remoteAddr string) (context.Context, error) { + controller := srv.WebSessionController(sessionController) + ctx, err := controller(ctx, sctx, login, localAddr, remoteAddr) + return ctx, trace.Wrap(err) + }), + PROXYSigner: proxySigner, } webHandler, err := web.NewHandler(webConfig) if err != nil { @@ -4599,7 +4603,7 @@ func kubeDialAddr(config servicecfg.ProxyConfig, mode types.ProxyListenerMode) u return config.Kube.ListenAddr } -func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv reversetunnel.Server, accessPoint auth.ReadProxyAccessPoint, clusterName string) (*tls.Config, error) { +func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv reversetunnelclient.Server, accessPoint auth.ReadProxyAccessPoint, clusterName string) (*tls.Config, error) { cfg := process.Config var tlsConfig *tls.Config acmeCfg := process.Config.Proxy.ACME diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 3b31487f13848..d0aca3ad6ea66 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -56,7 +56,6 @@ import ( "github.com/gravitational/teleport/lib/events/athena" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -617,7 +616,7 @@ type mockAccessPoint struct { } type mockReverseTunnelServer struct { - reversetunnel.Server + reversetunnelclient.Server } func TestSetupProxyTLSConfig(t *testing.T) { diff --git a/lib/srv/alpnproxy/auth/auth_proxy.go b/lib/srv/alpnproxy/auth/auth_proxy.go index 5a3cdefde0008..1a6c7af95b4d8 100644 --- a/lib/srv/alpnproxy/auth/auth_proxy.go +++ b/lib/srv/alpnproxy/auth/auth_proxy.go @@ -31,14 +31,14 @@ import ( "github.com/gravitational/teleport/api/defaults" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/utils" ) type sitesGetter interface { - GetSites() ([]reversetunnel.RemoteSite, error) + GetSites() ([]reversetunnelclient.RemoteSite, error) } // NewAuthProxyDialerService create new instance of AuthProxyDialerService. @@ -169,7 +169,7 @@ func (s *AuthProxyDialerService) dialRemoteAuthServer(ctx context.Context, clust if site.GetName() != clusterName { continue } - conn, err := site.DialAuthServer(reversetunnel.DialParams{From: clientSrcAddr, OriginalClientDstAddr: clientDstAddr}) + conn, err := site.DialAuthServer(reversetunnelclient.DialParams{From: clientSrcAddr, OriginalClientDstAddr: clientDstAddr}) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index b0220b825ebb4..b564c847cf40f 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -61,7 +61,7 @@ import ( "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/alpnproxy" @@ -1254,7 +1254,7 @@ type testContext struct { mux *multiplexer.Mux mysqlListener net.Listener webListener *multiplexer.WebListener - fakeRemoteSite *reversetunnel.FakeRemoteSite + fakeRemoteSite *reversetunnelclient.FakeRemoteSite server *Server emitter *eventstest.ChannelEmitter databaseCA types.CertAuthority @@ -1990,10 +1990,10 @@ func setupTestContext(ctx context.Context, t *testing.T, withDatabases ...withDa } // Establish fake reversetunnel b/w database proxy and database service. - testCtx.fakeRemoteSite = reversetunnel.NewFakeRemoteSite(testCtx.clusterName, proxyAuthClient) + testCtx.fakeRemoteSite = reversetunnelclient.NewFakeRemoteSite(testCtx.clusterName, proxyAuthClient) t.Cleanup(func() { require.NoError(t, testCtx.fakeRemoteSite.Close()) }) - tunnel := &reversetunnel.FakeServer{ - Sites: []reversetunnel.RemoteSite{ + tunnel := &reversetunnelclient.FakeServer{ + Sites: []reversetunnelclient.RemoteSite{ testCtx.fakeRemoteSite, }, } diff --git a/lib/srv/db/common/interfaces.go b/lib/srv/db/common/interfaces.go index ea5d81fd66af4..aed2363093d62 100644 --- a/lib/srv/db/common/interfaces.go +++ b/lib/srv/db/common/interfaces.go @@ -22,7 +22,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/authz" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -59,7 +59,7 @@ type ProxyContext struct { // Identity is the authorized client Identity. Identity tlsca.Identity // Cluster is the remote Cluster running the database server. - Cluster reversetunnel.RemoteSite + Cluster reversetunnelclient.RemoteSite // Servers is a list of database Servers that proxy the requested database. Servers []types.DatabaseServer // AuthContext is a context of authenticated user. diff --git a/lib/srv/db/proxyserver.go b/lib/srv/db/proxyserver.go index 4dcfa45c37e6c..d3f4763eff3f3 100644 --- a/lib/srv/db/proxyserver.go +++ b/lib/srv/db/proxyserver.go @@ -41,7 +41,6 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/enterprise" @@ -83,7 +82,7 @@ type ProxyServerConfig struct { // Authorizer is responsible for authorizing user identities. Authorizer authz.Authorizer // Tunnel is the reverse tunnel server. - Tunnel reversetunnel.Server + Tunnel reversetunnelclient.Server // TLSConfig is the proxy server TLS configuration. TLSConfig *tls.Config // Limiter is the connection/rate limiter. @@ -446,7 +445,7 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext return nil, trace.Wrap(err) } - serviceConn, err := proxyCtx.Cluster.Dial(reversetunnel.DialParams{ + serviceConn, err := proxyCtx.Cluster.Dial(reversetunnelclient.DialParams{ From: clientSrcAddr, To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode}, OriginalClientDstAddr: clientDstAddr, @@ -475,7 +474,7 @@ func (s *ProxyServer) Connect(ctx context.Context, proxyCtx *common.ProxyContext // the reverse tunnel connection is down e.g. because the agent is down. func isReverseTunnelDownError(err error) bool { return trace.IsConnectionProblem(err) || - strings.Contains(err.Error(), reversetunnel.NoDatabaseTunnel) + strings.Contains(err.Error(), reversetunnelclient.NoDatabaseTunnel) } // Proxy starts proxying all traffic received from database client between @@ -531,7 +530,7 @@ func (s *ProxyServer) Authorize(ctx context.Context, tlsConn utils.TLSConn, para // getDatabaseServers finds database servers that proxy the database instance // encoded in the provided identity. -func (s *ProxyServer) getDatabaseServers(ctx context.Context, identity tlsca.Identity) (reversetunnel.RemoteSite, []types.DatabaseServer, error) { +func (s *ProxyServer) getDatabaseServers(ctx context.Context, identity tlsca.Identity) (reversetunnelclient.RemoteSite, []types.DatabaseServer, error) { cluster, err := s.cfg.Tunnel.GetSite(identity.RouteToCluster) if err != nil { return nil, nil, trace.Wrap(err) diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index eae50d6a151d5..9c71af4c4e39a 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -177,8 +177,12 @@ func newProxySubsys(ctx *srv.ServerContext, srv *Server, req proxySubsysRequest) req.clusterName = ctx.Identity.RouteToCluster } if req.clusterName != "" && srv.proxyTun != nil { - _, err := srv.tunnelWithAccessChecker(ctx).GetSite(req.clusterName) + checker, err := srv.tunnelWithAccessChecker(ctx) if err != nil { + return nil, trace.Wrap(err) + } + + if _, err := checker.GetSite(req.clusterName); err != nil { return nil, trace.BadParameter("invalid format for proxy request: unknown cluster %q", req.clusterName) } } diff --git a/lib/srv/regular/sites.go b/lib/srv/regular/sites.go index d4ad0cbfff110..04e3800e89cb5 100644 --- a/lib/srv/regular/sites.go +++ b/lib/srv/regular/sites.go @@ -52,7 +52,12 @@ func (t *proxySitesSubsys) Wait() error { // service.Site structures, and writes it serialized as JSON back to the SSH client func (t *proxySitesSubsys) Start(ctx context.Context, sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Request, serverContext *srv.ServerContext) error { log.Debugf("proxysites.start(%v)", serverContext) - remoteSites, err := t.srv.tunnelWithAccessChecker(serverContext).GetSites() + checker, err := t.srv.tunnelWithAccessChecker(serverContext) + if err != nil { + return trace.Wrap(err) + } + + remoteSites, err := checker.GetSites() if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 6c4f90b249813..982aecbbf7448 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -56,6 +56,7 @@ import ( "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" @@ -101,7 +102,7 @@ type Server struct { cloudLabels labels.Importer proxyMode bool - proxyTun reversetunnel.Tunnel + proxyTun reversetunnelclient.Tunnel proxyAccessPoint auth.ReadProxyAccessPoint peerAddr string @@ -438,7 +439,7 @@ func SetShell(shell string) ServerOption { } // SetProxyMode starts this server in SSH proxying mode -func SetProxyMode(peerAddr string, tsrv reversetunnel.Tunnel, ap auth.ReadProxyAccessPoint, router *proxy.Router) ServerOption { +func SetProxyMode(peerAddr string, tsrv reversetunnelclient.Tunnel, ap auth.ReadProxyAccessPoint, router *proxy.Router) ServerOption { return func(s *Server) error { // always set proxy mode to true, // because in some tests reverse tunnel is disabled, @@ -903,8 +904,13 @@ func (s *Server) getNamespace() string { return types.ProcessNamespace(s.namespace) } -func (s *Server) tunnelWithAccessChecker(ctx *srv.ServerContext) reversetunnel.Tunnel { - return reversetunnel.NewTunnelWithRoles(s.proxyTun, ctx.Identity.AccessChecker, s.proxyAccessPoint) +func (s *Server) tunnelWithAccessChecker(ctx *srv.ServerContext) (reversetunnelclient.Tunnel, error) { + clusterName, err := s.GetAccessPoint().GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + + return reversetunnelclient.NewTunnelWithRoles(s.proxyTun, clusterName.GetClusterName(), ctx.Identity.AccessChecker, s.proxyAccessPoint), nil } // Context returns server shutdown context diff --git a/lib/srv/session_control.go b/lib/srv/session_control.go index 067fee75a74dc..4e38190d94653 100644 --- a/lib/srv/session_control.go +++ b/lib/srv/session_control.go @@ -24,6 +24,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/constants" @@ -140,6 +141,53 @@ func NewSessionController(cfg SessionControllerConfig) (*SessionController, erro return &SessionController{cfg: cfg}, nil } +// WebSessionContext contains information associated with a session +// established via the web ui. +type WebSessionContext interface { + GetUserAccessChecker() (services.AccessChecker, error) + GetSSHCertificate() (*ssh.Certificate, error) + GetUser() string +} + +// WebSessionController is a wrapper around [SessionController] which can be +// used to create an [IdentityContext] and apply session controls for a web session. +// This allows `lib/web` to not depend on `lib/srv`. +func WebSessionController(controller *SessionController) func(ctx context.Context, sctx WebSessionContext, login, localAddr, remoteAddr string) (context.Context, error) { + return func(ctx context.Context, sctx WebSessionContext, login, localAddr, remoteAddr string) (context.Context, error) { + accessChecker, err := sctx.GetUserAccessChecker() + if err != nil { + return ctx, trace.Wrap(err) + } + + sshCert, err := sctx.GetSSHCertificate() + if err != nil { + return ctx, trace.Wrap(err) + } + + unmappedRoles, err := services.ExtractRolesFromCert(sshCert) + if err != nil { + return ctx, trace.Wrap(err) + } + + accessRequestIDs, err := ParseAccessRequestIDs(sshCert.Extensions[teleport.CertExtensionTeleportActiveRequests]) + if err != nil { + return ctx, trace.Wrap(err) + } + + identity := IdentityContext{ + AccessChecker: accessChecker, + TeleportUser: sctx.GetUser(), + Login: login, + Certificate: sshCert, + UnmappedRoles: unmappedRoles, + ActiveRequests: accessRequestIDs, + Impersonator: sshCert.Extensions[teleport.CertExtensionImpersonator], + } + ctx, err = controller.AcquireSessionContext(ctx, identity, localAddr, remoteAddr) + return ctx, trace.Wrap(err) + } +} + // AcquireSessionContext attempts to create a context for the session. If the session is // not allowed due to session control an error is returned. The returned // context is scoped to the session and will be canceled in the event the semaphore lock diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 7dff6a24d48d0..6c6ace76a978a 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -80,11 +80,10 @@ import ( "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/proxy" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/secret" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" websession "github.com/gravitational/teleport/lib/web/session" @@ -173,7 +172,7 @@ type Config struct { PluginRegistry plugin.Registry // Proxy is a reverse tunnel proxy that handles connections // to local cluster or remote clusters using unified interface - Proxy reversetunnel.Tunnel + Proxy reversetunnelclient.Tunnel // AuthServers is a list of auth servers this proxy talks to AuthServers utils.NetAddr // DomainName is a domain name served by web handler @@ -243,7 +242,7 @@ type Config struct { // SessionControl is used to determine if users are // allowed to spawn new sessions - SessionControl *srv.SessionController + SessionControl SessionController // PROXYSigner is used to sign PROXY header and securely propagate client IP information PROXYSigner multiplexer.PROXYHeaderSigner @@ -876,7 +875,7 @@ func (h *Handler) handleGetUserOrResetToken(w http.ResponseWriter, r *http.Reque // getUserContext returns user context // // GET /webapi/sites/:site/context -func (h *Handler) getUserContext(w http.ResponseWriter, r *http.Request, p httprouter.Params, c *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) getUserContext(w http.ResponseWriter, r *http.Request, p httprouter.Params, c *SessionContext, site reversetunnelclient.RemoteSite) (any, error) { cn, err := h.cfg.AccessPoint.GetClusterName() if err != nil { return nil, trace.Wrap(err) @@ -2417,7 +2416,7 @@ type getSiteNamespacesResponse struct { // Successful response: // // {"namespaces": [{..namespace resource...}]} -func (h *Handler) getSiteNamespaces(w http.ResponseWriter, r *http.Request, _ httprouter.Params, c *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) getSiteNamespaces(w http.ResponseWriter, r *http.Request, _ httprouter.Params, c *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := site.GetClient() if err != nil { return nil, trace.Wrap(err) @@ -2432,7 +2431,7 @@ func (h *Handler) getSiteNamespaces(w http.ResponseWriter, r *http.Request, _ ht } // clusterNodesGet returns a list of nodes for a given cluster site. -func (h *Handler) clusterNodesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterNodesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { // Get a client to the Auth Server with the logged in user's identity. The // identity of the logged in user is used to fetch the list of nodes. clt, err := sctx.GetUserClient(r.Context(), site) @@ -2472,7 +2471,7 @@ type getLoginAlertsResponse struct { } // clusterLoginAlertsGet returns a list of on-login alerts for the user. -func (h *Handler) clusterLoginAlertsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterLoginAlertsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { // Get a client to the Auth Server with the logged in user's identity. The // identity of the logged in user is used to fetch the list of alerts. clt, err := sctx.GetUserClient(r.Context(), site) @@ -2494,46 +2493,12 @@ func (h *Handler) clusterLoginAlertsGet(w http.ResponseWriter, r *http.Request, }, nil } -// createIdentityContext creates a srv.IdentityContext from the ssh cert of the user -// stored within the SessionContext. -func createIdentityContext(login string, sessionCtx *SessionContext) (srv.IdentityContext, error) { - accessChecker, err := sessionCtx.GetUserAccessChecker() - if err != nil { - return srv.IdentityContext{}, trace.Wrap(err) - } - - sshCert, err := sessionCtx.GetSSHCertificate() - if err != nil { - return srv.IdentityContext{}, trace.Wrap(err) - } - - unmappedRoles, err := services.ExtractRolesFromCert(sshCert) - if err != nil { - return srv.IdentityContext{}, trace.Wrap(err) - } - - accessRequestIDs, err := srv.ParseAccessRequestIDs(sshCert.Extensions[teleport.CertExtensionTeleportActiveRequests]) - if err != nil { - return srv.IdentityContext{}, trace.Wrap(err) - } - - return srv.IdentityContext{ - AccessChecker: accessChecker, - TeleportUser: sessionCtx.GetUser(), - Login: login, - Certificate: sshCert, - UnmappedRoles: unmappedRoles, - ActiveRequests: accessRequestIDs, - Impersonator: sshCert.Extensions[teleport.CertExtensionImpersonator], - }, nil -} - func (h *Handler) getClusterLocks( w http.ResponseWriter, r *http.Request, p httprouter.Params, sessionCtx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (interface{}, error) { ctx := r.Context() clt, err := sessionCtx.GetUserClient(ctx, site) @@ -2560,7 +2525,7 @@ func (h *Handler) createClusterLock( r *http.Request, p httprouter.Params, sessionCtx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (interface{}, error) { var req *createLockReq if err := httplib.ReadJSON(r, &req); err != nil { @@ -2608,7 +2573,7 @@ func (h *Handler) deleteClusterLock( r *http.Request, p httprouter.Params, sessionCtx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (interface{}, error) { ctx := r.Context() clt, err := sessionCtx.GetUserClient(ctx, site) @@ -2624,6 +2589,28 @@ func (h *Handler) deleteClusterLock( return OK(), nil } +// SessionController restricts creation of sessions based on +// cluster session control configuration(locks, connection limits, etc). +type SessionController interface { + // AcquireSessionContext attempts to create a context for the session. If the session is + // not allowed due to session control an error is returned. The returned + // context is scoped to the session and will be canceled in the event that session + // controls terminate the session early. + AcquireSessionContext(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) +} + +// SessionControllerFunc type is an adapter to allow the use of +// ordinary functions a [SessionController]. If f is a function +// with the appropriate signature, SessionControllerFunc(f) is a +// SessionController that calls f. +type SessionControllerFunc func(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) + +// AcquireSessionContext calls f(ctx, sctx, localAddr, remoteAddr). +func (f SessionControllerFunc) AcquireSessionContext(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) { + ctx, err := f(ctx, sctx, login, localAddr, remoteAddr) + return ctx, trace.Wrap(err) +} + // siteNodeConnect connect to the site node // // GET /v1/webapi/sites/:site/namespaces/:namespace/connect?access_token=bearer_token¶ms= @@ -2639,7 +2626,7 @@ func (h *Handler) siteNodeConnect( r *http.Request, p httprouter.Params, sessionCtx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (interface{}, error) { q := r.URL.Query() params := q.Get("params") @@ -2656,12 +2643,7 @@ func (h *Handler) siteNodeConnect( return nil, trace.Wrap(err) } - identity, err := createIdentityContext(req.Login, sessionCtx) - if err != nil { - return nil, trace.Wrap(err) - } - - ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), identity, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) + ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), sessionCtx, req.Login, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) if err != nil { return nil, trace.Wrap(err) } @@ -3007,7 +2989,7 @@ func trackerToLegacySession(tracker types.SessionTracker, clusterName string) se // clusterActiveAndPendingSessionsGet gets the list of active and pending sessions for a site. // // GET /v1/webapi/sites/:site/sessions -func (h *Handler) clusterActiveAndPendingSessionsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterActiveAndPendingSessionsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -3073,7 +3055,7 @@ func toFieldsSlice(rawEvents []apievents.AuditEvent) ([]events.EventFields, erro // "order": optional ordering of events. Can be either "asc" or "desc" // for ascending and descending respectively. // If no order is provided it defaults to descending. -func (h *Handler) clusterSearchEvents(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterSearchEvents(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { values := r.URL.Query() var eventTypes []string @@ -3108,7 +3090,7 @@ func (h *Handler) clusterSearchEvents(w http.ResponseWriter, r *http.Request, p // "order": optional ordering of events. Can be either "asc" or "desc" // for ascending and descending respectively. // If no order is provided it defaults to descending. -func (h *Handler) clusterSearchSessionEvents(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterSearchSessionEvents(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { searchSessionEvents := func(clt auth.ClientI, from, to time.Time, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { return clt.SearchSessionEvents(r.Context(), events.SearchSessionEventsRequest{ From: from, @@ -3123,7 +3105,7 @@ func (h *Handler) clusterSearchSessionEvents(w http.ResponseWriter, r *http.Requ // clusterEventsList returns a list of audit events obtained using the provided // searchEvents method. -func clusterEventsList(ctx context.Context, sctx *SessionContext, site reversetunnel.RemoteSite, values url.Values, searchEvents func(clt auth.ClientI, from, to time.Time, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error)) (interface{}, error) { +func clusterEventsList(ctx context.Context, sctx *SessionContext, site reversetunnelclient.RemoteSite, values url.Values, searchEvents func(clt auth.ClientI, from, to time.Time, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error)) (interface{}, error) { from, err := queryTime(values, "from", time.Now().UTC().AddDate(0, -1, 0)) if err != nil { return nil, trace.Wrap(err) @@ -3326,7 +3308,7 @@ type eventsListGetResponse struct { // Response body (each event is an arbitrary JSON structure) // // {"events": [{...}, {...}, ...} -func (h *Handler) siteSessionEventsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) siteSessionEventsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { sessionID, err := session.ParseID(p.ByName("sid")) if err != nil { return nil, trace.BadParameter("invalid session ID %q", p.ByName("sid")) @@ -3505,7 +3487,7 @@ const currentSiteShortcut = "-current-" type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext) (interface{}, error) // ClusterHandler is a authenticated handler that is called for some existing remote cluster -type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) +type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) // WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy // (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster @@ -3525,7 +3507,7 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. -func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnel.RemoteSite, error) { +func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) { sctx, err := h.AuthenticateRequest(w, r, true) if err != nil { return nil, nil, trace.Wrap(err) @@ -3541,7 +3523,7 @@ func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http. // getSiteByParams gets the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. -func (h *Handler) getSiteByParams(sctx *SessionContext, p httprouter.Params) (reversetunnel.RemoteSite, error) { +func (h *Handler) getSiteByParams(sctx *SessionContext, p httprouter.Params) (reversetunnelclient.RemoteSite, error) { clusterName := p.ByName("site") if clusterName == currentSiteShortcut { res, err := h.cfg.ProxyClient.GetClusterName() @@ -3560,7 +3542,7 @@ func (h *Handler) getSiteByParams(sctx *SessionContext, p httprouter.Params) (re return site, nil } -func (h *Handler) getSiteByClusterName(ctx *SessionContext, clusterName string) (reversetunnel.RemoteSite, error) { +func (h *Handler) getSiteByClusterName(ctx *SessionContext, clusterName string) (reversetunnelclient.RemoteSite, error) { proxy, err := h.ProxyWithRoles(ctx) if err != nil { h.log.WithError(err).Warn("Failed to get proxy with roles.") @@ -3626,7 +3608,7 @@ func (h *Handler) WithClusterClientProvider(fn ClusterClientHandler) httprouter. } // ProvisionTokenHandler is a authenticated handler that is called for some existing Token -type ProvisionTokenHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, site reversetunnel.RemoteSite, token types.ProvisionToken) (interface{}, error) +type ProvisionTokenHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, site reversetunnelclient.RemoteSite, token types.ProvisionToken) (interface{}, error) // WithProvisionTokenAuth ensures that request is authenticated with a provision token. // Provision tokens, when used like this are invalidated as soon as used. @@ -3854,13 +3836,18 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions // of the given user. -func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnel.Tunnel, error) { +func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) { accessChecker, err := ctx.GetUserAccessChecker() if err != nil { h.log.WithError(err).Warn("Failed to get client roles.") return nil, trace.Wrap(err) } - return reversetunnel.NewTunnelWithRoles(h.cfg.Proxy, accessChecker, h.cfg.AccessPoint), nil + + cn, err := h.cfg.AccessPoint.GetClusterName() + if err != nil { + return nil, trace.Wrap(err) + } + return reversetunnelclient.NewTunnelWithRoles(h.cfg.Proxy, cn.GetClusterName(), accessChecker, h.cfg.AccessPoint), nil } // ProxyHostPort returns the address of the proxy server using --proxy diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 39eb2cc7e0eea..9e756f9294634 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -110,6 +110,7 @@ import ( "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/secret" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -133,7 +134,7 @@ type WebSuite struct { node *regular.Server proxy *regular.Server - proxyTunnel reversetunnel.Server + proxyTunnel reversetunnelclient.Server srvID string user string @@ -453,11 +454,15 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { StaticFS: fs, CachedSessionLingeringThreshold: &sessionLingeringThreshold, ProxySettings: &mockProxySettings{}, - SessionControl: proxySessionController, - Router: router, - HealthCheckAppServer: cfg.HealthCheckAppServer, - UI: cfg.uiConfig, - OpenAIConfig: cfg.OpenAIConfig, + SessionControl: SessionControllerFunc(func(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) { + controller := srv.WebSessionController(proxySessionController) + ctx, err := controller(ctx, sctx, login, localAddr, remoteAddr) + return ctx, trace.Wrap(err) + }), + Router: router, + HealthCheckAppServer: cfg.HealthCheckAppServer, + UI: cfg.uiConfig, + OpenAIConfig: cfg.OpenAIConfig, } if handlerConfig.HealthCheckAppServer == nil { @@ -7692,19 +7697,23 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula fs, err := newDebugFileSystem() require.NoError(t, err) handler, err := NewHandler(Config{ - Proxy: revTunServer, - AuthServers: utils.FromAddr(authServer.Addr()), - DomainName: authServer.ClusterName(), - ProxyClient: client, - ProxyPublicAddrs: utils.MustParseAddrList("proxy-1.example.com", "proxy-2.example.com"), - CipherSuites: utils.DefaultCipherSuites(), - AccessPoint: client, - Context: ctx, - HostUUID: proxyID, - Emitter: client, - StaticFS: fs, - ProxySettings: &mockProxySettings{}, - SessionControl: sessionController, + Proxy: revTunServer, + AuthServers: utils.FromAddr(authServer.Addr()), + DomainName: authServer.ClusterName(), + ProxyClient: client, + ProxyPublicAddrs: utils.MustParseAddrList("proxy-1.example.com", "proxy-2.example.com"), + CipherSuites: utils.DefaultCipherSuites(), + AccessPoint: client, + Context: ctx, + HostUUID: proxyID, + Emitter: client, + StaticFS: fs, + ProxySettings: &mockProxySettings{}, + SessionControl: SessionControllerFunc(func(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) { + controller := srv.WebSessionController(sessionController) + ctx, err := controller(ctx, sctx, login, localAddr, remoteAddr) + return ctx, trace.Wrap(err) + }), Router: router, HealthCheckAppServer: func(context.Context, string, string) error { return nil }, MinimalReverseTunnelRoutesOnly: cfg.minimalHandler, @@ -7766,7 +7775,7 @@ type testProxy struct { clock clockwork.FakeClock client auth.ClientI auth *auth.TestTLSServer - revTun reversetunnel.Server + revTun reversetunnelclient.Server node *regular.Server proxy *regular.Server handler *APIHandler @@ -8325,7 +8334,7 @@ func newKubeConfigFile(ctx context.Context, t *testing.T, clusters ...kubeCluste type startKubeOptions struct { clusters []kubeClusterConfig authServer *auth.TestTLSServer - revTunnel reversetunnel.Server + revTunnel reversetunnelclient.Server serviceType kubeproxy.KubeServiceType } diff --git a/lib/web/app/handler.go b/lib/web/app/handler.go index c74d0b1626bfa..cbcf4e7f14524 100644 --- a/lib/web/app/handler.go +++ b/lib/web/app/handler.go @@ -39,7 +39,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/httplib/reverseproxy" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -53,7 +53,7 @@ type HandlerConfig struct { // AccessPoint is caching client to auth. AccessPoint auth.ProxyAccessPoint // ProxyClient holds connections to leaf clusters. - ProxyClient reversetunnel.Tunnel + ProxyClient reversetunnelclient.Tunnel // ProxyPublicAddrs contains web proxy public addresses. ProxyPublicAddrs []utils.NetAddr // CipherSuites is the list of TLS cipher suites that have been configured diff --git a/lib/web/app/handler_test.go b/lib/web/app/handler_test.go index 02d06acac7d93..c50dc701a01bb 100644 --- a/lib/web/app/handler_test.go +++ b/lib/web/app/handler_test.go @@ -45,7 +45,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" @@ -349,9 +349,9 @@ func TestMatchApplicationServers(t *testing.T) { } // Create a fake remote site and tunnel. - fakeRemoteSite := reversetunnel.NewFakeRemoteSite(clusterName, authClient) - tunnel := &reversetunnel.FakeServer{ - Sites: []reversetunnel.RemoteSite{ + fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, authClient) + tunnel := &reversetunnelclient.FakeServer{ + Sites: []reversetunnelclient.RemoteSite{ fakeRemoteSite, }, } @@ -405,14 +405,14 @@ func TestHealthCheckAppServer(t *testing.T) { for _, tc := range []struct { desc string publicAddr string - appServersFunc func(t *testing.T, remoteSite *reversetunnel.FakeRemoteSite) []types.AppServer + appServersFunc func(t *testing.T, remoteSite *reversetunnelclient.FakeRemoteSite) []types.AppServer expectedTunnelCalls int expectErr require.ErrorAssertionFunc }{ { desc: "match and online services", publicAddr: "valid.example.com", - appServersFunc: func(t *testing.T, _ *reversetunnel.FakeRemoteSite) []types.AppServer { + appServersFunc: func(t *testing.T, _ *reversetunnelclient.FakeRemoteSite) []types.AppServer { return []types.AppServer{createAppServer(t, "valid.example.com")} }, expectedTunnelCalls: 1, @@ -421,7 +421,7 @@ func TestHealthCheckAppServer(t *testing.T) { { desc: "match and but no online services", publicAddr: "valid.example.com", - appServersFunc: func(t *testing.T, tunnel *reversetunnel.FakeRemoteSite) []types.AppServer { + appServersFunc: func(t *testing.T, tunnel *reversetunnelclient.FakeRemoteSite) []types.AppServer { appServer := createAppServer(t, "valid.example.com") tunnel.OfflineTunnels = map[string]struct{}{ fmt.Sprintf("%s.%s", appServer.GetHostID(), clusterName): {}, @@ -434,7 +434,7 @@ func TestHealthCheckAppServer(t *testing.T) { { desc: "no match", publicAddr: "valid.example.com", - appServersFunc: func(t *testing.T, tunnel *reversetunnel.FakeRemoteSite) []types.AppServer { + appServersFunc: func(t *testing.T, tunnel *reversetunnelclient.FakeRemoteSite) []types.AppServer { return []types.AppServer{} }, expectedTunnelCalls: 0, @@ -458,7 +458,7 @@ func TestHealthCheckAppServer(t *testing.T) { caCert: cert, } - fakeRemoteSite := reversetunnel.NewFakeRemoteSite(clusterName, authClient) + fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, authClient) authClient.appServers = tc.appServersFunc(t, fakeRemoteSite) // Create a httptest server to serve the application requests. It must serve @@ -476,8 +476,8 @@ func TestHealthCheckAppServer(t *testing.T) { } server.StartTLS() - tunnel := &reversetunnel.FakeServer{ - Sites: []reversetunnel.RemoteSite{fakeRemoteSite}, + tunnel := &reversetunnelclient.FakeServer{ + Sites: []reversetunnelclient.RemoteSite{fakeRemoteSite}, } appHandler, err := NewHandler(ctx, &HandlerConfig{ @@ -500,7 +500,7 @@ type testServer struct { serverURL *url.URL } -func setup(t *testing.T, clock clockwork.FakeClock, authClient auth.ClientI, proxyClient reversetunnel.Tunnel, proxyPublicAddrs []utils.NetAddr) *testServer { +func setup(t *testing.T, clock clockwork.FakeClock, authClient auth.ClientI, proxyClient reversetunnelclient.Tunnel, proxyPublicAddrs []utils.NetAddr) *testServer { appHandler, err := NewHandler(context.Background(), &HandlerConfig{ Clock: clock, AuthClient: authClient, @@ -655,7 +655,7 @@ func (c *mockAuthClient) GetCertAuthority(ctx context.Context, id types.CertAuth // fakeRemoteListener Implements a `net.Listener` that return `net.Conn` from // the `FakeRemoteSite`. type fakeRemoteListener struct { - fakeRemote *reversetunnel.FakeRemoteSite + fakeRemote *reversetunnelclient.FakeRemoteSite } func (r *fakeRemoteListener) Accept() (net.Conn, error) { diff --git a/lib/web/app/match.go b/lib/web/app/match.go index a8697d95ad690..dd6105035376c 100644 --- a/lib/web/app/match.go +++ b/lib/web/app/match.go @@ -26,7 +26,7 @@ import ( "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" ) @@ -99,7 +99,7 @@ func MatchName(name string) Matcher { // MatchHealthy tries to establish a connection with the server using the // `dialAppServer` function. The app server is matched if the function call // doesn't return any error. -func MatchHealthy(proxyClient reversetunnel.Tunnel, clusterName string) Matcher { +func MatchHealthy(proxyClient reversetunnelclient.Tunnel, clusterName string) Matcher { return func(ctx context.Context, appServer types.AppServer) bool { conn, err := dialAppServer(ctx, proxyClient, clusterName, appServer) if err != nil { @@ -132,7 +132,7 @@ func MatchAll(matchers ...Matcher) Matcher { // cluster, this method will always return "acme" running within the root // cluster. Always supply public address and cluster name to deterministically // resolve an application. -func ResolveFQDN(ctx context.Context, clt Getter, tunnel reversetunnel.Tunnel, proxyDNSNames []string, fqdn string) (types.AppServer, string, error) { +func ResolveFQDN(ctx context.Context, clt Getter, tunnel reversetunnelclient.Tunnel, proxyDNSNames []string, fqdn string) (types.AppServer, string, error) { // Try and match FQDN to public address of application within cluster. servers, err := Match(ctx, clt, MatchPublicAddr(fqdn)) if err == nil && len(servers) > 0 { diff --git a/lib/web/app/match_test.go b/lib/web/app/match_test.go index f583f06804df2..0a000a9e7a01c 100644 --- a/lib/web/app/match_test.go +++ b/lib/web/app/match_test.go @@ -26,7 +26,7 @@ import ( "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" ) func TestMatchAll(t *testing.T) { @@ -79,20 +79,20 @@ func TestMatchHealthy(t *testing.T) { } type mockProxyClient struct { - reversetunnel.Tunnel + reversetunnelclient.Tunnel remoteSite *mockRemoteSite } -func (p *mockProxyClient) GetSite(_ string) (reversetunnel.RemoteSite, error) { +func (p *mockProxyClient) GetSite(_ string) (reversetunnelclient.RemoteSite, error) { return p.remoteSite, nil } type mockRemoteSite struct { - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite dialErr error } -func (r *mockRemoteSite) Dial(_ reversetunnel.DialParams) (net.Conn, error) { +func (r *mockRemoteSite) Dial(_ reversetunnelclient.DialParams) (net.Conn, error) { if r.dialErr != nil { return nil, r.dialErr } diff --git a/lib/web/app/session.go b/lib/web/app/session.go index 64f2ae973269f..8dd9a7f548e50 100644 --- a/lib/web/app/session.go +++ b/lib/web/app/session.go @@ -28,7 +28,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib/reverseproxy" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" ) @@ -120,7 +120,7 @@ func (h *Handler) newSession(ctx context.Context, ws types.WebSession) (*session // appServerMatcher returns a Matcher function used to find which AppServer can // handle the application requests. -func appServerMatcher(proxyClient reversetunnel.Tunnel, publicAddr string, clusterName string) Matcher { +func appServerMatcher(proxyClient reversetunnelclient.Tunnel, publicAddr string, clusterName string) Matcher { // Match healthy and PublicAddr servers. Having a list of only healthy // servers helps the transport fail before the request is forwarded to a // server (in cases where there are no healthy servers). This process might diff --git a/lib/web/app/transport.go b/lib/web/app/transport.go index 1d5ec60f7e19f..69aca98ab6018 100644 --- a/lib/web/app/transport.go +++ b/lib/web/app/transport.go @@ -34,7 +34,6 @@ import ( apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tlsca" @@ -43,7 +42,7 @@ import ( // transportConfig is configuration for a rewriting transport. type transportConfig struct { - proxyClient reversetunnel.Tunnel + proxyClient reversetunnelclient.Tunnel accessPoint auth.ReadProxyAccessPoint cipherSuites []uint16 identity *tlsca.Identity @@ -285,7 +284,7 @@ func (t *transport) DialWebsocket(network, address string) (net.Conn, error) { // dialAppServer dial and connect to the application service over the reverse // tunnel subsystem. -func dialAppServer(ctx context.Context, proxyClient reversetunnel.Tunnel, clusterName string, server types.AppServer) (net.Conn, error) { +func dialAppServer(ctx context.Context, proxyClient reversetunnelclient.Tunnel, clusterName string, server types.AppServer) (net.Conn, error) { clusterClient, err := proxyClient.GetSite(clusterName) if err != nil { return nil, trace.Wrap(err) @@ -298,7 +297,7 @@ func dialAppServer(ctx context.Context, proxyClient reversetunnel.Tunnel, cluste from = clientSrcAddr } - conn, err := clusterClient.Dial(reversetunnel.DialParams{ + conn, err := clusterClient.Dial(reversetunnelclient.DialParams{ From: from, To: &utils.NetAddr{AddrNetwork: "tcp", Addr: reversetunnelclient.LocalNode}, OriginalClientDstAddr: originalDst, @@ -349,5 +348,5 @@ func configureTLS(c *transportConfig) (*tls.Config, error) { // the reverse tunnel connection is down e.g. because the agent is down. func isReverseTunnelDownError(err error) bool { return trace.IsConnectionProblem(err) || - strings.Contains(err.Error(), reversetunnel.NoApplicationTunnel) + strings.Contains(err.Error(), reversetunnelclient.NoApplicationTunnel) } diff --git a/lib/web/apps.go b/lib/web/apps.go index a8e1ce6acc257..1a5b7a63ea71e 100644 --- a/lib/web/apps.go +++ b/lib/web/apps.go @@ -35,7 +35,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" @@ -43,7 +43,7 @@ import ( ) // clusterAppsGet returns a list of applications in a form the UI can present. -func (h *Handler) clusterAppsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterAppsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { identity, err := sctx.GetIdentity() if err != nil { return nil, trace.Wrap(err) @@ -315,7 +315,7 @@ type resolveAppResult struct { App types.Application } -func (h *Handler) resolveApp(ctx context.Context, clt app.Getter, proxy reversetunnel.Tunnel, params resolveAppParams) (*resolveAppResult, error) { +func (h *Handler) resolveApp(ctx context.Context, clt app.Getter, proxy reversetunnelclient.Tunnel, params resolveAppParams) (*resolveAppResult, error) { var ( server types.AppServer appClusterName string @@ -349,7 +349,7 @@ func (h *Handler) resolveApp(ctx context.Context, clt app.Getter, proxy reverset // resolveDirect takes a public address and cluster name and exactly resolves // the application and the server on which it is running. -func (h *Handler) resolveDirect(ctx context.Context, proxy reversetunnel.Tunnel, publicAddr string, clusterName string) (types.AppServer, string, error) { +func (h *Handler) resolveDirect(ctx context.Context, proxy reversetunnelclient.Tunnel, publicAddr string, clusterName string) (types.AppServer, string, error) { clusterClient, err := proxy.GetSite(clusterName) if err != nil { return nil, "", trace.Wrap(err) @@ -374,7 +374,7 @@ func (h *Handler) resolveDirect(ctx context.Context, proxy reversetunnel.Tunnel, // resolveFQDN makes a best effort attempt to resolve FQDN to an application // running within a root or leaf cluster. -func (h *Handler) resolveFQDN(ctx context.Context, clt app.Getter, proxy reversetunnel.Tunnel, fqdn string) (types.AppServer, string, error) { +func (h *Handler) resolveFQDN(ctx context.Context, clt app.Getter, proxy reversetunnelclient.Tunnel, fqdn string) (types.AppServer, string, error) { return app.ResolveFQDN(ctx, clt, proxy, h.proxyDNSNames(), fqdn) } diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 2001d9caa6de5..6845dcb9f9c40 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -36,7 +36,7 @@ import ( "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" ) // createAssistantConversationResponse is a response for POST /webapi/assistant/conversations. @@ -287,7 +287,7 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, } func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, - sctx *SessionContext, site reversetunnel.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ) (any, error) { if err := runAssistant(h, w, r, sctx, site); err != nil { h.log.Warn(trace.DebugReport(err)) @@ -311,7 +311,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, - sctx *SessionContext, site reversetunnel.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") @@ -328,12 +328,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - identity, err := createIdentityContext(sctx.GetUser(), sctx) - if err != nil { - return trace.Wrap(err) - } - - ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), identity, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) + ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), sctx, sctx.GetUser(), h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) if err != nil { return trace.Wrap(err) } diff --git a/lib/web/command.go b/lib/web/command.go index 2bcdd223f8382..b230ff8759cef 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -48,10 +48,9 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/proxy" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/teleagent" ) @@ -118,7 +117,7 @@ func (h *Handler) executeCommand( r *http.Request, _ httprouter.Params, sessionCtx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (any, error) { q := r.URL.Query() params := q.Get("params") @@ -143,12 +142,7 @@ func (h *Handler) executeCommand( return nil, trace.Wrap(err) } - identity, err := createIdentityContext(req.Login, sessionCtx) - if err != nil { - return nil, trace.Wrap(err) - } - - ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), identity, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) + ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), sessionCtx, req.Login, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) if err != nil { return nil, trace.Wrap(err) } @@ -273,7 +267,7 @@ func (h *Handler) executeCommand( err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{ ConversationId: req.ConversationID, - Username: identity.TeleportUser, + Username: sessionCtx.GetUser(), Message: &assist.AssistantMessage{ Type: string(assistlib.MessageKindCommandResult), CreatedTime: timestamppb.New(time.Now().UTC()), @@ -292,7 +286,7 @@ func (h *Handler) executeCommand( hosts: hosts, output: output, authClient: clt, - identity: identity, + username: sessionCtx.GetUser(), executionID: req.ExecutionID, conversationID: req.ConversationID, command: req.Command, @@ -310,7 +304,7 @@ type summaryRequest struct { hosts []hostInfo output map[string][]byte authClient auth.ClientI - identity srv.IdentityContext + username string executionID string conversationID string command string @@ -326,7 +320,7 @@ func (h *Handler) computeAndSendSummary( history, err := req.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ ConversationId: req.conversationID, - Username: req.identity.TeleportUser, + Username: req.username, }) if err != nil { return trace.Wrap(err) @@ -354,7 +348,7 @@ func (h *Handler) computeAndSendSummary( } summaryMessage := &assist.CreateAssistantMessageRequest{ ConversationId: req.conversationID, - Username: req.identity.TeleportUser, + Username: req.username, Message: &assist.AssistantMessage{ Type: string(assistlib.MessageKindCommandResultSummary), CreatedTime: timestamppb.New(time.Now().UTC()), diff --git a/lib/web/connection_diagnostic.go b/lib/web/connection_diagnostic.go index d2b71731d5e80..49f90aaa83f46 100644 --- a/lib/web/connection_diagnostic.go +++ b/lib/web/connection_diagnostic.go @@ -24,12 +24,12 @@ import ( "github.com/gravitational/teleport/lib/client/conntest" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/web/ui" ) // getConnectionDiagnostic returns a connection diagnostic connection diagnostics. -func (h *Handler) getConnectionDiagnostic(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) getConnectionDiagnostic(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -50,7 +50,7 @@ func (h *Handler) getConnectionDiagnostic(w http.ResponseWriter, r *http.Request } // diagnoseConnection executes and returns a connection diagnostic. -func (h *Handler) diagnoseConnection(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) diagnoseConnection(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { req := conntest.TestConnectionRequest{} if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/databases.go b/lib/web/databases.go index e904f405b5461..5c79edbbeb1c5 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -33,7 +33,7 @@ import ( "github.com/gravitational/teleport/api/utils/tlsutils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" "github.com/gravitational/teleport/lib/web/scripts" "github.com/gravitational/teleport/lib/web/ui" @@ -80,7 +80,7 @@ func (r *createDatabaseRequest) checkAndSetDefaults() error { } // handleDatabaseCreate creates a database's metadata. -func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) handleDatabaseCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { var req *createDatabaseRequest if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) @@ -156,7 +156,7 @@ func (r *updateDatabaseRequest) checkAndSetDefaults() error { } // handleDatabaseUpdate updates the database -func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) handleDatabaseUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { databaseName := p.ByName("database") if databaseName == "" { return nil, trace.BadParameter("a database name is required") @@ -245,7 +245,7 @@ type databaseIAMPolicyAWS struct { } // handleDatabaseGetIAMPolicy returns the required IAM policy for database. -func (h *Handler) handleDatabaseGetIAMPolicy(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) handleDatabaseGetIAMPolicy(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { databaseName := p.ByName("database") if databaseName == "" { return nil, trace.BadParameter("missing database name") diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 3f47aa6cbe55d..830ddfde78fba 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -40,6 +40,7 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" @@ -48,8 +49,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" - "github.com/gravitational/teleport/lib/srv/desktop" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/desktop/tdp" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/scripts" @@ -61,7 +61,7 @@ func (h *Handler) desktopConnectHandle( r *http.Request, p httprouter.Params, sctx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (interface{}, error) { desktopName := p.ByName("desktopName") if desktopName == "" { @@ -95,7 +95,7 @@ func (h *Handler) createDesktopConnection( clusterName string, log *logrus.Entry, sctx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) error { upgrader := websocket.Upgrader{ ReadBufferSize: 1024, @@ -244,6 +244,16 @@ func proxyClient(ctx context.Context, sessCtx *SessionContext, addr, windowsUser return pc, nil } +const ( + // SNISuffix is the server name suffix used during SNI to specify the + // target desktop to connect to. The client (proxy_service) will use SNI + // like "${UUID}.desktop.teleport.cluster.local" to pass the UUID of the + // desktop. + // This is a copy of the same constant in `lib/srv/desktop/desktop.go` to + // prevent depending on `lib/srv` in `lib/web`. + SNISuffix = ".desktop." + constants.APIDomain +) + func desktopTLSConfig(ctx context.Context, ws *websocket.Conn, pc *client.ProxyClient, sessCtx *SessionContext, desktopName, username, siteName string) (*tls.Config, error) { pk, err := keys.ParsePrivateKey(sessCtx.cfg.Session.GetPriv()) if err != nil { @@ -308,14 +318,14 @@ func desktopTLSConfig(ctx context.Context, ws *websocket.Conn, pc *client.ProxyC } tlsConfig.Certificates = []tls.Certificate{certConf} // Pass target desktop name via SNI. - tlsConfig.ServerName = desktopName + desktop.SNISuffix + tlsConfig.ServerName = desktopName + SNISuffix return tlsConfig, nil } type connector struct { log *logrus.Entry clt auth.ClientI - site reversetunnel.RemoteSite + site reversetunnelclient.RemoteSite clientSrcAddr net.Addr clientDstAddr net.Addr } @@ -350,7 +360,7 @@ func (c *connector) tryConnect(clusterName, desktopServiceID string) (net.Conn, *c.log = *c.log.WithField("windows-service-uuid", service.GetName()) *c.log = *c.log.WithField("windows-service-addr", service.GetAddr()) - return c.site.DialTCP(reversetunnel.DialParams{ + return c.site.DialTCP(reversetunnelclient.DialParams{ From: c.clientSrcAddr, To: &utils.NetAddr{AddrNetwork: "tcp", Addr: service.GetAddr()}, ConnType: types.WindowsDesktopTunnel, diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index 27bb0e245712b..be2035288e580 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -23,7 +23,7 @@ import ( "github.com/julienschmidt/httprouter" "golang.org/x/net/websocket" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/web/desktop" ) @@ -32,7 +32,7 @@ func (h *Handler) desktopPlaybackHandle( r *http.Request, p httprouter.Params, ctx *SessionContext, - site reversetunnel.RemoteSite, + site reversetunnelclient.RemoteSite, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { diff --git a/lib/web/files.go b/lib/web/files.go index 1c5e455690067..87c3fd5ad62f9 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -33,7 +33,7 @@ import ( wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/sshutils/sftp" ) @@ -60,7 +60,7 @@ type fileTransferRequest struct { moderatedSessionID string } -func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { query := r.URL.Query() req := fileTransferRequest{ cluster: site.GetName(), diff --git a/lib/web/integrations.go b/lib/web/integrations.go index 2bc3af9ea2e04..9fdf5135c1934 100644 --- a/lib/web/integrations.go +++ b/lib/web/integrations.go @@ -25,12 +25,12 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/web/ui" ) // integrationsCreate creates an Integration -func (h *Handler) integrationsCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) integrationsCreate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { var req *ui.Integration if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) @@ -77,7 +77,7 @@ func (h *Handler) integrationsCreate(w http.ResponseWriter, r *http.Request, p h } // integrationsUpdate updates the Integration based on its name -func (h *Handler) integrationsUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) integrationsUpdate(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { integrationName := p.ByName("name") if integrationName == "" { return nil, trace.BadParameter("an integration name is required") @@ -118,7 +118,7 @@ func (h *Handler) integrationsUpdate(w http.ResponseWriter, r *http.Request, p h } // integrationsDelete removes an Integration based on its name -func (h *Handler) integrationsDelete(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) integrationsDelete(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { integrationName := p.ByName("name") if integrationName == "" { return nil, trace.BadParameter("an integration name is required") @@ -137,7 +137,7 @@ func (h *Handler) integrationsDelete(w http.ResponseWriter, r *http.Request, p h } // integrationsGet returns an Integration based on its name -func (h *Handler) integrationsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) integrationsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { integrationName := p.ByName("name") if integrationName == "" { return nil, trace.BadParameter("an integration name is required") @@ -157,7 +157,7 @@ func (h *Handler) integrationsGet(w http.ResponseWriter, r *http.Request, p http } // integrationsList returns a page of Integrations -func (h *Handler) integrationsList(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) integrationsList(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 7fe849979bcdb..ad1a4a241488b 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -24,12 +24,12 @@ import ( "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/integrations/awsoidc" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/web/ui" ) // awsOIDCListDatabases returns a list of databases using the ListDatabases action of the AWS OIDC Integration. -func (h *Handler) awsOIDCListDatabases(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) awsOIDCListDatabases(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { ctx := r.Context() var req ui.AWSOIDCListDatabasesRequest @@ -67,7 +67,7 @@ func (h *Handler) awsOIDCListDatabases(w http.ResponseWriter, r *http.Request, p } // awsOIDClientRequest receives a request to execute an action for the AWS OIDC integrations. -func (h *Handler) awsOIDCClientRequest(ctx context.Context, region string, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (*awsoidc.AWSClientRequest, error) { +func (h *Handler) awsOIDCClientRequest(ctx context.Context, region string, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (*awsoidc.AWSClientRequest, error) { integrationName := p.ByName("name") if integrationName == "" { return nil, trace.BadParameter("an integration name is required") @@ -113,7 +113,7 @@ func (h *Handler) awsOIDCClientRequest(ctx context.Context, region string, p htt } // awsOIDCDeployService deploys a Discovery Service and a Database Service in Amazon ECS. -func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { ctx := r.Context() var req ui.AWSOIDCDeployServiceRequest diff --git a/lib/web/mfa.go b/lib/web/mfa.go index e6bb1cc2fcca4..ac5e831915619 100644 --- a/lib/web/mfa.go +++ b/lib/web/mfa.go @@ -27,7 +27,7 @@ import ( "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/web/ui" ) @@ -347,7 +347,7 @@ type isMfaRequiredResponse struct { Required bool `json:"required"` } -func (h *Handler) isMFARequired(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) isMFARequired(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { var httpReq *isMFARequiredRequest if err := httplib.ReadJSON(r, &httpReq); err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/resources.go b/lib/web/resources.go index 41264e5237582..fef9289b0bba7 100644 --- a/lib/web/resources.go +++ b/lib/web/resources.go @@ -32,13 +32,13 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/web/ui" ) // checkAccessToRegisteredResource checks if calling user has access to at least one registered resource. -func (h *Handler) checkAccessToRegisteredResource(w http.ResponseWriter, r *http.Request, p httprouter.Params, c *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) checkAccessToRegisteredResource(w http.ResponseWriter, r *http.Request, p httprouter.Params, c *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { // Get a client to the Auth Server with the logged in user's identity. The // identity of the logged in user is used to fetch the list of resources. clt, err := c.GetUserClient(r.Context(), site) diff --git a/lib/web/servers.go b/lib/web/servers.go index b9b0ea968e34a..53640c625cb09 100644 --- a/lib/web/servers.go +++ b/lib/web/servers.go @@ -24,13 +24,13 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/web/ui" ) // clusterKubesGet returns a list of kube clusters in a form the UI can present. -func (h *Handler) clusterKubesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterKubesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -60,7 +60,7 @@ func (h *Handler) clusterKubesGet(w http.ResponseWriter, r *http.Request, p http // clusterKubePodsGet returns a list of Kubernetes Pods in a form the // UI can present. -func (h *Handler) clusterKubePodsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterKubePodsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.NewKubernetesServiceClient(r.Context(), h.cfg.ProxyWebAddr.Addr) if err != nil { return nil, trace.Wrap(err) @@ -79,7 +79,7 @@ func (h *Handler) clusterKubePodsGet(w http.ResponseWriter, r *http.Request, p h } // clusterDatabasesGet returns a list of db servers in a form the UI can present. -func (h *Handler) clusterDatabasesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterDatabasesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -119,7 +119,7 @@ func (h *Handler) clusterDatabasesGet(w http.ResponseWriter, r *http.Request, p } // clusterDatabaseGet returns a list of db servers in a form the UI can present. -func (h *Handler) clusterDatabaseGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterDatabaseGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { databaseName := p.ByName("database") if databaseName == "" { return nil, trace.BadParameter("database name is required") @@ -149,7 +149,7 @@ func (h *Handler) clusterDatabaseGet(w http.ResponseWriter, r *http.Request, p h } // clusterDatabaseServicesList returns a list of DatabaseServices (database agents) in a form the UI can present. -func (h *Handler) clusterDatabaseServicesList(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterDatabaseServicesList(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := ctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -173,7 +173,7 @@ func (h *Handler) clusterDatabaseServicesList(w http.ResponseWriter, r *http.Req } // clusterDesktopsGet returns a list of desktops in a form the UI can present. -func (h *Handler) clusterDesktopsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterDesktopsGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -207,7 +207,7 @@ func (h *Handler) clusterDesktopsGet(w http.ResponseWriter, r *http.Request, p h } // clusterDesktopServicesGet returns a list of desktop services in a form the UI can present. -func (h *Handler) clusterDesktopServicesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) clusterDesktopServicesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { // Get a client to the Auth Server with the logged in user's identity. The // identity of the logged in user is used to fetch the list of desktop services. clt, err := sctx.GetUserClient(r.Context(), site) @@ -233,7 +233,7 @@ func (h *Handler) clusterDesktopServicesGet(w http.ResponseWriter, r *http.Reque } // getDesktopHandle returns a desktop. -func (h *Handler) getDesktopHandle(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) getDesktopHandle(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) @@ -273,7 +273,7 @@ func (h *Handler) getDesktopHandle(w http.ResponseWriter, r *http.Request, p htt // Response body: // // {"active": bool} -func (h *Handler) desktopIsActive(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { +func (h *Handler) desktopIsActive(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) { desktopName := p.ByName("desktopName") trackers, err := h.auth.proxyClient.GetActiveSessionTrackersWithFilter(r.Context(), &types.SessionTrackerFilter{ Kind: string(types.WindowsDesktopSessionKind), diff --git a/lib/web/sessions.go b/lib/web/sessions.go index 2dc5c65a5c25d..f405a6d70edef 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -49,7 +49,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/multiplexer" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" @@ -108,7 +108,7 @@ type SessionContextConfig struct { Session types.WebSession // newRemoteClient is used by tests to override how remote clients are constructed to allow for fake sites - newRemoteClient func(ctx context.Context, sessionContext *SessionContext, site reversetunnel.RemoteSite) (auth.ClientI, error) + newRemoteClient func(ctx context.Context, sessionContext *SessionContext, site reversetunnelclient.RemoteSite) (auth.ClientI, error) } func (c *SessionContextConfig) CheckAndSetDefaults() error { @@ -219,7 +219,7 @@ func (c *SessionContext) GetClientConnection() *grpc.ClientConn { // the requested site. If the site is local a client with the users local role // is returned. If the site is remote a client with the users remote role is // returned. -func (c *SessionContext) GetUserClient(ctx context.Context, site reversetunnel.RemoteSite) (auth.ClientI, error) { +func (c *SessionContext) GetUserClient(ctx context.Context, site reversetunnelclient.RemoteSite) (auth.ClientI, error) { // if we're trying to access the local cluster, pass back the local client. if c.cfg.RootClusterName == site.GetName() { return c.cfg.RootClient, nil @@ -238,7 +238,7 @@ func (c *SessionContext) GetUserClient(ctx context.Context, site reversetunnel.R // // A [singleflight.Group] is leveraged to prevent duplicate requests for remote // clients at the same time to race. -func (c *SessionContext) remoteClient(ctx context.Context, site reversetunnel.RemoteSite) (auth.ClientI, error) { +func (c *SessionContext) remoteClient(ctx context.Context, site reversetunnelclient.RemoteSite) (auth.ClientI, error) { cltI, err, _ := c.remoteClientGroup.Do(site.GetName(), func() (interface{}, error) { // check if we already have a connection to this cluster if clt, ok := c.remoteClientCache.getRemoteClient(site); ok { @@ -274,7 +274,7 @@ func (c *SessionContext) remoteClient(ctx context.Context, site reversetunnel.Re } // newRemoteClient returns a client to a remote cluster with the role of current user. -func newRemoteClient(ctx context.Context, sctx *SessionContext, site reversetunnel.RemoteSite) (auth.ClientI, error) { +func newRemoteClient(ctx context.Context, sctx *SessionContext, site reversetunnelclient.RemoteSite) (auth.ClientI, error) { clt, err := sctx.newRemoteTLSClient(ctx, site) if err != nil { return nil, trace.Wrap(err) @@ -291,9 +291,9 @@ func newRemoteClient(ctx context.Context, sctx *SessionContext, site reversetunn } // clusterDialer returns DialContext function using cluster's dial function -func clusterDialer(remoteCluster reversetunnel.RemoteSite, src, dst net.Addr) apiclient.ContextDialer { +func clusterDialer(remoteCluster reversetunnelclient.RemoteSite, src, dst net.Addr) apiclient.ContextDialer { return apiclient.ContextDialerFunc(func(in context.Context, network, _ string) (net.Conn, error) { - dialParams := reversetunnel.DialParams{ + dialParams := reversetunnelclient.DialParams{ From: src, OriginalClientDstAddr: dst, } @@ -386,7 +386,7 @@ func (c *SessionContext) ClientTLSConfig(ctx context.Context, clusterName ...str return tlsConfig, nil } -func (c *SessionContext) newRemoteTLSClient(ctx context.Context, cluster reversetunnel.RemoteSite) (auth.ClientI, error) { +func (c *SessionContext) newRemoteTLSClient(ctx context.Context, cluster reversetunnelclient.RemoteSite) (auth.ClientI, error) { tlsConfig, err := c.ClientTLSConfig(ctx, cluster.GetName()) if err != nil { return nil, trace.Wrap(err) @@ -1143,17 +1143,17 @@ type remoteClientCache struct { sync.Mutex clients map[string]struct { auth.ClientI - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite } } -func (c *remoteClientCache) addRemoteClient(site reversetunnel.RemoteSite, remoteClient auth.ClientI) error { +func (c *remoteClientCache) addRemoteClient(site reversetunnelclient.RemoteSite, remoteClient auth.ClientI) error { c.Lock() defer c.Unlock() if c.clients == nil { c.clients = make(map[string]struct { auth.ClientI - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite }) } var err error @@ -1162,12 +1162,12 @@ func (c *remoteClientCache) addRemoteClient(site reversetunnel.RemoteSite, remot } c.clients[site.GetName()] = struct { auth.ClientI - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite }{remoteClient, site} return err } -func (c *remoteClientCache) getRemoteClient(site reversetunnel.RemoteSite) (auth.ClientI, bool) { +func (c *remoteClientCache) getRemoteClient(site reversetunnelclient.RemoteSite) (auth.ClientI, bool) { c.Lock() defer c.Unlock() remoteClt, ok := c.clients[site.GetName()] diff --git a/lib/web/sessions_test.go b/lib/web/sessions_test.go index 1187f17e35d26..8d47ab928aca4 100644 --- a/lib/web/sessions_test.go +++ b/lib/web/sessions_test.go @@ -25,7 +25,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" ) func TestRemoteClientCache(t *testing.T) { @@ -57,12 +57,12 @@ func TestRemoteClientCache(t *testing.T) { require.Zero(t, openCount.Load()) } -func newMockRemoteSite(name string) reversetunnel.RemoteSite { +func newMockRemoteSite(name string) reversetunnelclient.RemoteSite { return &mockRemoteSite{name: name} } type mockRemoteSite struct { - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite name string } @@ -98,7 +98,7 @@ func TestGetUserClient(t *testing.T) { sctx := SessionContext{ cfg: SessionContextConfig{ RootClusterName: "local", - newRemoteClient: func(ctx context.Context, sessionContext *SessionContext, site reversetunnel.RemoteSite) (auth.ClientI, error) { + newRemoteClient: func(ctx context.Context, sessionContext *SessionContext, site reversetunnelclient.RemoteSite) (auth.ClientI, error) { return newMockClientI(&openCount, nil), nil }, }, diff --git a/lib/web/sign.go b/lib/web/sign.go index 8dd2e6c5144d3..9a0a122ab4574 100644 --- a/lib/web/sign.go +++ b/lib/web/sign.go @@ -30,7 +30,7 @@ import ( "github.com/gravitational/teleport/lib/client/db" "github.com/gravitational/teleport/lib/client/identityfile" "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" ) @@ -51,7 +51,7 @@ Should be equivalent to running: This endpoint returns a tar.gz compressed archive containing the required files to setup mTLS for the database. */ -func (h *Handler) signDatabaseCertificate(w http.ResponseWriter, r *http.Request, p httprouter.Params, site reversetunnel.RemoteSite, token types.ProvisionToken) (interface{}, error) { +func (h *Handler) signDatabaseCertificate(w http.ResponseWriter, r *http.Request, p httprouter.Params, site reversetunnelclient.RemoteSite, token types.ProvisionToken) (interface{}, error) { if !token.GetRoles().Include(types.RoleDatabase) { return nil, trace.AccessDenied("required '%s' role was not provided by the token", types.RoleDatabase) } diff --git a/lib/web/ui/cluster.go b/lib/web/ui/cluster.go index e5656c0538e20..9e3015b99b40d 100644 --- a/lib/web/ui/cluster.go +++ b/lib/web/ui/cluster.go @@ -25,7 +25,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" ) @@ -50,7 +50,7 @@ type Cluster struct { } // NewClusters creates a slice of Cluster's, containing data about each cluster. -func NewClusters(remoteClusters []reversetunnel.RemoteSite) ([]Cluster, error) { +func NewClusters(remoteClusters []reversetunnelclient.RemoteSite) ([]Cluster, error) { clusters := make([]Cluster, 0, len(remoteClusters)) for _, site := range remoteClusters { // Other fields such as node count, url, and proxy/auth versions are not set @@ -87,7 +87,7 @@ func NewClustersFromRemote(remoteClusters []types.RemoteCluster) ([]Cluster, err } // GetClusterDetails retrieves and sets details about a cluster -func GetClusterDetails(ctx context.Context, site reversetunnel.RemoteSite, opts ...services.MarshalOption) (*Cluster, error) { +func GetClusterDetails(ctx context.Context, site reversetunnelclient.RemoteSite, opts ...services.MarshalOption) (*Cluster, error) { clt, err := site.CachingAccessPoint() if err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/ui/perf_test.go b/lib/web/ui/perf_test.go index 04326a83f49ee..a781eb262cbd9 100644 --- a/lib/web/ui/perf_test.go +++ b/lib/web/ui/perf_test.go @@ -31,7 +31,7 @@ import ( "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/memory" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/local" ) @@ -143,7 +143,7 @@ func insertServers(ctx context.Context, b *testing.B, svc services.Presence, kin } } -func benchmarkGetClusterDetails(ctx context.Context, b *testing.B, site reversetunnel.RemoteSite, nodes int, opts ...services.MarshalOption) { +func benchmarkGetClusterDetails(ctx context.Context, b *testing.B, site reversetunnelclient.RemoteSite, nodes int, opts ...services.MarshalOption) { var cluster *Cluster var err error for i := 0; i < b.N; i++ { @@ -155,7 +155,7 @@ func benchmarkGetClusterDetails(ctx context.Context, b *testing.B, site reverset } type mockRemoteSite struct { - reversetunnel.RemoteSite + reversetunnelclient.RemoteSite accessPoint auth.ProxyAccessPoint } diff --git a/lib/web/user_groups.go b/lib/web/user_groups.go index 432b1d27eca2e..c2a1bc02e7b51 100644 --- a/lib/web/user_groups.go +++ b/lib/web/user_groups.go @@ -30,11 +30,11 @@ import ( "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/web/ui" ) -func (h *Handler) getUserGroups(_ http.ResponseWriter, r *http.Request, params httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (any, error) { +func (h *Handler) getUserGroups(_ http.ResponseWriter, r *http.Request, params httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (any, error) { // Get a client to the Auth Server with the logged in user's identity. The // identity of the logged in user is used to fetch the list of nodes. clt, err := sctx.GetUserClient(r.Context(), site) diff --git a/tool/tsh/common/app_test.go b/tool/tsh/common/app_test.go index 772c2e40515dd..96ebeae13ac41 100644 --- a/tool/tsh/common/app_test.go +++ b/tool/tsh/common/app_test.go @@ -36,7 +36,7 @@ import ( "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/client" defaults2 "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" ) @@ -86,7 +86,7 @@ func TestAppLoginLeaf(t *testing.T) { rootAuth, rootProxy := makeTestServers(t, withClusterName(t, "root"), withBootstrap(connector, alice), withConfig(configStorage)) event, err := rootAuth.WaitForEventTimeout(time.Second, service.ProxyReverseTunnelReady) require.NoError(t, err) - tunnel, ok := event.Payload.(reversetunnel.Server) + tunnel, ok := event.Payload.(reversetunnelclient.Server) require.True(t, ok) rootAppURL := startDummyHTTPServer(t, "rootapp")