From 44c9b0ff9e71f015c49f686c68a7950fac76623c Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Thu, 25 Jan 2024 18:32:22 -0700 Subject: [PATCH] ssh: allow server auth callbacks to send additional banners Add a new BannerError error type that auth callbacks can return to send banner to the client. While the BannerCallback can send the initial banner message, auth callbacks might want to communicate more information to the client to help them diagnose failures. Updates golang/go#64962 Change-Id: I97a26480ff4064b95a0a26042b0a5e19737cfb62 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/558695 LUCI-TryBot-Result: Go LUCI Reviewed-by: Roland Shoemaker Reviewed-by: Nicola Murino Auto-Submit: Nicola Murino Reviewed-by: Dmitri Shuralyov --- ssh/server.go | 30 +++++++++++++++++++ ssh/server_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/ssh/server.go b/ssh/server.go index e2ae4f891b..3ca9e89e22 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -462,6 +462,24 @@ func (p *PartialSuccessError) Error() string { // It is returned in ServerAuthError.Errors from NewServerConn. var ErrNoAuth = errors.New("ssh: no auth passed yet") +// BannerError is an error that can be returned by authentication handlers in +// ServerConfig to send a banner message to the client. +type BannerError struct { + Err error + Message string +} + +func (b *BannerError) Unwrap() error { + return b.Err +} + +func (b *BannerError) Error() string { + if b.Err == nil { + return b.Message + } + return b.Err.Error() +} + func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { sessionID := s.transport.getSessionID() var cache pubKeyCache @@ -734,6 +752,18 @@ userAuthLoop: config.AuthLogCallback(s, userAuthReq.Method, authErr) } + var bannerErr *BannerError + if errors.As(authErr, &bannerErr) { + if bannerErr.Message != "" { + bannerMsg := &userAuthBannerMsg{ + Message: bannerErr.Message, + } + if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil { + return nil, err + } + } + } + if authErr == nil { break userAuthLoop } diff --git a/ssh/server_test.go b/ssh/server_test.go index 5b47b9e0af..9057a9b5f0 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -6,8 +6,10 @@ package ssh import ( "errors" + "fmt" "io" "net" + "slices" "strings" "sync/atomic" "testing" @@ -225,6 +227,78 @@ func TestNewServerConnValidationErrors(t *testing.T) { } } +func TestBannerError(t *testing.T) { + serverConfig := &ServerConfig{ + BannerCallback: func(ConnMetadata) string { + return "banner from BannerCallback" + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + err := &BannerError{ + Err: errors.New("error from NoClientAuthCallback"), + Message: "banner from NoClientAuthCallback", + } + return nil, fmt.Errorf("wrapped: %w", err) + }, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, &BannerError{ + Err: errors.New("error from PublicKeyCallback"), + Message: "banner from PublicKeyCallback", + } + }, + KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) { + return nil, &BannerError{ + Err: nil, // make sure that a nil inner error is allowed + Message: "banner from KeyboardInteractiveCallback", + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{"letmein"}, nil + }), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + banners = append(banners, msg) + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "banner from BannerCallback", + "banner from NoClientAuthCallback", + "banner from PublicKeyCallback", + "banner from KeyboardInteractiveCallback", + } + if !slices.Equal(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } +} + type markerConn struct { closed uint32 used uint32