diff --git a/integration/port_forwarding_test.go b/integration/port_forwarding_test.go index e797e9bc700eb..50907b8ab9025 100644 --- a/integration/port_forwarding_test.go +++ b/integration/port_forwarding_test.go @@ -22,10 +22,12 @@ import ( "net/http" "net/http/httptest" "net/url" + "os/user" "strconv" "testing" "time" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -71,19 +73,49 @@ func waitForSessionToBeEstablished(ctx context.Context, namespace string, site a } func testPortForwarding(t *testing.T, suite *integrationTestSuite) { + invalidOSLogin := uuid.NewString()[:12] + notFound := false + for i := 0; i < 10; i++ { + if _, err := user.Lookup(invalidOSLogin); err == nil { + invalidOSLogin = uuid.NewString()[:12] + continue + } + notFound = true + break + } + require.True(t, notFound, "unable to locate invalid os user") + + // Providing our own logins to Teleport so we can verify that a user + // that exists within Teleport but does not exist on the local node + // cannot port forward. + logins := []string{ + invalidOSLogin, + suite.Me.Username, + } + testCases := []struct { desc string portForwardingAllowed bool expectSuccess bool + login string }{ { desc: "Enabled", portForwardingAllowed: true, expectSuccess: true, - }, { + login: suite.Me.Username, + }, + { desc: "Disabled", portForwardingAllowed: false, expectSuccess: false, + login: suite.Me.Username, + }, + { + desc: "Enabled with invalid user", + portForwardingAllowed: true, + expectSuccess: false, + login: invalidOSLogin, }, } @@ -106,7 +138,7 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { cfg.SSH.Enabled = true cfg.SSH.AllowTCPForwarding = tt.portForwardingAllowed - teleport := suite.NewTeleportWithConfig(t, nil, nil, cfg) + teleport := suite.NewTeleportWithConfig(t, logins, nil, cfg) defer teleport.StopAll() site := teleport.GetSiteAPI(helpers.Site) @@ -127,7 +159,7 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { nodeSSHPort := helpers.Port(t, teleport.SSH) cl, err := teleport.NewClient(helpers.ClientConfig{ - Login: suite.Me.Username, + Login: tt.login, Cluster: helpers.Site, Host: Host, Port: nodeSSHPort, diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index a5230ae2eddc9..248f168f5f9c4 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -748,7 +748,7 @@ func definitionForBuiltinRole(clusterName string, recConfig types.SessionRecordi types.NewRule(types.KindRole, services.RO()), types.NewRule(types.KindNamespace, services.RO()), types.NewRule(types.KindLock, services.RO()), - types.NewRule(types.KindKubernetesCluster, services.RW()), + types.NewRule(types.KindKubernetesCluster, services.RO()), types.NewRule(types.KindSemaphore, services.RW()), }, }, diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index c52f5ace8fc52..3f1ac663c6274 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1622,6 +1622,34 @@ func TestApplicationServers(t *testing.T) { }) } +// TestKubernetesServers tests that CRUD operations on kube servers are +// replicated from the backend to the cache. +func TestKubernetesServers(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForProxy) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.KubeServer]{ + newResource: func(name string) (types.KubeServer, error) { + app, err := types.NewKubernetesClusterV3(types.Metadata{Name: name}, types.KubernetesClusterSpecV3{}) + require.NoError(t, err) + return types.NewKubernetesServerV3FromCluster(app, "host", uuid.New().String()) + }, + create: withKeepalive(p.presenceS.UpsertKubernetesServer), + list: func(ctx context.Context) ([]types.KubeServer, error) { + return p.presenceS.GetKubernetesServers(ctx) + }, + cacheList: func(ctx context.Context) ([]types.KubeServer, error) { + return p.cache.GetKubernetesServers(ctx) + }, + update: withKeepalive(p.presenceS.UpsertKubernetesServer), + deleteAll: func(ctx context.Context) error { + return p.presenceS.DeleteAllKubernetesServers(ctx) + }, + }) +} + // TestApps tests that CRUD operations on application resources are // replicated from the backend to the cache. func TestApps(t *testing.T) { diff --git a/lib/client/api.go b/lib/client/api.go index a4af7ff464b1e..c93d0809e5ac0 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "io" "net" @@ -1587,16 +1588,27 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, nod return trace.Wrap(err) } - // If no remote command execution was requested, block on the context which - // will unblock upon error or SIGINT. + // If no remote command execution was requested block on which ever comes first: + // 1) the context which will unblock upon error or user terminating the process + // 2) ssh.Client.Wait which will unblock when the connection has shut down if tc.NoRemoteExec { - log.Debugf("Connected to node, no remote command execution was requested, blocking until context closes.") - <-ctx.Done() - - // Only return an error if the context was canceled by something other than SIGINT. - if ctx.Err() != context.Canceled { - return ctx.Err() + connClosed := make(chan error, 1) + go func() { + connClosed <- nodeClient.Client.Wait() + }() + log.Debugf("Connected to node, no remote command execution was requested, blocking indefinitely.") + select { + case <-ctx.Done(): + // Only return an error if the context was canceled by something other than SIGINT. + if err := ctx.Err(); !errors.Is(err, context.Canceled) { + return trace.Wrap(err) + } + case err := <-connClosed: + if !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } } + return nil } diff --git a/lib/client/client.go b/lib/client/client.go index 9914b57fd7cf4..b5693ccbac160 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -48,6 +48,7 @@ import ( tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -1840,75 +1841,52 @@ func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error } type netDialer interface { - Dial(string, string) (net.Conn, error) + DialContext(context.Context, string, string) (net.Conn, error) } func proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string, dialer netDialer) error { defer conn.Close() defer log.Debugf("Finished proxy from %v to %v.", conn.RemoteAddr(), remoteAddr) - var ( - remoteConn net.Conn - err error - ) - + var remoteConn net.Conn log.Debugf("Attempting to connect proxy from %v to %v.", conn.RemoteAddr(), remoteAddr) - for attempt := 1; attempt <= 5; attempt++ { - remoteConn, err = dialer.Dial("tcp", remoteAddr) - if err != nil { - log.Debugf("Proxy connection attempt %v: %v.", attempt, err) - - timer := time.NewTimer(time.Duration(100*attempt) * time.Millisecond) - defer timer.Stop() - - // Wait and attempt to connect again, if the context has closed, exit - // right away. - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case <-timer.C: - continue - } - } - // Connection established, break out of the loop. - break - } + + retry, err := retryutils.NewLinear(retryutils.LinearConfig{ + First: 100 * time.Millisecond, + Step: 100 * time.Millisecond, + Max: time.Second, + Jitter: retryutils.NewHalfJitter(), + }) if err != nil { - return trace.BadParameter("failed to connect to node: %v", remoteAddr) + return trace.Wrap(err) } - defer remoteConn.Close() - - // Start proxying, close the connection if a problem occurs on either leg. - errCh := make(chan error, 2) - go func() { - defer conn.Close() - defer remoteConn.Close() - _, err := io.Copy(conn, remoteConn) - errCh <- err - }() - go func() { - defer conn.Close() - defer remoteConn.Close() - - _, err := io.Copy(remoteConn, conn) - errCh <- err - }() + for attempt := 1; attempt <= 5; attempt++ { + conn, err := dialer.DialContext(ctx, "tcp", remoteAddr) + if err == nil { + // Connection established, break out of the loop. + remoteConn = conn + break + } - var errs []error - for i := 0; i < 2; i++ { + log.Debugf("Proxy connection attempt %v: %v.", attempt, err) + // Wait and attempt to connect again, if the context has closed, exit + // right away. select { - case err := <-errCh: - if err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - log.Warnf("Failed to proxy connection: %v.", err) - errs = append(errs, err) - } case <-ctx.Done(): return trace.Wrap(ctx.Err()) + case <-retry.After(): + retry.Inc() + continue } } + if remoteConn == nil { + return trace.BadParameter("failed to connect to node: %v", remoteAddr) + } + defer remoteConn.Close() - return trace.NewAggregate(errs...) + // Start proxying, close the connection if a problem occurs on either leg. + return trace.Wrap(utils.ProxyConn(ctx, remoteConn, conn)) } // acceptWithContext calls "Accept" on the listener but will unblock when the diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index a5fedc5c82bf1..daba8a21f8771 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -352,8 +352,19 @@ type Forwarder struct { sessions map[uuid.UUID]*session // upgrades connections to websockets upgrader websocket.Upgrader + // getKubernetesServersForKubeCluster is a function that returns a list of + // kubernetes servers for a given kube cluster but uses different methods + // depending on the service type. + // For example, if the service type is KubeService, it will use the + // local kubernetes clusters. If the service type is Proxy, it will + // use the heartbeat clusters. + getKubernetesServersForKubeCluster getKubeServersByNameFunc } +// getKubeServersByNameFunc is a function that returns a list of +// kubernetes servers for a given kube cluster. +type getKubeServersByNameFunc = func(ctx context.Context, name string) ([]types.KubeServer, error) + // Close signals close to all outstanding or background operations // to complete func (f *Forwarder) Close() error { @@ -396,6 +407,9 @@ type authContext struct { kubeResource *types.KubernetesResource // httpMethod is the request HTTP Method. httpMethod string + // kubeServers are the registered agents for the kubernetes cluster the request + // is targeted to. + kubeServers []types.KubeServer } func (c authContext) String() string { @@ -729,7 +743,9 @@ func (f *Forwarder) setupContext(ctx context.Context, authCtx authz.Context, req } kubeCluster := identity.KubernetesCluster - if !isRemoteCluster { + // Only set a default kube cluster if the user is not accessing a specific cluster. + // The check for kubeCluster != "" is happens in the next code section. + if !isRemoteCluster && kubeCluster == "" { kc, err := kubeutils.CheckOrSetKubeCluster(ctx, f.cfg.CachingAuthClient, identity.KubernetesCluster, teleportClusterName) if err != nil { if !trace.IsNotFound(err) { @@ -746,14 +762,20 @@ func (f *Forwarder) setupContext(ctx context.Context, authCtx authz.Context, req var ( kubeUsers, kubeGroups []string kubeLabels map[string]string + kubeServers []types.KubeServer + err error ) // Only check k8s principals for local clusters. // // For remote clusters, everything will be remapped to new roles on the // leaf and checked there. if !isRemoteCluster { + kubeServers, err = f.getKubernetesServersForKubeCluster(ctx, kubeCluster) + if err != nil || len(kubeServers) == 0 { + return nil, trace.NotFound("cluster %q not found", kubeCluster) + } // check signing TTL and return a list of allowed logins for local cluster based on Kubernetes service labels. - kubeAccessDetails, err := f.getKubeAccessDetails(roles, kubeCluster, sessionTTL, kubeResource) + kubeAccessDetails, err := f.getKubeAccessDetails(kubeServers, roles, kubeCluster, sessionTTL, kubeResource) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) // roles.CheckKubeGroupsAndUsers returns trace.NotFound if the user does @@ -924,7 +946,8 @@ func (f *Forwarder) setupContext(ctx context.Context, authCtx authz.Context, req isRemote: isRemoteCluster, isRemoteClosed: isRemoteClosed, }, - httpMethod: req.Method, + httpMethod: req.Method, + kubeServers: kubeServers, }, nil } @@ -1008,16 +1031,12 @@ type kubeAccessDetails struct { // getKubeAccessDetails returns the allowed kube groups/users names and the cluster labels for a local kube cluster. func (f *Forwarder) getKubeAccessDetails( + kubeServers []types.KubeServer, roles services.AccessChecker, kubeClusterName string, sessionTTL time.Duration, kubeResource *types.KubernetesResource, ) (kubeAccessDetails, error) { - kubeServers, err := f.cfg.CachingAuthClient.GetKubernetesServers(f.ctx) - if err != nil { - return kubeAccessDetails{}, trace.Wrap(err) - } - // Find requested kubernetes cluster name and get allowed kube users/groups names. for _, s := range kubeServers { c := s.GetCluster() @@ -1123,10 +1142,7 @@ func (f *Forwarder) authorize(ctx context.Context, actx *authContext) error { f.log.WithField("auth_context", actx.String()).Debug("Skipping authorization due to unknown kubernetes cluster name") return nil } - servers, err := f.cfg.CachingAuthClient.GetKubernetesServers(ctx) - if err != nil { - return trace.Wrap(err) - } + authPref, err := f.cfg.CachingAuthClient.GetAuthPreference(ctx) if err != nil { return trace.Wrap(err) @@ -1153,7 +1169,7 @@ func (f *Forwarder) authorize(ctx context.Context, actx *authContext) error { // // We assume that users won't register two identically-named clusters with // mis-matched labels. If they do, expect weirdness. - for _, s := range servers { + for _, s := range actx.kubeServers { ks := s.GetCluster() if ks.GetName() != actx.kubeClusterName { continue @@ -2281,11 +2297,7 @@ func (f *Forwarder) newClusterSessionSameCluster(ctx context.Context, authCtx au return sess, nil } - kubeServers, err := f.cfg.CachingAuthClient.GetKubernetesServers(f.ctx) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - + kubeServers := authCtx.kubeServers if len(kubeServers) == 0 && authCtx.kubeClusterName == authCtx.teleportCluster.name { return nil, trace.Wrap(localErr) } @@ -2314,12 +2326,8 @@ func (f *Forwarder) newClusterSessionSameCluster(ctx context.Context, authCtx au } func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, error) { - if len(f.clusterDetails) == 0 { - return nil, trace.NotFound("this Teleport process is not configured for direct Kubernetes access; you likely need to 'tsh login' into a leaf cluster or 'tsh kube login' into a different kubernetes cluster") - } - - details, ok := f.clusterDetails[ctx.kubeClusterName] - if !ok { + details, err := f.findKubeDetailsByClusterName(ctx.kubeClusterName) + if err != nil { return nil, trace.NotFound("kubernetes cluster %q not found", ctx.kubeClusterName) } diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index c53aaeb0648b4..9aea409cf1f02 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -163,6 +163,19 @@ func TestAuthenticate(t *testing.T) { TracerProvider: otel.GetTracerProvider(), tracer: otel.Tracer(teleport.ComponentKube), }, + getKubernetesServersForKubeCluster: func(ctx context.Context, name string) ([]types.KubeServer, error) { + servers, err := ap.GetKubernetesServers(ctx) + if err != nil { + return nil, err + } + var filtered []types.KubeServer + for _, server := range servers { + if server.GetCluster().GetName() == name { + filtered = append(filtered, server) + } + } + return filtered, nil + }, } const remoteAddr = "user.example.com" @@ -220,6 +233,21 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -243,6 +271,30 @@ func TestAuthenticate(t *testing.T) { DynamicLabels: map[string]types.CommandLabelV2{}, }, }, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "foo", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "bar", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, ), wantCtx: &authContext{ kubeUsers: utils.StringsSet([]string{"user-a"}), @@ -257,6 +309,21 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -289,6 +356,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -320,6 +399,19 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -352,6 +444,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -453,6 +557,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -501,6 +617,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -559,6 +687,21 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "foo", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -958,6 +1101,8 @@ func TestNewClusterSessionDirect(t *testing.T) { f.cfg.CachingAuthClient = mockAccessPoint{ kubeServers: []types.KubeServer{publicKubeService, otherKubeService, tunnelKubeService, otherKubeService}, } + authCtx.kubeServers, err = f.cfg.CachingAuthClient.GetKubernetesServers(context.Background()) + require.NoError(t, err) sess, err := f.newClusterSession(ctx, authCtx) require.NoError(t, err) require.Equal(t, []kubeClusterEndpoint{publicEndpoint, tunnelEndpoint}, sess.kubeClusterEndpoints) diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index c8a530b7b5e86..50012e7a67f6c 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -151,8 +151,8 @@ type TLSServer struct { heartbeats map[string]*srv.Heartbeat closeContext context.Context closeFunc context.CancelFunc - // watcher monitors changes to kube cluster resources. - watcher *services.KubeClusterWatcher + // kubeClusterWatcher monitors changes to kube cluster resources. + kubeClusterWatcher *services.KubeClusterWatcher // reconciler reconciles proxied kube clusters with kube_clusters resources. reconciler *services.Reconciler // monitoredKubeClusters contains all kube clusters the proxied kube_clusters are @@ -229,6 +229,11 @@ func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) { } server.TLS.GetConfigForClient = server.GetConfigForClient server.closeContext, server.closeFunc = context.WithCancel(cfg.Context) + // register into the forwarder the method to get kubernetes servers for a kube cluster. + server.fwd.getKubernetesServersForKubeCluster, err = server.getKubernetesServersForKubeClusterFunc() + if err != nil { + return nil, trace.Wrap(err) + } return server, nil } @@ -282,7 +287,9 @@ func (t *TLSServer) Serve(listener net.Listener) error { // Initialize watcher that will be dynamically (un-)registering // proxied clusters based on the kube_cluster resources. - if t.watcher, err = t.startResourceWatcher(t.closeContext); err != nil { + // This watcher is only started for the kube_service if a resource watcher + // is configured. + if t.kubeClusterWatcher, err = t.startKubeClusterResourceWatcher(t.closeContext); err != nil { return trace.Wrap(err) } @@ -314,8 +321,8 @@ func (t *TLSServer) close(ctx context.Context) error { t.closeFunc() // Stop the kube_cluster resource watcher. - if t.watcher != nil { - t.watcher.Close() + if t.kubeClusterWatcher != nil { + t.kubeClusterWatcher.Close() } t.mu.Lock() listClose := t.listener.Close() @@ -350,7 +357,7 @@ func (t *TLSServer) getServerInfo(name string) (types.Resource, error) { addr = t.listener.Addr().String() } - cluster, err := t.getKubeClusterForHeartbeat(name) + cluster, err := t.getKubeClusterWithServiceLabels(name) if err != nil { return nil, trace.Wrap(err) } @@ -385,12 +392,12 @@ func (t *TLSServer) getServerInfo(name string) (types.Resource, error) { return srv, nil } -// getKubeClusterForHeartbeat finds the kube cluster by name, strips the credentials, +// getKubeClusterWithServiceLabels finds the kube cluster by name, strips the credentials, // replaces the cluster dynamic labels with their latest value available and updates // the cluster with the service dynamic and static labels. // We strip the Azure, AWS and Kubeconfig credentials so they are not leaked when // heartbeating the cluster. -func (t *TLSServer) getKubeClusterForHeartbeat(name string) (*types.KubernetesClusterV3, error) { +func (t *TLSServer) getKubeClusterWithServiceLabels(name string) (*types.KubernetesClusterV3, error) { // it is safe do read from details since the structure is never updated. // we replace the whole structure each time an update happens to a dynamic cluster. details, err := t.fwd.findKubeDetailsByClusterName(name) @@ -524,3 +531,59 @@ func (t *TLSServer) setServiceLabels(cluster types.KubeCluster) { cluster.SetDynamicLabels(dstDynLabels) } } + +// getKubernetesServersForKubeClusterFunc returns a function that returns the kubernetes servers +// for a given kube cluster depending on the type of service. +func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNameFunc, error) { + switch t.KubeServiceType { + case KubeService: + return func(_ context.Context, name string) ([]types.KubeServer, error) { + // If this is a kube_service, we can just return the local kube servers. + kube, err := t.getKubeClusterWithServiceLabels(name) + if err != nil { + return nil, trace.Wrap(err) + } + srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID) + if err != nil { + return nil, trace.Wrap(err) + } + return []types.KubeServer{srv}, nil + }, nil + case ProxyService: + return t.getAuthKubeServers, nil + case LegacyProxyService: + return func(ctx context.Context, name string) ([]types.KubeServer, error) { + kube, err := t.getKubeClusterWithServiceLabels(name) + if err != nil { + servers, err := t.getAuthKubeServers(ctx, name) + return servers, trace.Wrap(err) + } + srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID) + if err != nil { + return nil, trace.Wrap(err) + } + return []types.KubeServer{srv}, nil + }, nil + default: + return nil, trace.BadParameter("unknown kubernetes service type %q", t.KubeServiceType) + } +} + +// getAuthKubeServers returns the kubernetes servers for a given kube cluster +// using the Auth server client. +func (t *TLSServer) getAuthKubeServers(ctx context.Context, name string) ([]types.KubeServer, error) { + servers, err := t.CachingAuthClient.GetKubernetesServers(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + var returnServers []types.KubeServer + for _, server := range servers { + if server.GetCluster().GetName() == name { + returnServers = append(returnServers, server) + } + } + if len(returnServers) == 0 { + return nil, trace.NotFound("no kubernetes servers found for cluster %q", name) + } + return returnServers, nil +} diff --git a/lib/kube/proxy/utils_testing.go b/lib/kube/proxy/utils_testing.go index 29964eda0c515..9bd78c249a744 100644 --- a/lib/kube/proxy/utils_testing.go +++ b/lib/kube/proxy/utils_testing.go @@ -172,7 +172,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // heartbeatsWaitChannel waits for clusters heartbeats to start. heartbeatsWaitChannel := make(chan struct{}, len(cfg.Clusters)+1) - + client := newAuthClientWithStreamer(testCtx) // Create kubernetes service server. testCtx.KubeServer, err = NewTLSServer(TLSServerConfig{ ForwarderConfig: ForwarderConfig{ @@ -186,12 +186,12 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // directly to AuthClient solves the issue. // We wrap the AuthClient with an events.TeeStreamer to send non-disk // events like session.end to testCtx.emitter as well. - AuthClient: newAuthClientWithStreamer(testCtx), + AuthClient: client, // StreamEmitter is required although not used because we are using // "node-sync" as session recording mode. StreamEmitter: testCtx.Emitter, DataDir: t.TempDir(), - CachingAuthClient: testCtx.AuthClient, + CachingAuthClient: client, HostID: testCtx.HostID, Context: testCtx.Context, KubeconfigPath: kubeConfigLocation, @@ -206,7 +206,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo }, DynamicLabels: nil, TLS: tlsConfig, - AccessPoint: testCtx.AuthClient, + AccessPoint: client, LimiterConfig: limiter.Config{ MaxConnections: 1000, MaxNumberOfUsers: 1000, diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go index 44a093bb9950a..04373f33d02be 100644 --- a/lib/kube/proxy/watcher.go +++ b/lib/kube/proxy/watcher.go @@ -82,9 +82,9 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) { return nil } -// startResourceWatcher starts watching changes to Kube Clusters resources and +// startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and // registers/unregisters the proxied Kube Cluster accordingly. -func (s *TLSServer) startResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) { +func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) { if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService { s.log.Debug("Not initializing Kube Cluster resource watcher.") return nil, nil diff --git a/lib/services/kubernetes.go b/lib/services/kubernetes.go index b8571220ec2a9..cd719f6d12d6f 100644 --- a/lib/services/kubernetes.go +++ b/lib/services/kubernetes.go @@ -34,8 +34,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// KubernetesGetter defines interface for fetching kubernetes cluster resources. -type KubernetesGetter interface { +// KubernetesClusterGetter defines interface for fetching kubernetes cluster resources. +type KubernetesClusterGetter interface { // GetKubernetesClusters returns all kubernetes cluster resources. GetKubernetesClusters(context.Context) ([]types.KubeCluster, error) // GetKubernetesCluster returns the specified kubernetes cluster resource. @@ -45,7 +45,7 @@ type KubernetesGetter interface { // Kubernetes defines an interface for managing kubernetes clusters resources. type Kubernetes interface { // KubernetesGetter provides methods for fetching kubernetes resources. - KubernetesGetter + KubernetesClusterGetter // CreateKubernetesCluster creates a new kubernetes cluster resource. CreateKubernetesCluster(context.Context, types.KubeCluster) error // UpdateKubernetesCluster updates an existing kubernetes cluster resource. diff --git a/lib/services/watcher.go b/lib/services/watcher.go index c89b4ff7762eb..015d9dc05a595 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -1016,7 +1016,7 @@ type KubeClusterWatcherConfig struct { // ResourceWatcherConfig is the resource watcher configuration. ResourceWatcherConfig // KubernetesGetter is responsible for fetching kube_cluster resources. - KubernetesGetter + KubernetesClusterGetter // KubeClustersC receives up-to-date list of all kube_cluster resources. KubeClustersC chan types.KubeClusters } @@ -1026,12 +1026,12 @@ func (cfg *KubeClusterWatcherConfig) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } - if cfg.KubernetesGetter == nil { - getter, ok := cfg.Client.(KubernetesGetter) + if cfg.KubernetesClusterGetter == nil { + getter, ok := cfg.Client.(KubernetesClusterGetter) if !ok { return trace.BadParameter("missing parameter KubernetesGetter and Client not usable as KubernetesGetter") } - cfg.KubernetesGetter = getter + cfg.KubernetesClusterGetter = getter } if cfg.KubeClustersC == nil { cfg.KubeClustersC = make(chan types.KubeClusters) @@ -1087,7 +1087,7 @@ func (k *kubeCollector) resourceKind() string { // getResourcesAndUpdateCurrent refreshes the list of current resources. func (k *kubeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - clusters, err := k.KubernetesGetter.GetKubernetesClusters(ctx) + clusters, err := k.KubernetesClusterGetter.GetKubernetesClusters(ctx) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 7fd0a11cada84..375414aaf2710 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -18,6 +18,7 @@ package srv import ( "context" + "encoding/json" "fmt" "io" "net" @@ -182,6 +183,48 @@ type Server interface { TargetMetadata() apievents.ServerMetadata } +// childProcessError is used to provide an underlying error +// from a re-executed Teleport child process to its parent. +type childProcessError struct { + Code int `json:"code"` + RawError []byte `json:"rawError"` +} + +// writeChildError encodes the provided error +// as json and writes it to w. Special care +// is taken to preserve the error type by +// including the error code and raw message +// so that [DecodeChildError] will return +// the matching error type and message. +func writeChildError(w io.Writer, err error) { + if w == nil || err == nil { + return + } + + data, jerr := json.Marshal(err) + if jerr != nil { + return + } + + _ = json.NewEncoder(w).Encode(childProcessError{ + Code: trace.ErrorToCode(err), + RawError: data, + }) + +} + +// DecodeChildError consumes the output from a child +// process decoding it from its raw form back into +// a concrete error. +func DecodeChildError(r io.Reader) error { + var c childProcessError + if err := json.NewDecoder(r).Decode(&c); err != nil { + return nil + } + + return trace.ReadError(c.Code, c.RawError) +} + // IdentityContext holds all identity information associated with the user // logged on the connection. type IdentityContext struct { @@ -374,6 +417,12 @@ type ServerContext struct { x11rdyr *os.File x11rdyw *os.File + // err{r,w} is used to propagate errors from the child process to the + // parent process so the parent can get more information about why the child + // process failed and act accordingly. + errr *os.File + errw *os.File + // x11Config holds the xauth and XServer listener config for this session. x11Config *X11Config @@ -523,6 +572,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s child.AddCloser(child.x11rdyr) child.AddCloser(child.x11rdyw) + // Create pipe used to get errors from the child process. + child.errr, child.errw, err = os.Pipe() + if err != nil { + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) + } + child.AddCloser(child.errr) + child.AddCloser(child.errw) + return ctx, child, nil } @@ -833,6 +891,11 @@ func (c *ServerContext) x11Ready() (bool, error) { return true, nil } +// GetChildError returns the error from the child process +func (c *ServerContext) GetChildError() error { + return DecodeChildError(c.errr) +} + // takeClosers returns all resources that should be closed and sets the properties to null // we do this to avoid calling Close() under lock to avoid potential deadlocks func (c *ServerContext) takeClosers() []io.Closer { diff --git a/lib/srv/ctx_test.go b/lib/srv/ctx_test.go index 809fa5d2a439f..0143dd3cf952b 100644 --- a/lib/srv/ctx_test.go +++ b/lib/srv/ctx_test.go @@ -17,10 +17,13 @@ limitations under the License. package srv import ( + "bytes" + "os/user" "testing" "github.com/gogo/protobuf/proto" "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/protobuf/testing/protocmp" @@ -31,6 +34,19 @@ import ( "github.com/gravitational/teleport/lib/services" ) +// TestDecodeChildError ensures that child error message marshaling +// and unmarshaling returns the original values. +func TestDecodeChildError(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, DecodeChildError(&buf)) + + targetErr := trace.NotFound(user.UnknownUserError("test").Error()) + + writeChildError(&buf, targetErr) + + require.ErrorIs(t, DecodeChildError(&buf), targetErr) +} + func TestCheckSFTPAllowed(t *testing.T) { srv := newMockServer(t) ctx := newTestServerContext(t, srv, nil) diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index 205871addf9fe..8488cde31d81b 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -68,6 +68,9 @@ const ( // X11File is used to communicate to the parent process that the child // process has set up X11 forwarding. X11File + // ErrorFile is used to communicate any errors terminating the child process + // to the parent process + ErrorFile // PTYFile is a PTY the parent process passes to the child process. PTYFile // TTYFile is a TTY the parent process passes to the child process. @@ -75,9 +78,13 @@ const ( // FirstExtraFile is the first file descriptor that will be valid when // extra files are passed to child processes without a terminal. - FirstExtraFile = X11File + 1 + FirstExtraFile FileFD = ErrorFile + 1 ) +func fdName(f FileFD) string { + return fmt.Sprintf("/proc/self/fd/%d", f) +} + // ExecCommand contains the payload to "teleport exec" which will be used to // construct and execute a shell. type ExecCommand struct { @@ -191,29 +198,23 @@ func RunCommand() (errw io.Writer, code int, err error) { errorWriter := os.Stdout // Parent sends the command payload in the third file descriptor. - cmdfd := os.NewFile(CommandFile, fmt.Sprintf("/proc/self/fd/%d", CommandFile)) + cmdfd := os.NewFile(CommandFile, fdName(CommandFile)) if cmdfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("command pipe not found") } - contfd := os.NewFile(ContinueFile, fmt.Sprintf("/proc/self/fd/%d", ContinueFile)) + contfd := os.NewFile(ContinueFile, fdName(ContinueFile)) if contfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue pipe not found") } - termiantefd := os.NewFile(TerminateFile, fmt.Sprintf("/proc/self/fd/%d", TerminateFile)) + termiantefd := os.NewFile(TerminateFile, fdName(TerminateFile)) if termiantefd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("terminate pipe not found") } // Read in the command payload. - var b bytes.Buffer - _, err = b.ReadFrom(cmdfd) - if err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) - } var c ExecCommand - err = json.Unmarshal(b.Bytes(), &c) - if err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) + if err := json.NewDecoder(cmdfd).Decode(&c); err != nil { + return io.Discard, teleport.RemoteCommandFailure, trace.Wrap(err) } auditdMsg := auditd.Message{ @@ -251,8 +252,8 @@ func RunCommand() (errw io.Writer, code int, err error) { // PTY and TTY. Extract them and set the controlling TTY. Otherwise, connect // std{in,out,err} directly. if c.Terminal { - pty = os.NewFile(PTYFile, fmt.Sprintf("/proc/self/fd/%d", PTYFile)) - tty = os.NewFile(TTYFile, fmt.Sprintf("/proc/self/fd/%d", TTYFile)) + pty = os.NewFile(PTYFile, fdName(PTYFile)) + tty = os.NewFile(TTYFile, fdName(TTYFile)) if pty == nil || tty == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("pty and tty not found") } @@ -391,7 +392,7 @@ func RunCommand() (errw io.Writer, code int, err error) { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", x11.DisplayEnv, c.X11Config.XAuthEntry.Display.String())) // Open x11rdy fd to signal parent process once X11 forwarding is set up. - x11rdyfd := os.NewFile(X11File, fmt.Sprintf("/proc/self/fd/%d", X11File)) + x11rdyfd := os.NewFile(X11File, fdName(X11File)) if x11rdyfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue pipe not found") } @@ -568,20 +569,24 @@ func RunForward() (errw io.Writer, code int, err error) { errorWriter := os.Stderr // Parent sends the command payload in the third file descriptor. - cmdfd := os.NewFile(CommandFile, fmt.Sprintf("/proc/self/fd/%d", CommandFile)) + cmdfd := os.NewFile(CommandFile, fdName(CommandFile)) if cmdfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("command pipe not found") } - // Read in the command payload. - var b bytes.Buffer - _, err = b.ReadFrom(cmdfd) - if err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) + // Parent receives any errors on the sixth file descriptor. + errfd := os.NewFile(ErrorFile, fdName(ErrorFile)) + if errfd == nil { + return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("error pipe not found") } + + defer func() { + writeChildError(errfd, err) + }() + + // Read in the command payload. var c ExecCommand - err = json.Unmarshal(b.Bytes(), &c) - if err != nil { + if err := json.NewDecoder(cmdfd).Decode(&c); err != nil { return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) } @@ -607,6 +612,10 @@ func RunForward() (errw io.Writer, code int, err error) { defer pamContext.Close() } + if _, err := user.Lookup(c.Login); err != nil { + return errorWriter, teleport.RemoteCommandFailure, trace.NotFound(err.Error()) + } + // Connect to the target host. conn, err := net.Dial("tcp", c.DestinationAddress) if err != nil { @@ -614,33 +623,12 @@ func RunForward() (errw io.Writer, code int, err error) { } defer conn.Close() - // Start copy routines that copy from channel to stdin pipe and from stdout - // pipe to channel. - errorCh := make(chan error, 2) - go func() { - defer conn.Close() - defer os.Stdout.Close() - defer os.Stdin.Close() - - _, err := io.Copy(os.Stdout, conn) - errorCh <- err - }() - go func() { - defer conn.Close() - defer os.Stdout.Close() - defer os.Stdin.Close() - - _, err := io.Copy(conn, os.Stdin) - errorCh <- err - }() - - // Block until copy is complete in either direction. The other direction - // will get cleaned up automatically. - if err = <-errorCh; err != nil && err != io.EOF { + err = utils.ProxyConn(context.Background(), utils.CombineReadWriteCloser(os.Stdin, os.Stdout), conn) + if err != nil && !errors.Is(err, io.EOF) { return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) } - return io.Discard, teleport.RemoteCommandSuccess, nil + return errorWriter, teleport.RemoteCommandSuccess, nil } // runCheckHomeDir check's if the active user's $HOME dir exists. @@ -877,11 +865,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er cmdmsg.ExtraFilesLen = len(extraFiles) } - cmdbytes, err := json.Marshal(cmdmsg) - if err != nil { - return nil, trace.Wrap(err) - } - go copyCommand(ctx, cmdbytes) + go copyCommand(ctx, cmdmsg) // Find the Teleport executable and its directory on disk. executable, err := os.Executable() @@ -911,6 +895,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er ctx.contr, ctx.killShellr, ctx.x11rdyw, + ctx.errw, }, } // Add extra files if applicable. @@ -926,7 +911,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er // copyCommand will copy the provided command to the child process over the // pipe attached to the context. -func copyCommand(ctx *ServerContext, cmdbytes []byte) { +func copyCommand(ctx *ServerContext, cmdmsg *ExecCommand) { defer func() { err := ctx.cmdw.Close() if err != nil { @@ -939,8 +924,7 @@ func copyCommand(ctx *ServerContext, cmdbytes []byte) { // Write command bytes to pipe. The child process will read the command // to execute from this pipe. - _, err := io.Copy(ctx.cmdw, bytes.NewReader(cmdbytes)) - if err != nil { + if err := json.NewEncoder(ctx.cmdw).Encode(cmdmsg); err != nil { log.Errorf("Failed to copy command over pipe: %v.", err) return } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index d16568b8b5ca2..cd7db4264d2a5 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -21,6 +21,7 @@ package regular import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -1345,8 +1346,8 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con defer scx.Debugf("Closing direct-tcpip channel from %v to %v.", scx.SrcAddr, scx.DstAddr) // Create command to re-exec Teleport which will perform a net.Dial. The - // reason it's not done directly is because the PAM stack needs to be called - // from another process. + // reason it's not done directly because the PAM stack needs to be called + // from the child process. cmd, err := srv.ConfigureCommand(scx) if err != nil { writeStderr(channel, err.Error()) @@ -1378,63 +1379,48 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con return } - // Start copy routines that copy from channel to stdin pipe and from stdout - // pipe to channel. - errorCh := make(chan error, 2) - go func() { - defer channel.Close() - defer pw.Close() - defer pr.Close() - - _, err := io.Copy(pw, channel) - errorCh <- err - }() - go func() { - defer channel.Close() - defer pw.Close() - defer pr.Close() - - _, err := io.Copy(channel, pr) - errorCh <- err - }() - - // Block until copy is complete and the child process is done executing. -Loop: - for i := 0; i < 2; i++ { - select { - case err := <-errorCh: - if err != nil && err != io.EOF { - s.Logger.Warnf("Connection problem in \"direct-tcpip\" channel: %v %T.", trace.DebugReport(err), err) - } - case <-ctx.Done(): - break Loop - case <-s.ctx.Done(): - break Loop + if err := utils.ProxyConn(ctx, utils.CombineReadWriteCloser(pr, pw), channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { + s.Logger.Warnf("Connection problem in direct-tcpip channel: %v %T.", trace.DebugReport(err), err) + } + + // Emit a port forwarding event if the command exited successfully. + if err := cmd.Wait(); err == nil { + if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{ + Metadata: apievents.Metadata{ + Type: events.PortForwardEvent, + Code: events.PortForwardCode, + }, + UserMetadata: scx.Identity.GetUserMetadata(), + ConnectionMetadata: apievents.ConnectionMetadata{ + LocalAddr: scx.ServerConn.LocalAddr().String(), + RemoteAddr: scx.ServerConn.RemoteAddr().String(), + }, + Addr: scx.DstAddr, + Status: apievents.Status{ + Success: true, + }, + }); err != nil { + s.Logger.WithError(err).Warn("Failed to emit port forward event.") } - } - err = cmd.Wait() - if err != nil { - writeStderr(channel, err.Error()) return } - // Emit a port forwarding event. - if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{ - Metadata: apievents.Metadata{ - Type: events.PortForwardEvent, - Code: events.PortForwardCode, - }, - UserMetadata: scx.Identity.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - LocalAddr: scx.ServerConn.LocalAddr().String(), - RemoteAddr: scx.ServerConn.RemoteAddr().String(), - }, - Addr: scx.DstAddr, - Status: apievents.Status{ - Success: true, - }, - }); err != nil { - s.Logger.WithError(err).Warn("Failed to emit port forward event.") + // Get the error to see why the child process failed and + // determine the correct course of action. + err = scx.GetChildError() + switch { + case err == nil: + s.Logger.Warn("Forwarding data via direct-tcpip channel failed for unknown reason") + return + // The user does not exist for the provided login. Terminate the connection. + case errors.Is(err, trace.NotFound(user.UnknownUserError(scx.Identity.Login).Error())), + errors.Is(err, trace.BadParameter("unknown user")): + s.Logger.Warnf("Forwarding data via direct-tcpip channel failed. Terminating connection because user %q does not exist", scx.Identity.Login) + if err := ccx.ServerConn.Close(); err != nil { + s.Logger.Warnf("Unable to terminate connection: %v", err) + } + default: + s.Logger.WithError(err).Error("Forwarding data via direct-tcpip channel failed") } } diff --git a/lib/utils/proxyconn.go b/lib/utils/proxyconn.go index 3856f9d998ce7..2493ad222edef 100644 --- a/lib/utils/proxyconn.go +++ b/lib/utils/proxyconn.go @@ -23,6 +23,36 @@ import ( "github.com/gravitational/trace" ) +// CombinedReadWriteCloser wraps an [io.ReadCloser] and an [io.WriteCloser] to +// implement [io.ReadWriteCloser]. Reads are performed on the [io.ReadCloser] and +// writes are performed on the [io.WriteCloser]. Closing will return the +// aggregated errors of both. +type CombinedReadWriteCloser struct { + r io.ReadCloser + w io.WriteCloser +} + +func (o CombinedReadWriteCloser) Read(p []byte) (int, error) { + return o.r.Read(p) +} + +func (o CombinedReadWriteCloser) Write(p []byte) (int, error) { + return o.w.Write(p) +} + +func (o CombinedReadWriteCloser) Close() error { + return trace.NewAggregate(o.r.Close(), o.w.Close()) +} + +// CombineReadWriteCloser creates a CombinedReadWriteCloser from the provided +// [io.ReadCloser] and [io.WriteCloser] that implements [io.ReadWriteCloser] +func CombineReadWriteCloser(r io.ReadCloser, w io.WriteCloser) CombinedReadWriteCloser { + return CombinedReadWriteCloser{ + r: r, + w: w, + } +} + // ProxyConn launches a double-copy loop that proxies traffic between the // provided client and server connections. //