From 8d26546ef669c98c9d4cd92e664f2b6fa6853731 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 10 May 2023 10:11:17 -0700 Subject: [PATCH] [v12] fix: use errors.Is for all EOF comparisons Backport #26012 to branch/v12 This commit updates all `err == io.EOF` comparisons to use `errors.Is(err, io.EOF)`. This is necessary when the error may have been wrapped and fixes at least one current breakage (in `tsh request ls`). `golang.org/x/tools/refactor/eg` was very handy for this, I used the following template: ```go package teleport import ( "errors" "io" ) func before(err error) bool { return err == io.EOF } func after(err error) bool { return errors.Is(err, io.EOF) } ``` --- api/client/client.go | 6 +++--- api/utils/sshutils/ssh.go | 3 ++- integrations/lib/errors.go | 3 ++- integrations/lib/tar/extract.go | 3 ++- integrations/lib/tctl/resources.go | 3 ++- integrations/lib/testing/integration/authservice.go | 5 +++-- integrations/lib/testing/integration/proxyservice.go | 5 +++-- integrations/lib/testing/integration/sshservice.go | 5 +++-- lib/auth/grpcserver.go | 7 ++++--- lib/client/session.go | 3 ++- lib/config/configuration.go | 3 ++- lib/events/auditlog.go | 3 ++- lib/events/filesessions/fileasync.go | 2 +- lib/events/playback.go | 7 ++++--- lib/events/stream.go | 4 ++-- lib/kube/proxy/testing/kube_server/kube_mock.go | 3 ++- lib/proxy/peer/conn.go | 3 ++- lib/srv/ctx.go | 3 ++- lib/srv/exec.go | 3 ++- lib/srv/reexec.go | 2 +- lib/sshutils/marshal.go | 3 ++- lib/sshutils/scp/scp.go | 3 ++- lib/utils/archive_test.go | 3 ++- lib/utils/jsontools.go | 3 ++- lib/utils/unpack.go | 3 ++- lib/web/apiserver_test.go | 2 +- lib/web/terminal.go | 2 +- tool/tctl/common/resource_command.go | 3 ++- tool/tctl/sso/tester/command.go | 3 ++- 29 files changed, 62 insertions(+), 39 deletions(-) diff --git a/api/client/client.go b/api/client/client.go index f7066fe730274..9f728a2f572a6 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -867,7 +867,7 @@ func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessReque var reqs []types.AccessRequest for { req, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } @@ -3025,7 +3025,7 @@ func (c *Client) GetActiveSessionTrackers(ctx context.Context) ([]types.SessionT for { session, err := stream.Recv() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } @@ -3049,7 +3049,7 @@ func (c *Client) GetActiveSessionTrackersWithFilter(ctx context.Context, filter for { session, err := stream.Recv() if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } diff --git a/api/utils/sshutils/ssh.go b/api/utils/sshutils/ssh.go index d516e0d5e0af7..76f8622ca64e6 100644 --- a/api/utils/sshutils/ssh.go +++ b/api/utils/sshutils/ssh.go @@ -21,6 +21,7 @@ package sshutils import ( "crypto" "crypto/subtle" + "errors" "io" "net" "regexp" @@ -67,7 +68,7 @@ func ParseKnownHosts(knownHosts [][]byte, matchHostnames ...string) ([]ssh.Publi for _, line := range knownHosts { for { _, hosts, publicKey, _, bytes, err := ssh.ParseKnownHosts(line) - if err == io.EOF { + if errors.Is(err, io.EOF) { break } else if err != nil { return nil, trace.Wrap(err, "failed parsing known hosts: %v; raw line: %q", err, line) diff --git a/integrations/lib/errors.go b/integrations/lib/errors.go index c1f9d640a04ce..8696c25ca2aeb 100644 --- a/integrations/lib/errors.go +++ b/integrations/lib/errors.go @@ -16,6 +16,7 @@ package lib import ( "context" + "errors" "io" "github.com/gravitational/trace" @@ -27,7 +28,7 @@ import ( // TODO: remove this when trail.FromGRPC will understand additional error codes func FromGRPC(err error) error { switch { - case err == io.EOF: + case errors.Is(err, io.EOF): fallthrough case status.Code(err) == codes.Canceled, err == context.Canceled: fallthrough diff --git a/integrations/lib/tar/extract.go b/integrations/lib/tar/extract.go index 9a305a43ea164..73b92e36e1052 100644 --- a/integrations/lib/tar/extract.go +++ b/integrations/lib/tar/extract.go @@ -19,6 +19,7 @@ package tar import ( "archive/tar" "compress/gzip" + "errors" "io" "os" "path" @@ -95,7 +96,7 @@ func Extract(reader io.Reader, options ExtractOptions) error { } for filesDone == nil || filesDone.Len() > 0 { header, err := tarReader.Next() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } if err != nil { diff --git a/integrations/lib/tctl/resources.go b/integrations/lib/tctl/resources.go index 418b692c29491..c9b5876aa081f 100644 --- a/integrations/lib/tctl/resources.go +++ b/integrations/lib/tctl/resources.go @@ -18,6 +18,7 @@ package tctl import ( "encoding/json" + "errors" "io" "github.com/ghodss/yaml" @@ -52,7 +53,7 @@ func readResourcesYAMLOrJSON(r io.Reader) ([]types.Resource, error) { var res streamResource err := decoder.Decode(&res) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } return nil, trace.Wrap(err) diff --git a/integrations/lib/testing/integration/authservice.go b/integrations/lib/testing/integration/authservice.go index a9c59437391d5..9140976b36c51 100644 --- a/integrations/lib/testing/integration/authservice.go +++ b/integrations/lib/testing/integration/authservice.go @@ -20,6 +20,7 @@ import ( "bufio" "bytes" "context" + "errors" "io" "os/exec" "regexp" @@ -146,7 +147,7 @@ func (auth *AuthService) Run(ctx context.Context) error { stdout := bufio.NewReader(stdoutPipe) for { line, err := stdout.ReadString('\n') - if err == io.EOF { + if errors.Is(err, io.EOF) { return } if err := trace.Wrap(err); err != nil { @@ -177,7 +178,7 @@ func (auth *AuthService) Run(ctx context.Context) error { for { n, err := stderr.Read(data) auth.saveStderr(data[:n]) - if err == io.EOF { + if errors.Is(err, io.EOF) { return } if err := trace.Wrap(err); err != nil { diff --git a/integrations/lib/testing/integration/proxyservice.go b/integrations/lib/testing/integration/proxyservice.go index 87335b06dde2a..084dc6b38146d 100644 --- a/integrations/lib/testing/integration/proxyservice.go +++ b/integrations/lib/testing/integration/proxyservice.go @@ -20,6 +20,7 @@ import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "os/exec" @@ -151,7 +152,7 @@ func (proxy *ProxyService) Run(ctx context.Context) error { stdout := bufio.NewReader(stdoutPipe) for { line, err := stdout.ReadString('\n') - if err == io.EOF { + if errors.Is(err, io.EOF) { return } if err := trace.Wrap(err); err != nil { @@ -187,7 +188,7 @@ func (proxy *ProxyService) Run(ctx context.Context) error { for { n, err := stderr.Read(data) proxy.saveStderr(data[:n]) - if err == io.EOF { + if errors.Is(err, io.EOF) { return } if err := trace.Wrap(err); err != nil { diff --git a/integrations/lib/testing/integration/sshservice.go b/integrations/lib/testing/integration/sshservice.go index af249960c2f59..a5797e928004e 100644 --- a/integrations/lib/testing/integration/sshservice.go +++ b/integrations/lib/testing/integration/sshservice.go @@ -20,6 +20,7 @@ import ( "bufio" "bytes" "context" + "errors" "io" "os/exec" "regexp" @@ -146,7 +147,7 @@ func (ssh *SSHService) Run(ctx context.Context) error { stdout := bufio.NewReader(stdoutPipe) for { line, err := stdout.ReadString('\n') - if err == io.EOF { + if errors.Is(err, io.EOF) { return } if err := trace.Wrap(err); err != nil { @@ -177,7 +178,7 @@ func (ssh *SSHService) Run(ctx context.Context) error { for { n, err := stderr.Read(data) ssh.saveStderr(data[:n]) - if err == io.EOF { + if errors.Is(err, io.EOF) { return } if err := trace.Wrap(err); err != nil { diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 09da37c0d1c34..70167d4ca11d7 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -19,6 +19,7 @@ package auth import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -165,7 +166,7 @@ func (g *GRPCServer) SendKeepAlives(stream proto.AuthService_SendKeepAlivesServe return trace.Wrap(err) } keepAlive, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { g.Debugf("Connection closed.") return nil } @@ -223,7 +224,7 @@ func (g *GRPCServer) CreateAuditStream(stream proto.AuthService_CreateAuditStrea for { request, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } if err != nil { @@ -2249,7 +2250,7 @@ func (g *GRPCServer) MaintainSessionPresence(stream proto.AuthService_MaintainSe for { req, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } diff --git a/lib/client/session.go b/lib/client/session.go index 776b0494f3844..fce9e73d107e4 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -18,6 +18,7 @@ package client import ( "context" + "errors" "fmt" "io" "net" @@ -643,7 +644,7 @@ func handleNonPeerControls(mode types.SessionParticipantMode, term *terminal.Ter for { buf := make([]byte, 1) _, err := term.Stdin().Read(buf) - if err == io.EOF { + if errors.Is(err, io.EOF) { return } diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 84e01d15f0009..c395726b00f9c 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -22,6 +22,7 @@ package config import ( "crypto/x509" + "errors" "io" stdlog "log" "net" @@ -218,7 +219,7 @@ func ReadResources(filePath string) ([]types.Resource, error) { var raw services.UnknownResource err := decoder.Decode(&raw) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } return nil, trace.Wrap(err) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index fee24fb1c37a5..074e11a6e7412 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -22,6 +22,7 @@ import ( "compress/gzip" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -643,7 +644,7 @@ func (l *AuditLog) GetSessionChunk(namespace string, sid session.ID, offsetBytes for { out, err := l.getSessionChunk(namespace, sid, offsetBytes, maxBytes) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return data, nil } return nil, trace.Wrap(err) diff --git a/lib/events/filesessions/fileasync.go b/lib/events/filesessions/fileasync.go index e23fa64df67d5..fae27aed7fd93 100644 --- a/lib/events/filesessions/fileasync.go +++ b/lib/events/filesessions/fileasync.go @@ -506,7 +506,7 @@ func (u *Uploader) upload(ctx context.Context, up *upload) error { for { event, err := up.reader.Read(ctx) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { break } return sessionError{err: trace.Wrap(err)} diff --git a/lib/events/playback.go b/lib/events/playback.go index 89cb9231d08d1..8d9231324c563 100644 --- a/lib/events/playback.go +++ b/lib/events/playback.go @@ -22,6 +22,7 @@ import ( "compress/gzip" "context" "encoding/binary" + "errors" "fmt" "io" "os" @@ -96,7 +97,7 @@ func Export(ctx context.Context, rs io.ReadSeeker, w io.Writer, exportFormat str for { event, err := protoReader.Read(ctx) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return trace.Wrap(err) @@ -159,7 +160,7 @@ func (w *SSHPlaybackWriter) SessionEvents() ([]EventFields, error) { var f EventFields err := utils.FastUnmarshal(scanner.Bytes(), &f) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return sessionEvents, nil } return nil, trace.Wrap(err) @@ -250,7 +251,7 @@ func (w *SSHPlaybackWriter) Write(ctx context.Context) error { for { event, err := w.reader.Read(ctx) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return trace.Wrap(err) diff --git a/lib/events/stream.go b/lib/events/stream.go index 85d54a167c718..97ade5ea1b5a2 100644 --- a/lib/events/stream.go +++ b/lib/events/stream.go @@ -1015,7 +1015,7 @@ func (r *ProtoReader) Read(ctx context.Context) (apievents.AuditEvent, error) { _, err := io.ReadFull(r.reader, r.sizeBytes[:Int64Size]) if err != nil { // reached the end of the stream - if err == io.EOF { + if errors.Is(err, io.EOF) { r.state = protoReaderStateEOF return nil, err } @@ -1116,7 +1116,7 @@ func (r *ProtoReader) ReadAll(ctx context.Context) ([]apievents.AuditEvent, erro for { event, err := r.Read(ctx) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return events, nil } return nil, trace.Wrap(err) diff --git a/lib/kube/proxy/testing/kube_server/kube_mock.go b/lib/kube/proxy/testing/kube_server/kube_mock.go index 90b436692c6f1..26c4ac9215110 100644 --- a/lib/kube/proxy/testing/kube_server/kube_mock.go +++ b/lib/kube/proxy/testing/kube_server/kube_mock.go @@ -21,6 +21,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "net" @@ -215,7 +216,7 @@ func (s *KubeMockServer) exec(w http.ResponseWriter, req *http.Request, p httpro for { buffer = buffer[:cap(buffer)] n, err := proxy.stdinStream.Read(buffer) - if err == io.EOF && n == 0 { + if errors.Is(err, io.EOF) && n == 0 { break } else if err != nil && n == 0 { s.log.WithError(err).Errorf("unable to receive from stdin") diff --git a/lib/proxy/peer/conn.go b/lib/proxy/peer/conn.go index 0e7bf1bfde5a4..8524262c6bc03 100644 --- a/lib/proxy/peer/conn.go +++ b/lib/proxy/peer/conn.go @@ -16,6 +16,7 @@ package peer import ( "context" + "errors" "io" "net" "sync" @@ -69,7 +70,7 @@ func (c *streamConn) Read(b []byte) (n int, err error) { if len(c.rBytes) == 0 { frame, err := c.stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { return 0, io.EOF } if err != nil { diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 9f3a161175a0a..e42616c6e9594 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -19,6 +19,7 @@ package srv import ( "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -873,7 +874,7 @@ func (c *ServerContext) x11Ready() (bool, error) { // Wait for child process to send signal (1 byte) // or EOF if signal was already received. _, err := io.ReadFull(c.x11rdyr, make([]byte, 1)) - if err == io.EOF { + if errors.Is(err, io.EOF) { return true, nil } else if err != nil { return false, trace.Wrap(err) diff --git a/lib/srv/exec.go b/lib/srv/exec.go index ebc193b5429ab..b0b284eca39a6 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -19,6 +19,7 @@ package srv import ( "bufio" "context" + "errors" "fmt" "io" "os" @@ -270,7 +271,7 @@ func waitForContinue(contfd *os.File) error { // won't be closed until the parent has placed it in a cgroup. buf := make([]byte, 1) _, err := contfd.Read(buf) - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } waitCh <- err diff --git a/lib/srv/reexec.go b/lib/srv/reexec.go index 72c3b521df312..f9050aec5a97f 100644 --- a/lib/srv/reexec.go +++ b/lib/srv/reexec.go @@ -441,7 +441,7 @@ func waitForShell(termiantefd *os.File, cmd *exec.Cmd) error { // Wait for the terminate file descriptor to be closed. The FD will be closed when Teleport // parent process wants to terminate the remote command and all childs. _, err := termiantefd.Read(buf) - if err == io.EOF { + if errors.Is(err, io.EOF) { // Kill the shell process err = trace.Errorf("shell process has been killed: %w", cmd.Process.Kill()) } else { diff --git a/lib/sshutils/marshal.go b/lib/sshutils/marshal.go index 91a90a8ff5cde..2787f7c686981 100644 --- a/lib/sshutils/marshal.go +++ b/lib/sshutils/marshal.go @@ -17,6 +17,7 @@ limitations under the License. package sshutils import ( + "errors" "fmt" "io" "net/url" @@ -100,7 +101,7 @@ func UnmarshalKnownHosts(knownHostsFile [][]byte) ([]KnownHost, error) { for _, line := range knownHostsFile { for { _, hosts, publicKey, commentString, rest, err := ssh.ParseKnownHosts(line) - if err == io.EOF { + if errors.Is(err, io.EOF) { break } else if err != nil { return nil, trace.Wrap(err, "failed parsing known hosts: %v; raw line: %q", err, line) diff --git a/lib/sshutils/scp/scp.go b/lib/sshutils/scp/scp.go index c63771b218ea4..5aa3329c11a58 100644 --- a/lib/sshutils/scp/scp.go +++ b/lib/sshutils/scp/scp.go @@ -24,6 +24,7 @@ package scp import ( "bufio" + "errors" "fmt" "io" "os" @@ -415,7 +416,7 @@ func (cmd *command) serveSink(ch io.ReadWriter) error { for { n, err := ch.Read(b[:]) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return trace.Wrap(err) diff --git a/lib/utils/archive_test.go b/lib/utils/archive_test.go index ccc5cbe6505cc..bd9e30a65934e 100644 --- a/lib/utils/archive_test.go +++ b/lib/utils/archive_test.go @@ -19,6 +19,7 @@ package utils import ( "archive/tar" "compress/gzip" + "errors" "io" "io/fs" "testing" @@ -102,7 +103,7 @@ func TestCompressAsTarGzArchive(t *testing.T) { tarReader := tar.NewReader(gzipReader) for { header, err := tarReader.Next() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } require.NoError(t, err) diff --git a/lib/utils/jsontools.go b/lib/utils/jsontools.go index 13cde27bc7d8e..065fdf8aa9bc1 100644 --- a/lib/utils/jsontools.go +++ b/lib/utils/jsontools.go @@ -19,6 +19,7 @@ package utils import ( "bytes" "encoding/json" + "errors" "io" "reflect" "unicode" @@ -213,7 +214,7 @@ func ReadYAML(reader io.Reader) (interface{}, error) { var val interface{} err := decoder.Decode(&val) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { if len(values) == 0 { return nil, trace.BadParameter("no resources found, empty input?") } diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go index c48947d4b56df..351a8ecd691c3 100644 --- a/lib/utils/unpack.go +++ b/lib/utils/unpack.go @@ -18,6 +18,7 @@ package utils import ( "archive/tar" + "errors" "io" "os" "path/filepath" @@ -38,7 +39,7 @@ func Extract(r io.Reader, dir string) error { for { header, err := tarball.Next() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } else if err != nil { return trace.Wrap(err) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index dd8e53211cada..90e5d13af1da3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -2972,7 +2972,7 @@ func TestSignMTLS(t *testing.T) { tarContentFileNames := []string{} for { header, err := tarReader.Next() - if err == io.EOF { + if errors.Is(err, io.EOF) { break } require.NoError(t, err) diff --git a/lib/web/terminal.go b/lib/web/terminal.go index a09b997bf1800..2ba10e2671ff8 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -1117,7 +1117,7 @@ func (t *TerminalStream) Read(out []byte) (n int, err error) { if err != nil { // if the connection has closed, we must return io.EOF in order to abort // the websocket copy loop - if err == io.EOF || errors.Is(err, net.ErrClosed) || + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || websocket.IsCloseError(err, websocket.CloseAbnormalClosure, websocket.CloseGoingAway, websocket.CloseNormalClosure) { return 0, io.EOF } diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 1b8824b683721..aa291cf629e35 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -18,6 +18,7 @@ package common import ( "context" + "errors" "fmt" "io" "os" @@ -277,7 +278,7 @@ func (rc *ResourceCommand) Create(ctx context.Context, client auth.ClientI) (err var raw services.UnknownResource err := decoder.Decode(&raw) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { if count == 0 { return trace.BadParameter("no resources found, empty input?") } diff --git a/tool/tctl/sso/tester/command.go b/tool/tctl/sso/tester/command.go index afcbad48b79c8..5508b99d3db9d 100644 --- a/tool/tctl/sso/tester/command.go +++ b/tool/tctl/sso/tester/command.go @@ -16,6 +16,7 @@ package tester import ( "context" + "errors" "fmt" "io" "os" @@ -110,7 +111,7 @@ func (cmd *SSOTestCommand) ssoTestCommand(ctx context.Context, c auth.ClientI) e var raw services.UnknownResource err := decoder.Decode(&raw) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { return nil } return trace.Wrap(err, "Unable to load resource. Make sure the file is in correct format.")