Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ unreleased
- Support protocol 3.2, and the `min_protocol_version` and
`max_protocol_version` DSN parameters ([#1258]).

- Support `sslmode=prefer` and `sslmode=allow` ([#1270]).

### Fixes

- Fix SSL key permission check to allow modes stricter than 0600/0640#1265 ([#1265]).

[#1258]: https://github.com/lib/pq/pull/1258
[#1265]: https://github.com/lib/pq/pull/1265
[#1270]: https://github.com/lib/pq/pull/1270

v1.11.2 (2026-02-10)
--------------------
Expand Down
58 changes: 33 additions & 25 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,27 @@ func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {

func (c *Connector) open(ctx context.Context) (*conn, error) {
tsa := c.cfg.TargetSessionAttrs
restart:
restartAll:
var (
errs []error
app = func(err error, cfg Config) bool {
if err != nil {
if debugProto {
fmt.Println("CONNECT (error)", err)
fmt.Fprintln(os.Stderr, "CONNECT (error)", err)
}
errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err))
}
return err != nil
}
)
for _, cfg := range c.cfg.hosts() {
mode := cfg.SSLMode
restartHost:
if debugProto {
fmt.Println("CONNECT ", cfg.string())
fmt.Fprintln(os.Stderr, "CONNECT ", cfg.string())
}

cfg.SSLMode = mode
cn := &conn{cfg: cfg, dialer: c.dialer}
cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password,
cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database, cn.cfg.isset("password"))
Expand All @@ -271,7 +274,11 @@ restart:
continue
}

err = cn.ssl(cn.cfg)
err = cn.ssl(cn.cfg, mode)
if err != nil && mode == SSLModePrefer {
mode = SSLModeDisable
goto restartHost
}
if app(err, cfg) {
if cn.c != nil {
_ = cn.c.Close()
Expand All @@ -281,6 +288,10 @@ restart:

cn.buf = bufio.NewReader(cn.c)
err = cn.startup(cn.cfg)
if err != nil && mode == SSLModeAllow {
mode = SSLModeRequire
goto restartHost
}
if app(err, cfg) {
_ = cn.c.Close()
continue
Expand Down Expand Up @@ -308,7 +319,7 @@ restart:
// ran out of hosts so none are on standby. Clear the setting and try again.
if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby {
tsa = TargetSessionAttrsAny
goto restart
goto restartAll
}

if len(c.cfg.Multi) == 0 {
Expand Down Expand Up @@ -568,8 +579,8 @@ func (cn *conn) gname() string {

func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.simpleExec\n")
defer fmt.Fprintf(os.Stderr, " END conn.simpleExec\n")
fmt.Fprintln(os.Stderr, " START conn.simpleExec")
defer fmt.Fprintln(os.Stderr, " END conn.simpleExec")
}

b := cn.writeBuf(proto.Query)
Expand Down Expand Up @@ -611,8 +622,8 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resE

func (cn *conn) simpleQuery(q string) (*rows, error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.simpleQuery\n")
defer fmt.Fprintf(os.Stderr, " END conn.simpleQuery\n")
fmt.Fprintln(os.Stderr, " START conn.simpleQuery")
defer fmt.Fprintln(os.Stderr, " END conn.simpleQuery")
}

b := cn.writeBuf(proto.Query)
Expand Down Expand Up @@ -740,8 +751,8 @@ func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format,

func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.prepareTo\n")
defer fmt.Fprintf(os.Stderr, " END conn.prepareTo\n")
fmt.Fprintln(os.Stderr, " START conn.prepareTo")
defer fmt.Fprintln(os.Stderr, " END conn.prepareTo")
}

st := &stmt{cn: cn, name: stmtName}
Expand Down Expand Up @@ -865,8 +876,8 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {

func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.query\n")
defer fmt.Fprintf(os.Stderr, " END conn.query\n")
fmt.Fprintln(os.Stderr, " START conn.query")
defer fmt.Fprintln(os.Stderr, " END conn.query")
}
if err := cn.err.get(); err != nil {
return nil, err
Expand Down Expand Up @@ -1000,9 +1011,7 @@ func (cn *conn) sendStartupPacket(m *writeBuf) error {
if debugProto {
w := m.wrap()
fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n",
"Startup",
int(binary.BigEndian.Uint32(w[1:5]))-4,
w[5:])
"Startup", int(binary.BigEndian.Uint32(w[1:5]))-4, w[5:])
}
_, err := cn.c.Write((m.wrap())[1:])
return err
Expand All @@ -1012,8 +1021,7 @@ func (cn *conn) sendStartupPacket(m *writeBuf) error {
// should have no payload. This method does not use the scratch buffer.
func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error {
if debugProto {
fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n",
proto.RequestCode(typ), 0, []byte{})
fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", proto.RequestCode(typ), 0, []byte{})
}
_, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'})
return err
Expand Down Expand Up @@ -1079,7 +1087,7 @@ func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) {
}
*r = y
if debugProto {
fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y)
fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", proto.ResponseCode(t), n, y)
}
return t, nil
}
Expand Down Expand Up @@ -1150,19 +1158,19 @@ func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) {
return t, r, nil
}

func (cn *conn) ssl(cfg Config) error {
upgrade, err := ssl(cfg)
// Don't refer to Config.SSLMode here, as the mode in arguments may be different
// in case of sslmode=allow or prefer.
func (cn *conn) ssl(cfg Config, mode SSLMode) error {
upgrade, err := ssl(cfg, mode)
if err != nil {
return err
}

if upgrade == nil {
// Nothing to do
return nil
return nil // Nothing to do
}

// Only negotiate the ssl handshake if requested (which is the default).
// sllnegotiation=direct is supported by pg17 and above.
// sslnegotiation=direct is supported by pg17 and above.
if cfg.SSLNegotiation != SSLNegotiationDirect {
w := cn.writeBuf(0)
w.int32(proto.NegotiateSSLCode)
Expand Down
4 changes: 2 additions & 2 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (cn *conn) Ping(ctx context.Context) error {
}
rows, err := cn.simpleQuery(";")
if err != nil {
return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
return driver.ErrBadConn
}
_ = rows.Close()
return nil
Expand Down Expand Up @@ -144,7 +144,7 @@ func (cn *conn) cancel(ctx context.Context) error {
defer func() { _ = c.Close() }()

cn2 := conn{c: c}
err = cn2.ssl(cfg)
err = cn2.ssl(cfg, cfg.SSLMode)
if err != nil {
return err
}
Expand Down
28 changes: 21 additions & 7 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,28 @@ type (

// Values for [SSLMode] that pq supports.
const (
// disable: No SSL
// No SSL
SSLModeDisable = SSLMode("disable")

// require: require SSL, but skip verification.
// First try a non-SSL connection and if that fails try an SSL connection.
SSLModeAllow = SSLMode("allow")

// First try an SSL connection and if that fails try a non-SSL connection.
SSLModePrefer = SSLMode("prefer")

// Require SSL, but skip verification. This is the default.
SSLModeRequire = SSLMode("require")

// verify-ca: require SSL and verify that the certificate was signed by a
// trusted CA.
// Require SSL and verify that the certificate was signed by a trusted CA.
SSLModeVerifyCA = SSLMode("verify-ca")

// verify-full: require SSK and verify that the certificate was signed by a
// trusted CA and the server host name matches the one in the certificate.
// Require SSL and verify that the certificate was signed by a trusted CA
// and the server host name matches the one in the certificate.
SSLModeVerifyFull = SSLMode("verify-full")
)

var sslModes = []SSLMode{SSLModeDisable, SSLModeRequire, SSLModeVerifyFull, SSLModeVerifyCA}
var sslModes = []SSLMode{SSLModeDisable, SSLModeAllow, SSLModePrefer, SSLModeRequire,
SSLModeVerifyFull, SSLModeVerifyCA}

// Values for [SSLNegotiation] that pq supports.
const (
Expand Down Expand Up @@ -536,6 +542,14 @@ func newConfig(dsn string, env []string) (Config, error) {
return Config{}, fmt.Errorf("pq: min_protocol_version %q cannot be greater than max_protocol_version %q",
cfg.MinProtocolVersion, cfg.MaxProtocolVersion)
}
if cfg.SSLNegotiation == SSLNegotiationDirect {
switch cfg.SSLMode {
case SSLModeDisable, SSLModeAllow, SSLModePrefer:
return Config{}, fmt.Errorf(
`pq: weak sslmode %q may not be used with sslnegotiation=direct (use "require", "verify-ca", or "verify-full")`,
cfg.SSLMode)
}
}

return cfg, nil
}
Expand Down
7 changes: 7 additions & 0 deletions deprecated.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package pq

import "database/sql"

// PGError is an interface used by previous versions of pq.
//
// Deprecated: use the Error type. This is never used.
Expand Down Expand Up @@ -57,3 +59,8 @@ func (e *Error) Get(k byte) (v string) {
// Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...")
// now works, and calling this manually is no longer required.
func ParseURL(url string) (string, error) { return convertURL(url) }

// NullTime represents a [time.Time] that may be null.
//
// Deprecated: this is an alias for [sql.NullTime].
type NullTime = sql.NullTime
8 changes: 0 additions & 8 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pq

import (
"bytes"
"database/sql"
"encoding/binary"
"encoding/hex"
"errors"
Expand Down Expand Up @@ -603,10 +602,3 @@ func encodeBytea(v []byte) (result []byte) {
hex.Encode(result[2:], v)
return result
}

// NullTime represents a [time.Time] that may be null.
// NullTime implements the [sql.Scanner] interface so
// it can be used as a scan destination, similar to [sql.NullString].
//
// Deprecated: this is an alias for [sql.NullTime].
type NullTime = sql.NullTime
7 changes: 7 additions & 0 deletions internal/pqtest/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ func (f Fake) Startup(cn net.Conn, params map[string]string) {
// ReadStartup reads the startup message.
func (f Fake) ReadStartup(cn net.Conn) (float32, map[string]string, bool) {
_, msg, ok := f.read(cn, true)

if len(msg) == 4 && binary.BigEndian.Uint32(msg) == proto.NegotiateSSLCode {
f.WriteMsg(cn, proto.ErrorResponse, "SFATAL\x00VFATAL\x00C28000\x00"+
"encryption not supported\x00Fauth.c\x00L462\x00RClientAuthentication\x00\x00")
return 3.0, nil, false
}

var (
params = make(map[string]string)
m = strings.Split(string(msg[4:len(msg)-2]), "\x00")
Expand Down
45 changes: 23 additions & 22 deletions ssl.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,31 @@ func getTLSConfigClone(key string) *tls.Config {

// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
// related settings. The function is nil when no upgrade should take place.
func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) {
//
// Don't refer to Config.SSLMode here, as the mode in arguments may be different
// in case of sslmode=allow or prefer.
func ssl(cfg Config, mode SSLMode) (func(net.Conn) (net.Conn, error), error) {
var (
verifyCaOnly = false
tlsConf = &tls.Config{}
mode = cfg.SSLMode
)
switch {
// "require" is the default.
case mode == "" || mode == SSLModeRequire:
case mode == SSLModeDisable || mode == SSLModeAllow:
return nil, nil

case mode == "" || mode == SSLModeRequire || mode == SSLModePrefer:
// We must skip TLS's own verification since it requires full
// verification since Go 1.3.
tlsConf.InsecureSkipVerify = true

// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
//
// Note: For backwards compatibility with earlier versions of
// PostgreSQL, if a root CA file exists, the behavior of
// sslmode=require will be the same as that of verify-ca, meaning the
// server certificate is validated against the CA. Relying on this
// behavior is discouraged, and applications that need certificate
// validation should always use verify-ca or verify-full.
// For backwards compatibility with earlier versions of PostgreSQL, if a
// root CA file exists, the behavior of sslmode=require will be the same
// as that of verify-ca, meaning the server certificate is validated
// against the CA. Relying on this behavior is discouraged, and
// applications that need certificate validation should always use
// verify-ca or verify-full.
if cfg.SSLRootCert != "" {
if _, err := os.Stat(cfg.SSLRootCert); err == nil {
verifyCaOnly = true
Expand All @@ -94,24 +98,19 @@ func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) {
verifyCaOnly = true
case mode == SSLModeVerifyFull:
tlsConf.ServerName = cfg.Host
case mode == SSLModeDisable:
return nil, nil
case strings.HasPrefix(string(mode), "pqgo-"):
tlsConf = getTLSConfigClone(string(mode[5:]))
if tlsConf == nil {
return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode)
}
default:
return nil, fmt.Errorf(
`pq: unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`,
mode)
panic("unreachable")
}

// Set Server Name Indication (SNI), if enabled by connection parameters.
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 or
// IPv6). This check is coded already crypto.tls.hostnameInSNI, so just
// always set ServerName here and let crypto/tls do the filtering.
if cfg.SSLSNI {
// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4
// or IPv6). This check is coded already crypto.tls.hostnameInSNI, so
// just always set ServerName here and let crypto/tls do the filtering.
tlsConf.ServerName = cfg.Host
}

Expand All @@ -126,9 +125,11 @@ func ssl(cfg Config) (func(net.Conn) (net.Conn, error), error) {

// Accept renegotiation requests initiated by the backend.
//
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but
// the default configuration of older versions has it enabled. Redshift
// also initiates renegotiations and cannot be reconfigured.
// Renegotiation was deprecated then removed from PostgreSQL 9.5, but the
// default configuration of older versions has it enabled. Redshift also
// initiates renegotiations and cannot be reconfigured.
//
// TODO: I think this can be removed?
tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient

return func(conn net.Conn) (net.Conn, error) {
Expand Down
Loading
Loading