From 93aa22b202105429cbb9edb6e87e02cc7aa26a9e Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Thu, 16 Mar 2023 14:46:37 -0400 Subject: [PATCH 1/2] Prevent tunneling if the os login doesn't exist A user.Lookup was added to srv.RunForward to prevent dialing and forwarding any data if the os login is not found. The check alone only terminates the direct-tcpip ssh channel and not the underlying ssh connection. In order for the parent process to determine if the ssh connection should be terminated it needs to know why the child exited. That was not possible by looking at the exit code and any data written to standard error of the child process was forwarded to standard error on the parent; which was used to simply log the error and move on. To pass more detailed errors to the parent, the child process spawned by srv.RunForward now json marshals the trace.Error to standard error which is then decoded by the parent process. If the parent detects the error was due to a missing user it terminates the ssh connection. tsh ssh -N was also modified to terminate if the command context of tsh OR the ssh connection to the node is closes. Prior, it only terminated if the user cancelled the process by blocking on ctx.Done(). While this was necessary to end session if the os login does not exit, it also forces tsh to exit if the node goes offline. Note: This does not include any propagation of error messages to the user, so there won't be any indication from tsh about why the connection was closed. The session also will not be terminated until the first attempt to forward data and NOT when the session is created due to the way -N is implemented. Fixes #217 --- integration/port_forwarding_test.go | 38 +++++++++++- lib/client/api.go | 28 ++++++--- lib/client/client.go | 82 +++++++++--------------- lib/srv/ctx.go | 63 +++++++++++++++++++ lib/srv/ctx_test.go | 16 +++++ lib/srv/reexec.go | 94 ++++++++++++---------------- lib/srv/regular/sshserver.go | 96 ++++++++++++----------------- lib/utils/proxyconn.go | 30 +++++++++ 8 files changed, 274 insertions(+), 173 deletions(-) diff --git a/integration/port_forwarding_test.go b/integration/port_forwarding_test.go index e797e9bc700eb..50907b8ab9025 100644 --- a/integration/port_forwarding_test.go +++ b/integration/port_forwarding_test.go @@ -22,10 +22,12 @@ import ( "net/http" "net/http/httptest" "net/url" + "os/user" "strconv" "testing" "time" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -71,19 +73,49 @@ func waitForSessionToBeEstablished(ctx context.Context, namespace string, site a } func testPortForwarding(t *testing.T, suite *integrationTestSuite) { + invalidOSLogin := uuid.NewString()[:12] + notFound := false + for i := 0; i < 10; i++ { + if _, err := user.Lookup(invalidOSLogin); err == nil { + invalidOSLogin = uuid.NewString()[:12] + continue + } + notFound = true + break + } + require.True(t, notFound, "unable to locate invalid os user") + + // Providing our own logins to Teleport so we can verify that a user + // that exists within Teleport but does not exist on the local node + // cannot port forward. + logins := []string{ + invalidOSLogin, + suite.Me.Username, + } + testCases := []struct { desc string portForwardingAllowed bool expectSuccess bool + login string }{ { desc: "Enabled", portForwardingAllowed: true, expectSuccess: true, - }, { + login: suite.Me.Username, + }, + { desc: "Disabled", portForwardingAllowed: false, expectSuccess: false, + login: suite.Me.Username, + }, + { + desc: "Enabled with invalid user", + portForwardingAllowed: true, + expectSuccess: false, + login: invalidOSLogin, }, } @@ -106,7 +138,7 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { cfg.SSH.Enabled = true cfg.SSH.AllowTCPForwarding = tt.portForwardingAllowed - teleport := suite.NewTeleportWithConfig(t, nil, nil, cfg) + teleport := suite.NewTeleportWithConfig(t, logins, nil, cfg) defer teleport.StopAll() site := teleport.GetSiteAPI(helpers.Site) @@ -127,7 +159,7 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { nodeSSHPort := helpers.Port(t, teleport.SSH) cl, err := teleport.NewClient(helpers.ClientConfig{ - Login: suite.Me.Username, + Login: tt.login, Cluster: helpers.Site, Host: Host, Port: nodeSSHPort, diff --git a/lib/client/api.go b/lib/client/api.go index a4af7ff464b1e..c93d0809e5ac0 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "io" "net" @@ -1587,16 +1588,27 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, nod return trace.Wrap(err) } - // If no remote command execution was requested, block on the context which - // will unblock upon error or SIGINT. + // If no remote command execution was requested block on which ever comes first: + // 1) the context which will unblock upon error or user terminating the process + // 2) ssh.Client.Wait which will unblock when the connection has shut down if tc.NoRemoteExec { - log.Debugf("Connected to node, no remote command execution was requested, blocking until context closes.") - <-ctx.Done() - - // Only return an error if the context was canceled by something other than SIGINT. - if ctx.Err() != context.Canceled { - return ctx.Err() + connClosed := make(chan error, 1) + go func() { + connClosed <- nodeClient.Client.Wait() + }() + log.Debugf("Connected to node, no remote command execution was requested, blocking indefinitely.") + select { + case <-ctx.Done(): + // Only return an error if the context was canceled by something other than SIGINT. + if err := ctx.Err(); !errors.Is(err, context.Canceled) { + return trace.Wrap(err) + } + case err := <-connClosed: + if !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } } + return nil } diff --git a/lib/client/client.go b/lib/client/client.go index 9914b57fd7cf4..b5693ccbac160 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -48,6 +48,7 @@ import ( tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -1840,75 +1841,52 @@ func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error } type netDialer interface { - Dial(string, string) (net.Conn, error) + DialContext(context.Context, string, string) (net.Conn, error) } func proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string, dialer netDialer) error { defer conn.Close() defer log.Debugf("Finished proxy from %v to %v.", conn.RemoteAddr(), remoteAddr) - var ( - remoteConn net.Conn - err error - ) - + var remoteConn net.Conn log.Debugf("Attempting to connect proxy from %v to %v.", conn.RemoteAddr(), remoteAddr) - for attempt := 1; attempt <= 5; attempt++ { - remoteConn, err = dialer.Dial("tcp", remoteAddr) - if err != nil { - log.Debugf("Proxy connection attempt %v: %v.", attempt, err) - - timer := time.NewTimer(time.Duration(100*attempt) * time.Millisecond) - defer timer.Stop() - - // Wait and attempt to connect again, if the context has closed, exit - // right away. - select { - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - case <-timer.C: - continue - } - } - // Connection established, break out of the loop. - break - } + + retry, err := retryutils.NewLinear(retryutils.LinearConfig{ + First: 100 * time.Millisecond, + Step: 100 * time.Millisecond, + Max: time.Second, + Jitter: retryutils.NewHalfJitter(), + }) if err != nil { - return trace.BadParameter("failed to connect to node: %v", remoteAddr) + return trace.Wrap(err) } - defer remoteConn.Close() - - // Start proxying, close the connection if a problem occurs on either leg. - errCh := make(chan error, 2) - go func() { - defer conn.Close() - defer remoteConn.Close() - _, err := io.Copy(conn, remoteConn) - errCh <- err - }() - go func() { - defer conn.Close() - defer remoteConn.Close() - - _, err := io.Copy(remoteConn, conn) - errCh <- err - }() + for attempt := 1; attempt <= 5; attempt++ { + conn, err := dialer.DialContext(ctx, "tcp", remoteAddr) + if err == nil { + // Connection established, break out of the loop. + remoteConn = conn + break + } - var errs []error - for i := 0; i < 2; i++ { + log.Debugf("Proxy connection attempt %v: %v.", attempt, err) + // Wait and attempt to connect again, if the context has closed, exit + // right away. select { - case err := <-errCh: - if err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { - log.Warnf("Failed to proxy connection: %v.", err) - errs = append(errs, err) - } case <-ctx.Done(): return trace.Wrap(ctx.Err()) + case <-retry.After(): + retry.Inc() + continue } } + if remoteConn == nil { + return trace.BadParameter("failed to connect to node: %v", remoteAddr) + } + defer remoteConn.Close() - return trace.NewAggregate(errs...) + // Start proxying, close the connection if a problem occurs on either leg. + return trace.Wrap(utils.ProxyConn(ctx, remoteConn, conn)) } // acceptWithContext calls "Accept" on the listener but will unblock when the diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 7fd0a11cada84..375414aaf2710 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -18,6 +18,7 @@ package srv import ( "context" + "encoding/json" "fmt" "io" "net" @@ -182,6 +183,48 @@ type Server interface { TargetMetadata() apievents.ServerMetadata } +// childProcessError is used to provide an underlying error +// from a re-executed Teleport child process to its parent. +type childProcessError struct { + Code int `json:"code"` + RawError []byte `json:"rawError"` +} + +// writeChildError encodes the provided error +// as json and writes it to w. Special care +// is taken to preserve the error type by +// including the error code and raw message +// so that [DecodeChildError] will return +// the matching error type and message. +func writeChildError(w io.Writer, err error) { + if w == nil || err == nil { + return + } + + data, jerr := json.Marshal(err) + if jerr != nil { + return + } + + _ = json.NewEncoder(w).Encode(childProcessError{ + Code: trace.ErrorToCode(err), + RawError: data, + }) + +} + +// DecodeChildError consumes the output from a child +// process decoding it from its raw form back into +// a concrete error. +func DecodeChildError(r io.Reader) error { + var c childProcessError + if err := json.NewDecoder(r).Decode(&c); err != nil { + return nil + } + + return trace.ReadError(c.Code, c.RawError) +} + // IdentityContext holds all identity information associated with the user // logged on the connection. type IdentityContext struct { @@ -374,6 +417,12 @@ type ServerContext struct { x11rdyr *os.File x11rdyw *os.File + // err{r,w} is used to propagate errors from the child process to the + // parent process so the parent can get more information about why the child + // process failed and act accordingly. + errr *os.File + errw *os.File + // x11Config holds the xauth and XServer listener config for this session. x11Config *X11Config @@ -523,6 +572,15 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s child.AddCloser(child.x11rdyr) child.AddCloser(child.x11rdyw) + // Create pipe used to get errors from the child process. + child.errr, child.errw, err = os.Pipe() + if err != nil { + childErr := child.Close() + return nil, nil, trace.NewAggregate(err, childErr) + } + child.AddCloser(child.errr) + child.AddCloser(child.errw) + return ctx, child, nil } @@ -833,6 +891,11 @@ func (c *ServerContext) x11Ready() (bool, error) { return true, nil } +// GetChildError returns the error from the child process +func (c *ServerContext) GetChildError() error { + return DecodeChildError(c.errr) +} + // takeClosers returns all resources that should be closed and sets the properties to null // we do this to avoid calling Close() under lock to avoid potential deadlocks func (c *ServerContext) takeClosers() []io.Closer { diff --git a/lib/srv/ctx_test.go b/lib/srv/ctx_test.go index 809fa5d2a439f..0143dd3cf952b 100644 --- a/lib/srv/ctx_test.go +++ b/lib/srv/ctx_test.go @@ -17,10 +17,13 @@ limitations under the License. package srv import ( + "bytes" + "os/user" "testing" "github.com/gogo/protobuf/proto" "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/protobuf/testing/protocmp" @@ -31,6 +34,19 @@ import ( "github.com/gravitational/teleport/lib/services" ) +// TestDecodeChildError ensures that child error message marshaling +// and unmarshaling returns the original values. +func TestDecodeChildError(t *testing.T) { + var buf bytes.Buffer + require.NoError(t, DecodeChildError(&buf)) + + targetErr := trace.NotFound(user.UnknownUserError("test").Error()) + + writeChildError(&buf, targetErr) + + require.ErrorIs(t, DecodeChildError(&buf), targetErr) +} + func TestCheckSFTPAllowed(t *testing.T) { srv := newMockServer(t) ctx := newTestServerContext(t, srv, nil) diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index 205871addf9fe..8488cde31d81b 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -68,6 +68,9 @@ const ( // X11File is used to communicate to the parent process that the child // process has set up X11 forwarding. X11File + // ErrorFile is used to communicate any errors terminating the child process + // to the parent process + ErrorFile // PTYFile is a PTY the parent process passes to the child process. PTYFile // TTYFile is a TTY the parent process passes to the child process. @@ -75,9 +78,13 @@ const ( // FirstExtraFile is the first file descriptor that will be valid when // extra files are passed to child processes without a terminal. - FirstExtraFile = X11File + 1 + FirstExtraFile FileFD = ErrorFile + 1 ) +func fdName(f FileFD) string { + return fmt.Sprintf("/proc/self/fd/%d", f) +} + // ExecCommand contains the payload to "teleport exec" which will be used to // construct and execute a shell. type ExecCommand struct { @@ -191,29 +198,23 @@ func RunCommand() (errw io.Writer, code int, err error) { errorWriter := os.Stdout // Parent sends the command payload in the third file descriptor. - cmdfd := os.NewFile(CommandFile, fmt.Sprintf("/proc/self/fd/%d", CommandFile)) + cmdfd := os.NewFile(CommandFile, fdName(CommandFile)) if cmdfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("command pipe not found") } - contfd := os.NewFile(ContinueFile, fmt.Sprintf("/proc/self/fd/%d", ContinueFile)) + contfd := os.NewFile(ContinueFile, fdName(ContinueFile)) if contfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue pipe not found") } - termiantefd := os.NewFile(TerminateFile, fmt.Sprintf("/proc/self/fd/%d", TerminateFile)) + termiantefd := os.NewFile(TerminateFile, fdName(TerminateFile)) if termiantefd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("terminate pipe not found") } // Read in the command payload. - var b bytes.Buffer - _, err = b.ReadFrom(cmdfd) - if err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) - } var c ExecCommand - err = json.Unmarshal(b.Bytes(), &c) - if err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) + if err := json.NewDecoder(cmdfd).Decode(&c); err != nil { + return io.Discard, teleport.RemoteCommandFailure, trace.Wrap(err) } auditdMsg := auditd.Message{ @@ -251,8 +252,8 @@ func RunCommand() (errw io.Writer, code int, err error) { // PTY and TTY. Extract them and set the controlling TTY. Otherwise, connect // std{in,out,err} directly. if c.Terminal { - pty = os.NewFile(PTYFile, fmt.Sprintf("/proc/self/fd/%d", PTYFile)) - tty = os.NewFile(TTYFile, fmt.Sprintf("/proc/self/fd/%d", TTYFile)) + pty = os.NewFile(PTYFile, fdName(PTYFile)) + tty = os.NewFile(TTYFile, fdName(TTYFile)) if pty == nil || tty == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("pty and tty not found") } @@ -391,7 +392,7 @@ func RunCommand() (errw io.Writer, code int, err error) { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", x11.DisplayEnv, c.X11Config.XAuthEntry.Display.String())) // Open x11rdy fd to signal parent process once X11 forwarding is set up. - x11rdyfd := os.NewFile(X11File, fmt.Sprintf("/proc/self/fd/%d", X11File)) + x11rdyfd := os.NewFile(X11File, fdName(X11File)) if x11rdyfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("continue pipe not found") } @@ -568,20 +569,24 @@ func RunForward() (errw io.Writer, code int, err error) { errorWriter := os.Stderr // Parent sends the command payload in the third file descriptor. - cmdfd := os.NewFile(CommandFile, fmt.Sprintf("/proc/self/fd/%d", CommandFile)) + cmdfd := os.NewFile(CommandFile, fdName(CommandFile)) if cmdfd == nil { return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("command pipe not found") } - // Read in the command payload. - var b bytes.Buffer - _, err = b.ReadFrom(cmdfd) - if err != nil { - return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) + // Parent receives any errors on the sixth file descriptor. + errfd := os.NewFile(ErrorFile, fdName(ErrorFile)) + if errfd == nil { + return errorWriter, teleport.RemoteCommandFailure, trace.BadParameter("error pipe not found") } + + defer func() { + writeChildError(errfd, err) + }() + + // Read in the command payload. var c ExecCommand - err = json.Unmarshal(b.Bytes(), &c) - if err != nil { + if err := json.NewDecoder(cmdfd).Decode(&c); err != nil { return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) } @@ -607,6 +612,10 @@ func RunForward() (errw io.Writer, code int, err error) { defer pamContext.Close() } + if _, err := user.Lookup(c.Login); err != nil { + return errorWriter, teleport.RemoteCommandFailure, trace.NotFound(err.Error()) + } + // Connect to the target host. conn, err := net.Dial("tcp", c.DestinationAddress) if err != nil { @@ -614,33 +623,12 @@ func RunForward() (errw io.Writer, code int, err error) { } defer conn.Close() - // Start copy routines that copy from channel to stdin pipe and from stdout - // pipe to channel. - errorCh := make(chan error, 2) - go func() { - defer conn.Close() - defer os.Stdout.Close() - defer os.Stdin.Close() - - _, err := io.Copy(os.Stdout, conn) - errorCh <- err - }() - go func() { - defer conn.Close() - defer os.Stdout.Close() - defer os.Stdin.Close() - - _, err := io.Copy(conn, os.Stdin) - errorCh <- err - }() - - // Block until copy is complete in either direction. The other direction - // will get cleaned up automatically. - if err = <-errorCh; err != nil && err != io.EOF { + err = utils.ProxyConn(context.Background(), utils.CombineReadWriteCloser(os.Stdin, os.Stdout), conn) + if err != nil && !errors.Is(err, io.EOF) { return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err) } - return io.Discard, teleport.RemoteCommandSuccess, nil + return errorWriter, teleport.RemoteCommandSuccess, nil } // runCheckHomeDir check's if the active user's $HOME dir exists. @@ -877,11 +865,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er cmdmsg.ExtraFilesLen = len(extraFiles) } - cmdbytes, err := json.Marshal(cmdmsg) - if err != nil { - return nil, trace.Wrap(err) - } - go copyCommand(ctx, cmdbytes) + go copyCommand(ctx, cmdmsg) // Find the Teleport executable and its directory on disk. executable, err := os.Executable() @@ -911,6 +895,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er ctx.contr, ctx.killShellr, ctx.x11rdyw, + ctx.errw, }, } // Add extra files if applicable. @@ -926,7 +911,7 @@ func ConfigureCommand(ctx *ServerContext, extraFiles ...*os.File) (*exec.Cmd, er // copyCommand will copy the provided command to the child process over the // pipe attached to the context. -func copyCommand(ctx *ServerContext, cmdbytes []byte) { +func copyCommand(ctx *ServerContext, cmdmsg *ExecCommand) { defer func() { err := ctx.cmdw.Close() if err != nil { @@ -939,8 +924,7 @@ func copyCommand(ctx *ServerContext, cmdbytes []byte) { // Write command bytes to pipe. The child process will read the command // to execute from this pipe. - _, err := io.Copy(ctx.cmdw, bytes.NewReader(cmdbytes)) - if err != nil { + if err := json.NewEncoder(ctx.cmdw).Encode(cmdmsg); err != nil { log.Errorf("Failed to copy command over pipe: %v.", err) return } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index d16568b8b5ca2..cd7db4264d2a5 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -21,6 +21,7 @@ package regular import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -1345,8 +1346,8 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con defer scx.Debugf("Closing direct-tcpip channel from %v to %v.", scx.SrcAddr, scx.DstAddr) // Create command to re-exec Teleport which will perform a net.Dial. The - // reason it's not done directly is because the PAM stack needs to be called - // from another process. + // reason it's not done directly because the PAM stack needs to be called + // from the child process. cmd, err := srv.ConfigureCommand(scx) if err != nil { writeStderr(channel, err.Error()) @@ -1378,63 +1379,48 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con return } - // Start copy routines that copy from channel to stdin pipe and from stdout - // pipe to channel. - errorCh := make(chan error, 2) - go func() { - defer channel.Close() - defer pw.Close() - defer pr.Close() - - _, err := io.Copy(pw, channel) - errorCh <- err - }() - go func() { - defer channel.Close() - defer pw.Close() - defer pr.Close() - - _, err := io.Copy(channel, pr) - errorCh <- err - }() - - // Block until copy is complete and the child process is done executing. -Loop: - for i := 0; i < 2; i++ { - select { - case err := <-errorCh: - if err != nil && err != io.EOF { - s.Logger.Warnf("Connection problem in \"direct-tcpip\" channel: %v %T.", trace.DebugReport(err), err) - } - case <-ctx.Done(): - break Loop - case <-s.ctx.Done(): - break Loop + if err := utils.ProxyConn(ctx, utils.CombineReadWriteCloser(pr, pw), channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { + s.Logger.Warnf("Connection problem in direct-tcpip channel: %v %T.", trace.DebugReport(err), err) + } + + // Emit a port forwarding event if the command exited successfully. + if err := cmd.Wait(); err == nil { + if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{ + Metadata: apievents.Metadata{ + Type: events.PortForwardEvent, + Code: events.PortForwardCode, + }, + UserMetadata: scx.Identity.GetUserMetadata(), + ConnectionMetadata: apievents.ConnectionMetadata{ + LocalAddr: scx.ServerConn.LocalAddr().String(), + RemoteAddr: scx.ServerConn.RemoteAddr().String(), + }, + Addr: scx.DstAddr, + Status: apievents.Status{ + Success: true, + }, + }); err != nil { + s.Logger.WithError(err).Warn("Failed to emit port forward event.") } - } - err = cmd.Wait() - if err != nil { - writeStderr(channel, err.Error()) return } - // Emit a port forwarding event. - if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{ - Metadata: apievents.Metadata{ - Type: events.PortForwardEvent, - Code: events.PortForwardCode, - }, - UserMetadata: scx.Identity.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - LocalAddr: scx.ServerConn.LocalAddr().String(), - RemoteAddr: scx.ServerConn.RemoteAddr().String(), - }, - Addr: scx.DstAddr, - Status: apievents.Status{ - Success: true, - }, - }); err != nil { - s.Logger.WithError(err).Warn("Failed to emit port forward event.") + // Get the error to see why the child process failed and + // determine the correct course of action. + err = scx.GetChildError() + switch { + case err == nil: + s.Logger.Warn("Forwarding data via direct-tcpip channel failed for unknown reason") + return + // The user does not exist for the provided login. Terminate the connection. + case errors.Is(err, trace.NotFound(user.UnknownUserError(scx.Identity.Login).Error())), + errors.Is(err, trace.BadParameter("unknown user")): + s.Logger.Warnf("Forwarding data via direct-tcpip channel failed. Terminating connection because user %q does not exist", scx.Identity.Login) + if err := ccx.ServerConn.Close(); err != nil { + s.Logger.Warnf("Unable to terminate connection: %v", err) + } + default: + s.Logger.WithError(err).Error("Forwarding data via direct-tcpip channel failed") } } diff --git a/lib/utils/proxyconn.go b/lib/utils/proxyconn.go index 3856f9d998ce7..2493ad222edef 100644 --- a/lib/utils/proxyconn.go +++ b/lib/utils/proxyconn.go @@ -23,6 +23,36 @@ import ( "github.com/gravitational/trace" ) +// CombinedReadWriteCloser wraps an [io.ReadCloser] and an [io.WriteCloser] to +// implement [io.ReadWriteCloser]. Reads are performed on the [io.ReadCloser] and +// writes are performed on the [io.WriteCloser]. Closing will return the +// aggregated errors of both. +type CombinedReadWriteCloser struct { + r io.ReadCloser + w io.WriteCloser +} + +func (o CombinedReadWriteCloser) Read(p []byte) (int, error) { + return o.r.Read(p) +} + +func (o CombinedReadWriteCloser) Write(p []byte) (int, error) { + return o.w.Write(p) +} + +func (o CombinedReadWriteCloser) Close() error { + return trace.NewAggregate(o.r.Close(), o.w.Close()) +} + +// CombineReadWriteCloser creates a CombinedReadWriteCloser from the provided +// [io.ReadCloser] and [io.WriteCloser] that implements [io.ReadWriteCloser] +func CombineReadWriteCloser(r io.ReadCloser, w io.WriteCloser) CombinedReadWriteCloser { + return CombinedReadWriteCloser{ + r: r, + w: w, + } +} + // ProxyConn launches a double-copy loop that proxies traffic between the // provided client and server connections. // From 393603dc891011008f2ab1e991e6df48a37c7e68 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Tue, 14 Mar 2023 11:17:27 +0000 Subject: [PATCH 2/2] Prevent unauthorized access to kube clusters by upserting kube_servers This PR changes the behavior of the kubernetes_service when validating access to kubernetes clusters. Previously, the kubernetes_service would use the first kubernetes cluster it found in the Auth server backend to validate access. This was problematic because if the first kubernetes cluster was upserted with a the same name as a kubernetes cluster the user was trying to access but with different labels, the user would be able to access the cluster even though they shouldn't be able to. This PR changes the behavior of the kubernetes_service to use the in memory kubernetes cluster representation used for heartbeats instead of relying on the information received from the auth server. This would block the user from accessing the cluster if the cluster was upserted with a different set of labels since the kubernetes_service would not have the updated labels in memory and would deny access. Fixes #469 --- lib/authz/permissions.go | 2 +- lib/cache/cache_test.go | 28 ++++++ lib/kube/proxy/forwarder.go | 56 +++++++----- lib/kube/proxy/forwarder_test.go | 145 +++++++++++++++++++++++++++++++ lib/kube/proxy/server.go | 79 +++++++++++++++-- lib/kube/proxy/utils_testing.go | 8 +- lib/kube/proxy/watcher.go | 4 +- lib/services/kubernetes.go | 6 +- lib/services/watcher.go | 10 +-- 9 files changed, 291 insertions(+), 47 deletions(-) diff --git a/lib/authz/permissions.go b/lib/authz/permissions.go index a5230ae2eddc9..248f168f5f9c4 100644 --- a/lib/authz/permissions.go +++ b/lib/authz/permissions.go @@ -748,7 +748,7 @@ func definitionForBuiltinRole(clusterName string, recConfig types.SessionRecordi types.NewRule(types.KindRole, services.RO()), types.NewRule(types.KindNamespace, services.RO()), types.NewRule(types.KindLock, services.RO()), - types.NewRule(types.KindKubernetesCluster, services.RW()), + types.NewRule(types.KindKubernetesCluster, services.RO()), types.NewRule(types.KindSemaphore, services.RW()), }, }, diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index c52f5ace8fc52..3f1ac663c6274 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -1622,6 +1622,34 @@ func TestApplicationServers(t *testing.T) { }) } +// TestKubernetesServers tests that CRUD operations on kube servers are +// replicated from the backend to the cache. +func TestKubernetesServers(t *testing.T) { + t.Parallel() + + p := newTestPack(t, ForProxy) + t.Cleanup(p.Close) + + testResources(t, p, testFuncs[types.KubeServer]{ + newResource: func(name string) (types.KubeServer, error) { + app, err := types.NewKubernetesClusterV3(types.Metadata{Name: name}, types.KubernetesClusterSpecV3{}) + require.NoError(t, err) + return types.NewKubernetesServerV3FromCluster(app, "host", uuid.New().String()) + }, + create: withKeepalive(p.presenceS.UpsertKubernetesServer), + list: func(ctx context.Context) ([]types.KubeServer, error) { + return p.presenceS.GetKubernetesServers(ctx) + }, + cacheList: func(ctx context.Context) ([]types.KubeServer, error) { + return p.cache.GetKubernetesServers(ctx) + }, + update: withKeepalive(p.presenceS.UpsertKubernetesServer), + deleteAll: func(ctx context.Context) error { + return p.presenceS.DeleteAllKubernetesServers(ctx) + }, + }) +} + // TestApps tests that CRUD operations on application resources are // replicated from the backend to the cache. func TestApps(t *testing.T) { diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index a5fedc5c82bf1..daba8a21f8771 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -352,8 +352,19 @@ type Forwarder struct { sessions map[uuid.UUID]*session // upgrades connections to websockets upgrader websocket.Upgrader + // getKubernetesServersForKubeCluster is a function that returns a list of + // kubernetes servers for a given kube cluster but uses different methods + // depending on the service type. + // For example, if the service type is KubeService, it will use the + // local kubernetes clusters. If the service type is Proxy, it will + // use the heartbeat clusters. + getKubernetesServersForKubeCluster getKubeServersByNameFunc } +// getKubeServersByNameFunc is a function that returns a list of +// kubernetes servers for a given kube cluster. +type getKubeServersByNameFunc = func(ctx context.Context, name string) ([]types.KubeServer, error) + // Close signals close to all outstanding or background operations // to complete func (f *Forwarder) Close() error { @@ -396,6 +407,9 @@ type authContext struct { kubeResource *types.KubernetesResource // httpMethod is the request HTTP Method. httpMethod string + // kubeServers are the registered agents for the kubernetes cluster the request + // is targeted to. + kubeServers []types.KubeServer } func (c authContext) String() string { @@ -729,7 +743,9 @@ func (f *Forwarder) setupContext(ctx context.Context, authCtx authz.Context, req } kubeCluster := identity.KubernetesCluster - if !isRemoteCluster { + // Only set a default kube cluster if the user is not accessing a specific cluster. + // The check for kubeCluster != "" is happens in the next code section. + if !isRemoteCluster && kubeCluster == "" { kc, err := kubeutils.CheckOrSetKubeCluster(ctx, f.cfg.CachingAuthClient, identity.KubernetesCluster, teleportClusterName) if err != nil { if !trace.IsNotFound(err) { @@ -746,14 +762,20 @@ func (f *Forwarder) setupContext(ctx context.Context, authCtx authz.Context, req var ( kubeUsers, kubeGroups []string kubeLabels map[string]string + kubeServers []types.KubeServer + err error ) // Only check k8s principals for local clusters. // // For remote clusters, everything will be remapped to new roles on the // leaf and checked there. if !isRemoteCluster { + kubeServers, err = f.getKubernetesServersForKubeCluster(ctx, kubeCluster) + if err != nil || len(kubeServers) == 0 { + return nil, trace.NotFound("cluster %q not found", kubeCluster) + } // check signing TTL and return a list of allowed logins for local cluster based on Kubernetes service labels. - kubeAccessDetails, err := f.getKubeAccessDetails(roles, kubeCluster, sessionTTL, kubeResource) + kubeAccessDetails, err := f.getKubeAccessDetails(kubeServers, roles, kubeCluster, sessionTTL, kubeResource) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) // roles.CheckKubeGroupsAndUsers returns trace.NotFound if the user does @@ -924,7 +946,8 @@ func (f *Forwarder) setupContext(ctx context.Context, authCtx authz.Context, req isRemote: isRemoteCluster, isRemoteClosed: isRemoteClosed, }, - httpMethod: req.Method, + httpMethod: req.Method, + kubeServers: kubeServers, }, nil } @@ -1008,16 +1031,12 @@ type kubeAccessDetails struct { // getKubeAccessDetails returns the allowed kube groups/users names and the cluster labels for a local kube cluster. func (f *Forwarder) getKubeAccessDetails( + kubeServers []types.KubeServer, roles services.AccessChecker, kubeClusterName string, sessionTTL time.Duration, kubeResource *types.KubernetesResource, ) (kubeAccessDetails, error) { - kubeServers, err := f.cfg.CachingAuthClient.GetKubernetesServers(f.ctx) - if err != nil { - return kubeAccessDetails{}, trace.Wrap(err) - } - // Find requested kubernetes cluster name and get allowed kube users/groups names. for _, s := range kubeServers { c := s.GetCluster() @@ -1123,10 +1142,7 @@ func (f *Forwarder) authorize(ctx context.Context, actx *authContext) error { f.log.WithField("auth_context", actx.String()).Debug("Skipping authorization due to unknown kubernetes cluster name") return nil } - servers, err := f.cfg.CachingAuthClient.GetKubernetesServers(ctx) - if err != nil { - return trace.Wrap(err) - } + authPref, err := f.cfg.CachingAuthClient.GetAuthPreference(ctx) if err != nil { return trace.Wrap(err) @@ -1153,7 +1169,7 @@ func (f *Forwarder) authorize(ctx context.Context, actx *authContext) error { // // We assume that users won't register two identically-named clusters with // mis-matched labels. If they do, expect weirdness. - for _, s := range servers { + for _, s := range actx.kubeServers { ks := s.GetCluster() if ks.GetName() != actx.kubeClusterName { continue @@ -2281,11 +2297,7 @@ func (f *Forwarder) newClusterSessionSameCluster(ctx context.Context, authCtx au return sess, nil } - kubeServers, err := f.cfg.CachingAuthClient.GetKubernetesServers(f.ctx) - if err != nil && !trace.IsNotFound(err) { - return nil, trace.Wrap(err) - } - + kubeServers := authCtx.kubeServers if len(kubeServers) == 0 && authCtx.kubeClusterName == authCtx.teleportCluster.name { return nil, trace.Wrap(localErr) } @@ -2314,12 +2326,8 @@ func (f *Forwarder) newClusterSessionSameCluster(ctx context.Context, authCtx au } func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, error) { - if len(f.clusterDetails) == 0 { - return nil, trace.NotFound("this Teleport process is not configured for direct Kubernetes access; you likely need to 'tsh login' into a leaf cluster or 'tsh kube login' into a different kubernetes cluster") - } - - details, ok := f.clusterDetails[ctx.kubeClusterName] - if !ok { + details, err := f.findKubeDetailsByClusterName(ctx.kubeClusterName) + if err != nil { return nil, trace.NotFound("kubernetes cluster %q not found", ctx.kubeClusterName) } diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index c53aaeb0648b4..9aea409cf1f02 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -163,6 +163,19 @@ func TestAuthenticate(t *testing.T) { TracerProvider: otel.GetTracerProvider(), tracer: otel.Tracer(teleport.ComponentKube), }, + getKubernetesServersForKubeCluster: func(ctx context.Context, name string) ([]types.KubeServer, error) { + servers, err := ap.GetKubernetesServers(ctx) + if err != nil { + return nil, err + } + var filtered []types.KubeServer + for _, server := range servers { + if server.GetCluster().GetName() == name { + filtered = append(filtered, server) + } + } + return filtered, nil + }, } const remoteAddr = "user.example.com" @@ -220,6 +233,21 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -243,6 +271,30 @@ func TestAuthenticate(t *testing.T) { DynamicLabels: map[string]types.CommandLabelV2{}, }, }, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "foo", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "bar", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, ), wantCtx: &authContext{ kubeUsers: utils.StringsSet([]string{"user-a"}), @@ -257,6 +309,21 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -289,6 +356,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -320,6 +399,19 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -352,6 +444,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -453,6 +557,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -501,6 +617,18 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "local", + Labels: map[string]string{}, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -559,6 +687,21 @@ func TestAuthenticate(t *testing.T) { name: "local", remoteAddr: *utils.MustParseAddr(remoteAddr), }, + kubeServers: newKubeServersFromKubeClusters( + t, + &types.KubernetesClusterV3{ + Metadata: types.Metadata{ + Name: "foo", + Labels: map[string]string{ + "static_label1": "static_value1", + "static_label2": "static_value2", + }, + }, + Spec: types.KubernetesClusterSpecV3{ + DynamicLabels: map[string]types.CommandLabelV2{}, + }, + }, + ), }, }, { @@ -958,6 +1101,8 @@ func TestNewClusterSessionDirect(t *testing.T) { f.cfg.CachingAuthClient = mockAccessPoint{ kubeServers: []types.KubeServer{publicKubeService, otherKubeService, tunnelKubeService, otherKubeService}, } + authCtx.kubeServers, err = f.cfg.CachingAuthClient.GetKubernetesServers(context.Background()) + require.NoError(t, err) sess, err := f.newClusterSession(ctx, authCtx) require.NoError(t, err) require.Equal(t, []kubeClusterEndpoint{publicEndpoint, tunnelEndpoint}, sess.kubeClusterEndpoints) diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index c8a530b7b5e86..50012e7a67f6c 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -151,8 +151,8 @@ type TLSServer struct { heartbeats map[string]*srv.Heartbeat closeContext context.Context closeFunc context.CancelFunc - // watcher monitors changes to kube cluster resources. - watcher *services.KubeClusterWatcher + // kubeClusterWatcher monitors changes to kube cluster resources. + kubeClusterWatcher *services.KubeClusterWatcher // reconciler reconciles proxied kube clusters with kube_clusters resources. reconciler *services.Reconciler // monitoredKubeClusters contains all kube clusters the proxied kube_clusters are @@ -229,6 +229,11 @@ func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) { } server.TLS.GetConfigForClient = server.GetConfigForClient server.closeContext, server.closeFunc = context.WithCancel(cfg.Context) + // register into the forwarder the method to get kubernetes servers for a kube cluster. + server.fwd.getKubernetesServersForKubeCluster, err = server.getKubernetesServersForKubeClusterFunc() + if err != nil { + return nil, trace.Wrap(err) + } return server, nil } @@ -282,7 +287,9 @@ func (t *TLSServer) Serve(listener net.Listener) error { // Initialize watcher that will be dynamically (un-)registering // proxied clusters based on the kube_cluster resources. - if t.watcher, err = t.startResourceWatcher(t.closeContext); err != nil { + // This watcher is only started for the kube_service if a resource watcher + // is configured. + if t.kubeClusterWatcher, err = t.startKubeClusterResourceWatcher(t.closeContext); err != nil { return trace.Wrap(err) } @@ -314,8 +321,8 @@ func (t *TLSServer) close(ctx context.Context) error { t.closeFunc() // Stop the kube_cluster resource watcher. - if t.watcher != nil { - t.watcher.Close() + if t.kubeClusterWatcher != nil { + t.kubeClusterWatcher.Close() } t.mu.Lock() listClose := t.listener.Close() @@ -350,7 +357,7 @@ func (t *TLSServer) getServerInfo(name string) (types.Resource, error) { addr = t.listener.Addr().String() } - cluster, err := t.getKubeClusterForHeartbeat(name) + cluster, err := t.getKubeClusterWithServiceLabels(name) if err != nil { return nil, trace.Wrap(err) } @@ -385,12 +392,12 @@ func (t *TLSServer) getServerInfo(name string) (types.Resource, error) { return srv, nil } -// getKubeClusterForHeartbeat finds the kube cluster by name, strips the credentials, +// getKubeClusterWithServiceLabels finds the kube cluster by name, strips the credentials, // replaces the cluster dynamic labels with their latest value available and updates // the cluster with the service dynamic and static labels. // We strip the Azure, AWS and Kubeconfig credentials so they are not leaked when // heartbeating the cluster. -func (t *TLSServer) getKubeClusterForHeartbeat(name string) (*types.KubernetesClusterV3, error) { +func (t *TLSServer) getKubeClusterWithServiceLabels(name string) (*types.KubernetesClusterV3, error) { // it is safe do read from details since the structure is never updated. // we replace the whole structure each time an update happens to a dynamic cluster. details, err := t.fwd.findKubeDetailsByClusterName(name) @@ -524,3 +531,59 @@ func (t *TLSServer) setServiceLabels(cluster types.KubeCluster) { cluster.SetDynamicLabels(dstDynLabels) } } + +// getKubernetesServersForKubeClusterFunc returns a function that returns the kubernetes servers +// for a given kube cluster depending on the type of service. +func (t *TLSServer) getKubernetesServersForKubeClusterFunc() (getKubeServersByNameFunc, error) { + switch t.KubeServiceType { + case KubeService: + return func(_ context.Context, name string) ([]types.KubeServer, error) { + // If this is a kube_service, we can just return the local kube servers. + kube, err := t.getKubeClusterWithServiceLabels(name) + if err != nil { + return nil, trace.Wrap(err) + } + srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID) + if err != nil { + return nil, trace.Wrap(err) + } + return []types.KubeServer{srv}, nil + }, nil + case ProxyService: + return t.getAuthKubeServers, nil + case LegacyProxyService: + return func(ctx context.Context, name string) ([]types.KubeServer, error) { + kube, err := t.getKubeClusterWithServiceLabels(name) + if err != nil { + servers, err := t.getAuthKubeServers(ctx, name) + return servers, trace.Wrap(err) + } + srv, err := types.NewKubernetesServerV3FromCluster(kube, "", t.HostID) + if err != nil { + return nil, trace.Wrap(err) + } + return []types.KubeServer{srv}, nil + }, nil + default: + return nil, trace.BadParameter("unknown kubernetes service type %q", t.KubeServiceType) + } +} + +// getAuthKubeServers returns the kubernetes servers for a given kube cluster +// using the Auth server client. +func (t *TLSServer) getAuthKubeServers(ctx context.Context, name string) ([]types.KubeServer, error) { + servers, err := t.CachingAuthClient.GetKubernetesServers(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + var returnServers []types.KubeServer + for _, server := range servers { + if server.GetCluster().GetName() == name { + returnServers = append(returnServers, server) + } + } + if len(returnServers) == 0 { + return nil, trace.NotFound("no kubernetes servers found for cluster %q", name) + } + return returnServers, nil +} diff --git a/lib/kube/proxy/utils_testing.go b/lib/kube/proxy/utils_testing.go index 29964eda0c515..9bd78c249a744 100644 --- a/lib/kube/proxy/utils_testing.go +++ b/lib/kube/proxy/utils_testing.go @@ -172,7 +172,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // heartbeatsWaitChannel waits for clusters heartbeats to start. heartbeatsWaitChannel := make(chan struct{}, len(cfg.Clusters)+1) - + client := newAuthClientWithStreamer(testCtx) // Create kubernetes service server. testCtx.KubeServer, err = NewTLSServer(TLSServerConfig{ ForwarderConfig: ForwarderConfig{ @@ -186,12 +186,12 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo // directly to AuthClient solves the issue. // We wrap the AuthClient with an events.TeeStreamer to send non-disk // events like session.end to testCtx.emitter as well. - AuthClient: newAuthClientWithStreamer(testCtx), + AuthClient: client, // StreamEmitter is required although not used because we are using // "node-sync" as session recording mode. StreamEmitter: testCtx.Emitter, DataDir: t.TempDir(), - CachingAuthClient: testCtx.AuthClient, + CachingAuthClient: client, HostID: testCtx.HostID, Context: testCtx.Context, KubeconfigPath: kubeConfigLocation, @@ -206,7 +206,7 @@ func SetupTestContext(ctx context.Context, t *testing.T, cfg TestConfig) *TestCo }, DynamicLabels: nil, TLS: tlsConfig, - AccessPoint: testCtx.AuthClient, + AccessPoint: client, LimiterConfig: limiter.Config{ MaxConnections: 1000, MaxNumberOfUsers: 1000, diff --git a/lib/kube/proxy/watcher.go b/lib/kube/proxy/watcher.go index 44a093bb9950a..04373f33d02be 100644 --- a/lib/kube/proxy/watcher.go +++ b/lib/kube/proxy/watcher.go @@ -82,9 +82,9 @@ func (s *TLSServer) startReconciler(ctx context.Context) (err error) { return nil } -// startResourceWatcher starts watching changes to Kube Clusters resources and +// startKubeClusterResourceWatcher starts watching changes to Kube Clusters resources and // registers/unregisters the proxied Kube Cluster accordingly. -func (s *TLSServer) startResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) { +func (s *TLSServer) startKubeClusterResourceWatcher(ctx context.Context) (*services.KubeClusterWatcher, error) { if len(s.ResourceMatchers) == 0 || s.KubeServiceType != KubeService { s.log.Debug("Not initializing Kube Cluster resource watcher.") return nil, nil diff --git a/lib/services/kubernetes.go b/lib/services/kubernetes.go index b8571220ec2a9..cd719f6d12d6f 100644 --- a/lib/services/kubernetes.go +++ b/lib/services/kubernetes.go @@ -34,8 +34,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// KubernetesGetter defines interface for fetching kubernetes cluster resources. -type KubernetesGetter interface { +// KubernetesClusterGetter defines interface for fetching kubernetes cluster resources. +type KubernetesClusterGetter interface { // GetKubernetesClusters returns all kubernetes cluster resources. GetKubernetesClusters(context.Context) ([]types.KubeCluster, error) // GetKubernetesCluster returns the specified kubernetes cluster resource. @@ -45,7 +45,7 @@ type KubernetesGetter interface { // Kubernetes defines an interface for managing kubernetes clusters resources. type Kubernetes interface { // KubernetesGetter provides methods for fetching kubernetes resources. - KubernetesGetter + KubernetesClusterGetter // CreateKubernetesCluster creates a new kubernetes cluster resource. CreateKubernetesCluster(context.Context, types.KubeCluster) error // UpdateKubernetesCluster updates an existing kubernetes cluster resource. diff --git a/lib/services/watcher.go b/lib/services/watcher.go index c89b4ff7762eb..015d9dc05a595 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -1016,7 +1016,7 @@ type KubeClusterWatcherConfig struct { // ResourceWatcherConfig is the resource watcher configuration. ResourceWatcherConfig // KubernetesGetter is responsible for fetching kube_cluster resources. - KubernetesGetter + KubernetesClusterGetter // KubeClustersC receives up-to-date list of all kube_cluster resources. KubeClustersC chan types.KubeClusters } @@ -1026,12 +1026,12 @@ func (cfg *KubeClusterWatcherConfig) CheckAndSetDefaults() error { if err := cfg.ResourceWatcherConfig.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } - if cfg.KubernetesGetter == nil { - getter, ok := cfg.Client.(KubernetesGetter) + if cfg.KubernetesClusterGetter == nil { + getter, ok := cfg.Client.(KubernetesClusterGetter) if !ok { return trace.BadParameter("missing parameter KubernetesGetter and Client not usable as KubernetesGetter") } - cfg.KubernetesGetter = getter + cfg.KubernetesClusterGetter = getter } if cfg.KubeClustersC == nil { cfg.KubeClustersC = make(chan types.KubeClusters) @@ -1087,7 +1087,7 @@ func (k *kubeCollector) resourceKind() string { // getResourcesAndUpdateCurrent refreshes the list of current resources. func (k *kubeCollector) getResourcesAndUpdateCurrent(ctx context.Context) error { - clusters, err := k.KubernetesGetter.GetKubernetesClusters(ctx) + clusters, err := k.KubernetesClusterGetter.GetKubernetesClusters(ctx) if err != nil { return trace.Wrap(err) }