diff --git a/ssh/server.go b/ssh/server.go index e2ae4f891b..3ca9e89e22 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -462,6 +462,24 @@ func (p *PartialSuccessError) Error() string { // It is returned in ServerAuthError.Errors from NewServerConn. var ErrNoAuth = errors.New("ssh: no auth passed yet") +// BannerError is an error that can be returned by authentication handlers in +// ServerConfig to send a banner message to the client. +type BannerError struct { + Err error + Message string +} + +func (b *BannerError) Unwrap() error { + return b.Err +} + +func (b *BannerError) Error() string { + if b.Err == nil { + return b.Message + } + return b.Err.Error() +} + func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { sessionID := s.transport.getSessionID() var cache pubKeyCache @@ -734,6 +752,18 @@ userAuthLoop: config.AuthLogCallback(s, userAuthReq.Method, authErr) } + var bannerErr *BannerError + if errors.As(authErr, &bannerErr) { + if bannerErr.Message != "" { + bannerMsg := &userAuthBannerMsg{ + Message: bannerErr.Message, + } + if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil { + return nil, err + } + } + } + if authErr == nil { break userAuthLoop } diff --git a/ssh/server_test.go b/ssh/server_test.go index 5b47b9e0af..9057a9b5f0 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -6,8 +6,10 @@ package ssh import ( "errors" + "fmt" "io" "net" + "slices" "strings" "sync/atomic" "testing" @@ -225,6 +227,78 @@ func TestNewServerConnValidationErrors(t *testing.T) { } } +func TestBannerError(t *testing.T) { + serverConfig := &ServerConfig{ + BannerCallback: func(ConnMetadata) string { + return "banner from BannerCallback" + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + err := &BannerError{ + Err: errors.New("error from NoClientAuthCallback"), + Message: "banner from NoClientAuthCallback", + } + return nil, fmt.Errorf("wrapped: %w", err) + }, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, &BannerError{ + Err: errors.New("error from PublicKeyCallback"), + Message: "banner from PublicKeyCallback", + } + }, + KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) { + return nil, &BannerError{ + Err: nil, // make sure that a nil inner error is allowed + Message: "banner from KeyboardInteractiveCallback", + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{"letmein"}, nil + }), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + banners = append(banners, msg) + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "banner from BannerCallback", + "banner from NoClientAuthCallback", + "banner from PublicKeyCallback", + "banner from KeyboardInteractiveCallback", + } + if !slices.Equal(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } +} + type markerConn struct { closed uint32 used uint32