diff --git a/CHANGELOG.md b/CHANGELOG.md index ba6c1a99b..a0e93ec62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,12 +6,15 @@ unreleased - Support protocol 3.2, and the `min_protocol_version` and `max_protocol_version` DSN parameters ([#1258]). +- Support `sslmode=prefer` and `sslmode=allow` ([#1270]). + ### Fixes - Fix SSL key permission check to allow modes stricter than 0600/0640#1265 ([#1265]). [#1258]: https://github.com/lib/pq/pull/1258 [#1265]: https://github.com/lib/pq/pull/1265 +[#1270]: https://github.com/lib/pq/pull/1270 v1.11.2 (2026-02-10) -------------------- diff --git a/conn.go b/conn.go index a28803fb3..e713ec38b 100644 --- a/conn.go +++ b/conn.go @@ -243,13 +243,13 @@ func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { func (c *Connector) open(ctx context.Context) (*conn, error) { tsa := c.cfg.TargetSessionAttrs -restart: +restartAll: var ( errs []error app = func(err error, cfg Config) bool { if err != nil { if debugProto { - fmt.Println("CONNECT (error)", err) + fmt.Fprintln(os.Stderr, "CONNECT (error)", err) } errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err)) } @@ -257,10 +257,13 @@ restart: } ) for _, cfg := range c.cfg.hosts() { + mode := cfg.SSLMode + restartHost: if debugProto { - fmt.Println("CONNECT ", cfg.string()) + fmt.Fprintln(os.Stderr, "CONNECT ", cfg.string()) } + cfg.SSLMode = mode cn := &conn{cfg: cfg, dialer: c.dialer} cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password, cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database, cn.cfg.isset("password")) @@ -271,7 +274,11 @@ restart: continue } - err = cn.ssl(cn.cfg) + err = cn.ssl(cn.cfg, mode) + if err != nil && mode == SSLModePrefer { + mode = SSLModeDisable + goto restartHost + } if app(err, cfg) { if cn.c != nil { _ = cn.c.Close() @@ -281,6 +288,10 @@ restart: cn.buf = bufio.NewReader(cn.c) err = cn.startup(cn.cfg) + if err != nil && mode == SSLModeAllow { + mode = SSLModeRequire + goto restartHost + } if app(err, cfg) { _ = cn.c.Close() continue @@ -308,7 +319,7 @@ restart: // ran out of hosts so none are on standby. Clear the setting and try again. if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby { tsa = TargetSessionAttrsAny - goto restart + goto restartAll } if len(c.cfg.Multi) == 0 { @@ -568,8 +579,8 @@ func (cn *conn) gname() string { func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) { if debugProto { - fmt.Fprintf(os.Stderr, " START conn.simpleExec\n") - defer fmt.Fprintf(os.Stderr, " END conn.simpleExec\n") + fmt.Fprintln(os.Stderr, " START conn.simpleExec") + defer fmt.Fprintln(os.Stderr, " END conn.simpleExec") } b := cn.writeBuf(proto.Query) @@ -611,8 +622,8 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resE func (cn *conn) simpleQuery(q string) (*rows, error) { if debugProto { - fmt.Fprintf(os.Stderr, " START conn.simpleQuery\n") - defer fmt.Fprintf(os.Stderr, " END conn.simpleQuery\n") + fmt.Fprintln(os.Stderr, " START conn.simpleQuery") + defer fmt.Fprintln(os.Stderr, " END conn.simpleQuery") } b := cn.writeBuf(proto.Query) @@ -740,8 +751,8 @@ func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { if debugProto { - fmt.Fprintf(os.Stderr, " START conn.prepareTo\n") - defer fmt.Fprintf(os.Stderr, " END conn.prepareTo\n") + fmt.Fprintln(os.Stderr, " START conn.prepareTo") + defer fmt.Fprintln(os.Stderr, " END conn.prepareTo") } st := &stmt{cn: cn, name: stmtName} @@ -865,8 +876,8 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { if debugProto { - fmt.Fprintf(os.Stderr, " START conn.query\n") - defer fmt.Fprintf(os.Stderr, " END conn.query\n") + fmt.Fprintln(os.Stderr, " START conn.query") + defer fmt.Fprintln(os.Stderr, " END conn.query") } if err := cn.err.get(); err != nil { return nil, err @@ -1000,9 +1011,7 @@ func (cn *conn) sendStartupPacket(m *writeBuf) error { if debugProto { w := m.wrap() fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", - "Startup", - int(binary.BigEndian.Uint32(w[1:5]))-4, - w[5:]) + "Startup", int(binary.BigEndian.Uint32(w[1:5]))-4, w[5:]) } _, err := cn.c.Write((m.wrap())[1:]) return err @@ -1012,8 +1021,7 @@ func (cn *conn) sendStartupPacket(m *writeBuf) error { // should have no payload. This method does not use the scratch buffer. func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error { if debugProto { - fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", - proto.RequestCode(typ), 0, []byte{}) + fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", proto.RequestCode(typ), 0, []byte{}) } _, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'}) return err @@ -1079,7 +1087,7 @@ func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) { } *r = y if debugProto { - fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y) + fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", proto.ResponseCode(t), n, y) } return t, nil } @@ -1150,19 +1158,19 @@ func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { return t, r, nil } -func (cn *conn) ssl(cfg Config) error { - upgrade, err := ssl(cfg) +// Don't refer to Config.SSLMode here, as the mode in arguments may be different +// in case of sslmode=allow or prefer. +func (cn *conn) ssl(cfg Config, mode SSLMode) error { + upgrade, err := ssl(cfg, mode) if err != nil { return err } - if upgrade == nil { - // Nothing to do - return nil + return nil // Nothing to do } // Only negotiate the ssl handshake if requested (which is the default). - // sllnegotiation=direct is supported by pg17 and above. + // sslnegotiation=direct is supported by pg17 and above. if cfg.SSLNegotiation != SSLNegotiationDirect { w := cn.writeBuf(0) w.int32(proto.NegotiateSSLCode) diff --git a/conn_go18.go b/conn_go18.go index d776175e7..16de38ebe 100644 --- a/conn_go18.go +++ b/conn_go18.go @@ -86,7 +86,7 @@ func (cn *conn) Ping(ctx context.Context) error { } rows, err := cn.simpleQuery(";") if err != nil { - return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger + return driver.ErrBadConn } _ = rows.Close() return nil @@ -144,7 +144,7 @@ func (cn *conn) cancel(ctx context.Context) error { defer func() { _ = c.Close() }() cn2 := conn{c: c} - err = cn2.ssl(cfg) + err = cn2.ssl(cfg, cfg.SSLMode) if err != nil { return err } diff --git a/connector.go b/connector.go index 9b1c193a6..51931fbee 100644 --- a/connector.go +++ b/connector.go @@ -42,22 +42,28 @@ type ( // Values for [SSLMode] that pq supports. const ( - // disable: No SSL + // No SSL SSLModeDisable = SSLMode("disable") - // require: require SSL, but skip verification. + // First try a non-SSL connection and if that fails try an SSL connection. + SSLModeAllow = SSLMode("allow") + + // First try an SSL connection and if that fails try a non-SSL connection. + SSLModePrefer = SSLMode("prefer") + + // Require SSL, but skip verification. This is the default. SSLModeRequire = SSLMode("require") - // verify-ca: require SSL and verify that the certificate was signed by a - // trusted CA. + // Require SSL and verify that the certificate was signed by a trusted CA. SSLModeVerifyCA = SSLMode("verify-ca") - // verify-full: require SSK and verify that the certificate was signed by a - // trusted CA and the server host name matches the one in the certificate. + // Require SSL and verify that the certificate was signed by a trusted CA + // and the server host name matches the one in the certificate. SSLModeVerifyFull = SSLMode("verify-full") ) -var sslModes = []SSLMode{SSLModeDisable, SSLModeRequire, SSLModeVerifyFull, SSLModeVerifyCA} +var sslModes = []SSLMode{SSLModeDisable, SSLModeAllow, SSLModePrefer, SSLModeRequire, + SSLModeVerifyFull, SSLModeVerifyCA} // Values for [SSLNegotiation] that pq supports. const ( @@ -536,6 +542,14 @@ func newConfig(dsn string, env []string) (Config, error) { return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q", cfg.MinProtocolVersion, cfg.MaxProtocolVersion) } + if cfg.SSLNegotiation == SSLNegotiationDirect { + switch cfg.SSLMode { + case SSLModeDisable, SSLModeAllow, SSLModePrefer: + return Config{}, fmt.Errorf( + `pq: weak sslmode %q may not be used with sslnegotiation=direct (use "require", "verify-ca", or "verify-full")`, + cfg.SSLMode) + } + } return cfg, nil } diff --git a/deprecated.go b/deprecated.go index 0def49de9..e197922c6 100644 --- a/deprecated.go +++ b/deprecated.go @@ -1,5 +1,7 @@ package pq +import "database/sql" + // PGError is an interface used by previous versions of pq. // // Deprecated: use the Error type. This is never used. @@ -57,3 +59,8 @@ func (e *Error) Get(k byte) (v string) { // Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...") // now works, and calling this manually is no longer required. func ParseURL(url string) (string, error) { return convertURL(url) } + +// NullTime represents a [time.Time] that may be null. +// +// Deprecated: this is an alias for [sql.NullTime]. +type NullTime = sql.NullTime diff --git a/encode.go b/encode.go index e43fc93d6..05c5c0102 100644 --- a/encode.go +++ b/encode.go @@ -2,7 +2,6 @@ package pq import ( "bytes" - "database/sql" "encoding/binary" "encoding/hex" "errors" @@ -603,10 +602,3 @@ func encodeBytea(v []byte) (result []byte) { hex.Encode(result[2:], v) return result } - -// NullTime represents a [time.Time] that may be null. -// NullTime implements the [sql.Scanner] interface so -// it can be used as a scan destination, similar to [sql.NullString]. -// -// Deprecated: this is an alias for [sql.NullTime]. -type NullTime = sql.NullTime diff --git a/internal/pqtest/fake.go b/internal/pqtest/fake.go index 2507f08b7..a3df92079 100644 --- a/internal/pqtest/fake.go +++ b/internal/pqtest/fake.go @@ -106,6 +106,13 @@ func (f Fake) Startup(cn net.Conn, params map[string]string) { // ReadStartup reads the startup message. func (f Fake) ReadStartup(cn net.Conn) (float32, map[string]string, bool) { _, msg, ok := f.read(cn, true) + + if len(msg) == 4 && binary.BigEndian.Uint32(msg) == proto.NegotiateSSLCode { + f.WriteMsg(cn, proto.ErrorResponse, "SFATAL\x00VFATAL\x00C28000\x00"+ + "encryption not supported\x00Fauth.c\x00L462\x00RClientAuthentication\x00\x00") + return 3.0, nil, false + } + var ( params = make(map[string]string) m = strings.Split(string(msg[4:len(msg)-2]), "\x00") diff --git a/ssl.go b/ssl.go index 3aea110eb..caad8015a 100644 --- a/ssl.go +++ b/ssl.go @@ -59,27 +59,31 @@ func getTLSConfigClone(key string) *tls.Config { // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. -func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) { +// +// Don't refer to Config.SSLMode here, as the mode in arguments may be different +// in case of sslmode=allow or prefer. +func ssl(cfg Config, mode SSLMode) (func(net.Conn) (net.Conn, error), error) { var ( verifyCaOnly = false tlsConf = &tls.Config{} - mode = cfg.SSLMode ) switch { - // "require" is the default. - case mode == "" || mode == SSLModeRequire: + case mode == SSLModeDisable || mode == SSLModeAllow: + return nil, nil + + case mode == "" || mode == SSLModeRequire || mode == SSLModePrefer: // We must skip TLS's own verification since it requires full // verification since Go 1.3. tlsConf.InsecureSkipVerify = true // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: // - // Note: For backwards compatibility with earlier versions of - // PostgreSQL, if a root CA file exists, the behavior of - // sslmode=require will be the same as that of verify-ca, meaning the - // server certificate is validated against the CA. Relying on this - // behavior is discouraged, and applications that need certificate - // validation should always use verify-ca or verify-full. + // For backwards compatibility with earlier versions of PostgreSQL, if a + // root CA file exists, the behavior of sslmode=require will be the same + // as that of verify-ca, meaning the server certificate is validated + // against the CA. Relying on this behavior is discouraged, and + // applications that need certificate validation should always use + // verify-ca or verify-full. if cfg.SSLRootCert != "" { if _, err := os.Stat(cfg.SSLRootCert); err == nil { verifyCaOnly = true @@ -94,24 +98,19 @@ func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) { verifyCaOnly = true case mode == SSLModeVerifyFull: tlsConf.ServerName = cfg.Host - case mode == SSLModeDisable: - return nil, nil case strings.HasPrefix(string(mode), "pqgo-"): tlsConf = getTLSConfigClone(string(mode[5:])) if tlsConf == nil { return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode) } default: - return nil, fmt.Errorf( - `pq: unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, - mode) + panic("unreachable") } - // Set Server Name Indication (SNI), if enabled by connection parameters. + // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 or + // IPv6). This check is coded already crypto.tls.hostnameInSNI, so just + // always set ServerName here and let crypto/tls do the filtering. if cfg.SSLSNI { - // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 - // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so - // just always set ServerName here and let crypto/tls do the filtering. tlsConf.ServerName = cfg.Host } @@ -126,9 +125,11 @@ func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) { // Accept renegotiation requests initiated by the backend. // - // Renegotiation was deprecated then removed from PostgreSQL 9.5, but - // the default configuration of older versions has it enabled. Redshift - // also initiates renegotiations and cannot be reconfigured. + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but the + // default configuration of older versions has it enabled. Redshift also + // initiates renegotiations and cannot be reconfigured. + // + // TODO: I think this can be removed? tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { diff --git a/ssl_test.go b/ssl_test.go index cb497ea02..af038cebc 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/lib/pq/internal/pqtest" + "github.com/lib/pq/internal/proto" ) func openSSLConn(t *testing.T, conninfo ...string) (*sql.DB, error) { @@ -39,35 +40,70 @@ func startSSLTest(t *testing.T, user string) { } func TestSSLMode(t *testing.T) { + f := pqtest.NewFake(t, func(f pqtest.Fake, cn net.Conn) { + f.Startup(cn, nil) + for { + code, _, ok := f.ReadMsg(cn) + if !ok { + return + } + switch code { + case proto.Query: + f.WriteMsg(cn, proto.EmptyQueryResponse, "") + f.WriteMsg(cn, proto.ReadyForQuery, "I") + case proto.Terminate: + cn.Close() + return + } + } + }) + tests := []struct { connect string - wantErr bool + wantErr string }{ // sslmode=require: require SSL, but don't verify certificate. - {"sslmode=require user=pqgossl", false}, + {"sslmode=require user=pqgossl", ""}, + {"sslmode=require " + f.DSN(), "pq: SSL is not enabled on the server"}, // sslmode=verify-ca: verify that the certificate was signed by a trusted CA - {"host=postgres sslmode=verify-ca user=pqgossl", true}, - {"host=postgres sslmode=verify-ca user=pqgossl sslrootcert=''", true}, - - {"sslrootcert=testdata/init/root.crt sslmode=verify-ca user=pqgossl host=127.0.0.1", false}, - {"sslrootcert=testdata/init/root.crt sslmode=verify-ca user=pqgossl host=postgres-invalid", false}, - {"sslrootcert=testdata/init/root.crt sslmode=verify-ca user=pqgossl host=postgres", false}, + {"host=postgres sslmode=verify-ca user=pqgossl", "invalid-cert"}, + {"host=postgres sslmode=verify-ca user=pqgossl sslrootcert=''", "invalid-cert"}, + {"sslrootcert=testdata/init/root.crt sslmode=verify-ca user=pqgossl host=127.0.0.1", ""}, + {"sslrootcert=testdata/init/root.crt sslmode=verify-ca user=pqgossl host=postgres-invalid", ""}, + {"sslrootcert=testdata/init/root.crt sslmode=verify-ca user=pqgossl host=postgres", ""}, // sslmode=verify-full: verify that the certification was signed by a trusted CA and the host matches - {"sslmode=verify-full user=pqgossl host=postgres", true}, - {"sslrootcert=testdata/init/root.crt sslmode=verify-full user=pqgossl host=127.0.0.1", true}, - {"sslrootcert=testdata/init/root.crt sslmode=verify-full user=pqgossl host=postgres-invalid", true}, - - {"sslrootcert=testdata/init/root.crt sslmode=verify-full user=pqgossl host=postgres", false}, + {"sslmode=verify-full user=pqgossl host=postgres", "invalid-cert"}, + {"sslrootcert=testdata/init/root.crt sslmode=verify-full user=pqgossl host=127.0.0.1", "invalid-cert"}, + {"sslrootcert=testdata/init/root.crt sslmode=verify-full user=pqgossl host=postgres-invalid", "invalid-cert"}, + {"sslrootcert=testdata/init/root.crt sslmode=verify-full user=pqgossl host=postgres", ""}, // With root cert - {"sslrootcert=testdata/init/bogus_root.crt host=postgres sslmode=require user=pqgossl", true}, - - {"sslrootcert=testdata/init/non_existent.crt host=127.0.0.1 sslmode=require user=pqgossl", false}, - {"sslrootcert=testdata/init/root.crt host=127.0.0.1 sslmode=require user=pqgossl", false}, - {"sslrootcert=testdata/init/root.crt host=postgres sslmode=require user=pqgossl", false}, - {"sslrootcert=testdata/init/root.crt host=postgres-invalid sslmode=require user=pqgossl", false}, + {"sslrootcert=testdata/init/bogus_root.crt host=postgres sslmode=require user=pqgossl", "invalid-cert"}, + {"sslrootcert=testdata/init/non_existent.crt host=127.0.0.1 sslmode=require user=pqgossl", ""}, + {"sslrootcert=testdata/init/root.crt host=127.0.0.1 sslmode=require user=pqgossl", ""}, + {"sslrootcert=testdata/init/root.crt host=postgres sslmode=require user=pqgossl", ""}, + {"sslrootcert=testdata/init/root.crt host=postgres-invalid sslmode=require user=pqgossl", ""}, + + // sslmode=prefer + {"sslmode=prefer user=pqgossl", ""}, + {"sslmode=prefer", ""}, + {"sslmode=prefer user=pqgossl " + f.DSN(), ""}, // Doesn't support SSL, so try again without. + + // sslmode=allow + {"sslmode=allow user=pqgossl", ""}, // Requires SSL, so will try again + {"sslmode=allow", ""}, // Doesn't need SSL, should just work. + {"sslmode=allow " + f.DSN(), ""}, // Idem + + // sslmode=disable + {"sslmode=disable user=pqgossl", "no encryption"}, + + // sslnegotiation=direct should fail if ssl isn't required, like libpq: + // psql: error: weak sslmode "allow" may not be used with sslnegotiation=direct (use "require", "verify-ca", or "verify-full") + {"sslmode=disable sslnegotiation=direct", "weak sslmode"}, + {"sslmode=allow sslnegotiation=direct", "weak sslmode"}, + {"sslmode=prefer sslnegotiation=direct", "weak sslmode"}, } startSSLTest(t, "pqgossl") @@ -76,13 +112,27 @@ func TestSSLMode(t *testing.T) { tt := tt t.Run("", func(t *testing.T) { t.Parallel() + + if tt.wantErr == "no encryption" && pqtest.Pgbouncer() { + // PostgreSQL repsonds with: + // pq: pg_hba.conf rejects connection for host "172.18.0.1", user "pqgossl", database "pqgo", no encryption (28000) + // + // But pgbouncer has a different message and code: + // pq: login rejected (08P01) + tt.wantErr = "login rejected" + } + _, err := openSSLConn(t, tt.connect) - if tt.wantErr { + t.Log(tt.connect) + switch { + case tt.wantErr == "" && err != nil: + t.Fatalf("\nfailed for %q\n%s", tt.connect, err) + case tt.wantErr == "invalid-cert": if !pqtest.InvalidCertificate(err) { t.Fatalf("wrong error type %T: %[1]s", err) } - } else if err != nil { - t.Errorf("\nfailed for %q\n%s", tt.connect, err) + case !pqtest.ErrorContains(err, tt.wantErr): + t.Fatalf("wrong error\nwant: %s\nhave: %s", tt.wantErr, err) } }) }