diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..bb69c2ad6 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: arp242 diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e00164d0..e7f9bc019 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ newer. Previously PostgreSQL 8.4 and newer were supported. - Allow using a custom `tls.Config`, for example for encrypted keys ([#1228]). +- Add `Config`, `NewConfig()`, and `NewConnectorConfig()` to supply connection + details in a more structured way ([#1240]). + - Add `PQGO_DEBUG=1` print the communication with PostgreSQL to stderr, to aid in debugging, testing, and bug reports ([#1223]). @@ -100,6 +103,7 @@ newer. Previously PostgreSQL 8.4 and newer were supported. [#1234]: https://github.com/lib/pq/pull/1234 [#1238]: https://github.com/lib/pq/pull/1238 [#1239]: https://github.com/lib/pq/pull/1239 +[#1240]: https://github.com/lib/pq/pull/1240 v1.10.9 (2023-04-26) diff --git a/conn.go b/conn.go index 3fc9072df..bb406ea15 100644 --- a/conn.go +++ b/conn.go @@ -315,7 +315,7 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) { } func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { - network, address := network(o) + network, address := o.network() // Zero or not specified means wait indefinitely. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { @@ -722,7 +722,7 @@ func toNamedValue(v []driver.Value) []driver.NamedValue { } // CheckNamedValue implements [driver.NamedValueChecker]. -func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { +func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { // Ignore Valuer, for backward compatibility with pq.Array(). if _, ok := nv.Value.(driver.Valuer); ok { return driver.ErrSkip @@ -1077,23 +1077,10 @@ func (cn *conn) ssl(o values) error { // startup packet. func isDriverSetting(key string) bool { switch key { - case "host", "port": - return true - case "password": - return true - case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni": - return true - case "fallback_application_name": - return true - case "connect_timeout": - return true - case "disable_prepared_binary_result": - return true - case "binary_parameters": - return true - case "krbsrvname": - return true - case "krbspn": + case "host", "port", "password", "fallback_application_name", + "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni", + "connect_timeout", "binary_parameters", "disable_prepared_binary_result", + "krbsrvname", "krbspn": return true default: return false diff --git a/conn_test.go b/conn_test.go index bc19f6bca..19194f699 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1334,133 +1334,6 @@ func TestErrorClass(t *testing.T) { } } -func TestParseOpts(t *testing.T) { - tests := []struct { - in string - expected values - valid bool - }{ - {"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user=goodbye ", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, - {"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, true}, - {"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, true}, - {"dbname=hello user=''", values{"dbname": "hello", "user": ""}, true}, - {"user='' dbname=hello", values{"dbname": "hello", "user": ""}, true}, - // The last option value is an empty string if there's no non-whitespace after its = - {"dbname=hello user= ", values{"dbname": "hello", "user": ""}, true}, - - // The parser ignores spaces after = and interprets the next set of non-whitespace characters as the value. - {"user= password=foo", values{"user": "password=foo"}, true}, - - // Backslash escapes next char - {`user=a\ \'\\b`, values{"user": `a '\b`}, true}, - {`user='a \'b'`, values{"user": `a 'b`}, true}, - - // Incomplete escape - {`user=x\`, values{}, false}, - - // No '=' after the key - {"postgre://marko@internet", values{}, false}, - {"dbname user=goodbye", values{}, false}, - {"user=foo blah", values{}, false}, - {"user=foo blah ", values{}, false}, - - // Unterminated quoted value - {"dbname=hello user='unterminated", values{}, false}, - } - - for _, test := range tests { - o := make(values) - err := parseOpts(test.in, o) - - switch { - case err != nil && test.valid: - t.Errorf("%q got unexpected error: %s", test.in, err) - case err == nil && test.valid && !reflect.DeepEqual(test.expected, o): - t.Errorf("%q got: %#v want: %#v", test.in, o, test.expected) - case err == nil && !test.valid: - t.Errorf("%q expected an error", test.in) - } - } -} - -func TestRuntimeParameters(t *testing.T) { - t.Parallel() - - tests := []struct { - conninfo string - param string - want string - success bool - skipPgbouncer bool - }{ - // invalid parameter - {"DOESNOTEXIST=foo", "", "", false, false}, - - // we can only work with a specific value for these two - {"client_encoding=SQL_ASCII", "", "", false, false}, - {"datestyle='ISO, YDM'", "", "", false, false}, - - // "options" should work exactly as it does in libpq - // Skipped on pgbouncer as it errors with: - // pq: unsupported startup parameter in options: search_path - {"options='-c search_path=pqgotest'", "search_path", "pqgotest", true, true}, - - // pq should override client_encoding in this case - // TODO: not set consistently with pgbouncer - {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", true, true}, - - // allow client_encoding to be set explicitly - {"client_encoding=UTF8", "client_encoding", "UTF8", true, false}, - - // test a runtime parameter not supported by libpq - // Skipped on pgbouncer as it errors with: - // pq: unsupported startup parameter: work_mem - {"work_mem='139kB'", "work_mem", "139kB", true, true}, - - // test fallback_application_name - {"application_name=foo fallback_application_name=bar", "application_name", "foo", true, false}, - {"application_name='' fallback_application_name=bar", "application_name", "", true, false}, - {"fallback_application_name=bar", "application_name", "bar", true, false}, - } - - for _, tt := range tests { - t.Run("", func(t *testing.T) { - if tt.skipPgbouncer { - pqtest.SkipPgbouncer(t) - } - db, err := pqtest.DB(tt.conninfo) - if err != nil { - t.Fatal(err) - } - - tryGetParameterValue := func() (value string, success bool) { - defer db.Close() - row := db.QueryRow("SELECT current_setting($1)", tt.param) - err = row.Scan(&value) - if err != nil { - return "", false - } - return value, true - } - - have, success := tryGetParameterValue() - if success != tt.success && !success { - t.Fatal(err) - } - if success != tt.success { - t.Fatalf("\nhave: %v\nwant: %v", success, tt.success) - } - if have != tt.want { - t.Fatalf("\nhave: %v\nwant: %v", have, tt.want) - } - }) - } -} - func TestRowsResultTag(t *testing.T) { type ResultTag interface { Result() driver.Result diff --git a/connector.go b/connector.go index fa1732f12..a084377d5 100644 --- a/connector.go +++ b/connector.go @@ -3,89 +3,252 @@ package pq import ( "context" "database/sql/driver" - "errors" "fmt" "net" neturl "net/url" "os" "path/filepath" + "reflect" "sort" + "strconv" "strings" + "time" "unicode" "github.com/lib/pq/internal/pqutil" ) +type ( + // SSLMode is a sslmode setting. + SSLMode string + + // SSLNegotiation is a sslnegotiation setting. + SSLNegotiation string +) + +// Values for [SSLMode] that pq supports. +const ( + // disable: No SSL + SSLModeDisable = SSLMode("disable") + + // require: require SSL, but skip verification. + SSLModeRequire = SSLMode("require") + + // verify-ca: require SSL and verify that the certificate was signed by a + // trusted CA. + SSLModeVerifyCA = SSLMode("verify-ca") + + // verify-full: require SSK and verify that the certificate was signed by a + // trusted CA and the server host name matches the one in the certificate. + SSLModeVerifyFull = SSLMode("verify-full") +) + +var sslModes = []SSLMode{SSLModeDisable, SSLModeRequire, SSLModeVerifyFull, SSLModeVerifyCA} + +// Values for [SSLNegotiation] that pq supports. +const ( + // Negotiate whether SSL should be used. This is the default. + SSLNegotiationPostgres = SSLNegotiation("postgres") + + // Always use SSL, don't try to negotiate. + SSLNegotiationDirect = SSLNegotiation("direct") +) + +var sslNegotiations = []SSLNegotiation{SSLNegotiationPostgres, SSLNegotiationDirect} + // Connector represents a fixed configuration for the pq driver with a given -// name. Connector satisfies the database/sql/driver Connector interface and -// can be used to create any number of DB Conn's via the database/sql OpenDB -// function. -// -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. +// dsn. Connector satisfies the [database/sql/driver.Connector] interface and +// can be used to create any number of DB Conn's via [sql.OpenDB]. type Connector struct { opts values dialer Dialer } -// Connect returns a connection to the database using the fixed configuration -// of this Connector. Context is not used. -func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { - return c.open(ctx) +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given dsn. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// [sql.OpenDB]. +func NewConnector(dsn string) (*Connector, error) { + cfg, err := NewConfig(dsn) + if err != nil { + return nil, err + } + return NewConnectorConfig(cfg) } -// Dialer allows change the dialer used to open connections. -func (c *Connector) Dialer(dialer Dialer) { - c.dialer = dialer +// NewConnectorConfig returns a connector for the pq driver in a fixed +// configuration with the given [Config]. The returned connector can be used to +// create any number of equivalent Conn's. The returned connector is intended to +// be used with [sql.OpenDB]. +func NewConnectorConfig(cfg Config) (*Connector, error) { + return &Connector{opts: cfg.tomap(), dialer: defaultDialer{}}, nil } +// Connect returns a connection to the database using the fixed configuration of +// this Connector. Context is not used. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } + +// Dialer allows change the dialer used to open connections. +func (c *Connector) Dialer(dialer Dialer) { c.dialer = dialer } + // Driver returns the underlying driver of this Connector. -func (c *Connector) Driver() driver.Driver { - return &Driver{} -} +func (c *Connector) Driver() driver.Driver { return &Driver{} } -// NewConnector returns a connector for the pq driver in a fixed configuration -// with the given dsn. The returned connector can be used to create any number -// of equivalent Conn's. The returned connector is intended to be used with -// database/sql.OpenDB. +// Config holds options pq supports when connecting to PostgreSQL. // -// See https://golang.org/pkg/database/sql/driver/#Connector. -// See https://golang.org/pkg/database/sql/#OpenDB. -func NewConnector(dsn string) (*Connector, error) { - var err error - o := make(values) +// The postgres struct tag is used for the value from the DSN (e.g. +// "dbname=abc"), and the env struct tag is used for the environment variable +// (e.g. "PGDATABASE=abc") +type Config struct { + // The host to connect to. Absolute paths and values that start with @ are + // for unix domain sockets. Defaults to localhost. + Host string `postgres:"host" env:"PGHOST"` + + // The port to connect to. Defaults to 5432. + Port uint16 `postgres:"port" env:"PGPORT"` + + // The name of the database to connect to. + Database string `postgres:"dbname" env:"PGDATABASE"` + + // The user to sign in as. Defaults to the current user. + User string `postgres:"user" env:"PGUSER"` - // A number of defaults are applied here, in this order: + // The user's password. + Password string `postgres:"password" env:"PGPASSWORD"` + + // Path to [pgpass] file to store passwords; overrides Password. // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o["host"] = "localhost" - o["port"] = "5432" - env, err := parseEnviron(os.Environ()) - if err != nil { - return nil, err - } - for k, v := range env { - o[k] = v - } + // [pgpass]: http://www.postgresql.org/docs/current/static/libpq-pgpass.html + Passfile string `postgres:"passfile" env:"PGPASSFILE"` - if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { - dsn, err = ParseURL(dsn) - if err != nil { - return nil, err - } - } + // Commandline options to send to the server at connection start. + Options string `postgres:"options" env:"PGOPTIONS"` - if err := parseOpts(dsn, o); err != nil { - return nil, err + // Application name, displayed in pg_stat_activity and log entries. + ApplicationName string `postgres:"application_name" env:"PGAPPNAME"` + + // Used if application_name is not given. Specifying a fallback name is + // useful in generic utility programs that wish to set a default application + // name but allow it to be overridden by the user. + FallbackApplicationName string `postgres:"fallback_application_name" env:"-"` + + // Whether to use SSL. Defaults to "require" (different from libpq's default + // of "prefer"). + // + // [RegisterTLSConfig] can be used to registers a custom [tls.Config], which + // can be used by setting sslmode=pqgo-«key» in the connection string. + SSLMode SSLMode `postgres:"sslmode" env:"PGSSLMODE"` + + // When set to "direct" it will use SSL without negotiation (PostgreSQL ≥17 only). + SSLNegotiation SSLNegotiation `postgres:"sslnegotiation" env:"PGSSLNEGOTIATION"` + + // Cert file location. The file must contain PEM encoded data. + SSLCert string `postgres:"sslcert" env:"PGSSLCERT"` + + // Key file location. The file must contain PEM encoded data. + SSLKey string `postgres:"sslkey" env:"PGSSLKEY"` + + // The location of the root certificate file. The file must contain PEM encoded data. + SSLRootCert string `postgres:"sslrootcert" env:"PGSSLROOTCERT"` + + // By default SNI is on, any value which is not starting with "1" disables + // SNI. + SSLSNI bool `postgres:"sslsni" env:"PGSSLSNI"` + + // Interpert sslcert and sslkey as PEM encoded data, rather than a path to a + // PEM file. This is a pq extension, not supported in libpq. + SSLInline bool `postgres:"sslinline" env:"-"` + + // GSS (Kerberos) service name when constructing the SPN (default is + // postgres). This will be combined with the host to form the full SPN: + // krbsrvname/host. + KrbSrvname string `postgres:"krbsrvname" env:"PGKRBSRVNAME"` + + // GSS (Kerberos) SPN. This takes priority over krbsrvname if present. This + // is a pq extension, not supported in libpq. + KrbSpn string `postgres:"krbspn" env:"-"` + + // Maximum time to wait while connecting, in seconds. Zero, negative, or not + // specified means wait indefinitely + ConnectTimeout time.Duration `postgres:"connect_timeout" env:"PGCONNECT_TIMEOUT"` + + // Whether to always send []byte parameters over as binary. Enables single + // round-trip mode for non-prepared Query calls. This is a pq extension, not + // supported in libpq. + BinaryParameters bool `postgres:"binary_parameters" env:"-"` + + // This connection should never use the binary format when receiving query + // results from prepared statements. Only provided for debugging. This is a + // pq extension, not supported in libpq. + DisablePreparedBinaryResult bool `postgres:"disable_prepared_binary_result" env:"-"` + + // Client encoding; pq only supports UTF8 and this must be blank or "UTF8". + ClientEncoding string `postgres:"client_encoding" env:"PGCLIENTENCODING"` + + // Date/time representation to use; pq only supports "ISO, MDY" and this + // must be blank or "ISO, MDY". + Datestyle string `postgres:"datestyle" env:"PGDATESTYLE"` + + // Default time zone. + TZ string `postgres:"tz" env:"PGTZ"` + + // Default mode for the genetic query optimizer. + Geqo string `postgres:"geqo" env:"PGGEQO"` + + // Runtime parameters: any unrecognized parameter in the DSN will be added + // to this and sent to PostgreSQL during startup. + Runtime map[string]string `postgres:"-" env:"-"` + + // Record which parameters were given, so we can distinguish between an + // empty string "not given at all". + // + // The alternative is to use pointers or sql.Null[..], but that's more + // awkward to use. + set []string `env:"set"` +} + +// NewConfig creates a new [Config] from the current environment and given DSN. +// +// A subset of the connection parameters supported by PostgreSQL are supported +// by pq; see the [Config] struct fields for supported parameters. pq also lets +// you specify any [run-time parameter] (such as search_path or work_mem) +// directly in the connection string. This is different from libpq, which does +// not allow run-time parameters in the connection string, instead requiring you +// to supply them in the options parameter. +// +// pq supports both key=value type connection strings and postgres:// URL style +// connection strings. For key=value strings, use single quotes for values that +// contain whitespace or empty values. A backslash will escape the next +// character: +// +// "user=pqgo password='with spaces'" +// "user=''" +// "user=space\ man password='it\'s valid'" +// +// Most [PostgreSQL environment variables] are supported by pq. Environment +// variables have a lower precedence than explicitly provided connection +// parameters. pq will return an error if environment variables it does not +// support are set. Environment variables have a lower precedence than +// explicitly provided connection parameters. +// +// [run-time parameter]: http://www.postgresql.org/docs/current/static/runtime-config.html +// [PostgreSQL environment variables]: http://www.postgresql.org/docs/current/static/libpq-envars.html +func NewConfig(dsn string) (Config, error) { + return newConfig(dsn, os.Environ()) +} + +func newConfig(dsn string, env []string) (Config, error) { + cfg := Config{Host: "localhost", Port: 5432, SSLSNI: true} + if err := cfg.fromEnv(env); err != nil { + return Config{}, err + } + if err := cfg.fromDSN(dsn); err != nil { + return Config{}, err } // Use the "fallback" application name if necessary - if fallback, ok := o["fallback_application_name"]; ok { - if _, ok := o["application_name"]; !ok { - o["application_name"] = fallback - } + if cfg.isset("fallback_application_name") && !cfg.isset("application_name") { + cfg.ApplicationName = cfg.FallbackApplicationName } // We can't work with any client_encoding other than UTF-8 currently. @@ -95,90 +258,98 @@ func NewConnector(dsn string) (*Connector, error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") + if cfg.isset("client_encoding") && !isUTF8(cfg.ClientEncoding) { + return Config{}, fmt.Errorf(`pq: unsupported client_encoding %q: must be absent or "UTF8"`, cfg.ClientEncoding) } - o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle, ok := o["datestyle"]; ok { - if datestyle != "ISO, MDY" { - return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) - } - } else { - o["datestyle"] = "ISO, MDY" + if cfg.isset("datestyle") && cfg.Datestyle != "ISO, MDY" { + return Config{}, fmt.Errorf(`pq: unsupported datestyle %q: must be absent or "ISO, MDY"`, cfg.Datestyle) } + cfg.ClientEncoding, cfg.Datestyle = "UTF8", "ISO, MDY" - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if _, ok := o["user"]; !ok { + // Set default user if not explicitly provided. + if !cfg.isset("user") { u, err := pqutil.User() if err != nil { - return nil, ErrCouldNotDetectUsername + return Config{}, err } - o["user"] = u + cfg.User = u } - // SSL is not necessary or supported over UNIX domain sockets - if network, _ := network(o); network == "unix" { - o["sslmode"] = "disable" + // SSL is not necessary or supported over UNIX domain sockets. + if nw, _ := cfg.network(); nw == "unix" { + cfg.SSLMode = SSLModeDisable } - return &Connector{opts: o, dialer: defaultDialer{}}, nil + return cfg, nil } -func network(o values) (string, string) { - host := o["host"] - +func (cfg Config) network() (string, string) { // UNIX domain sockets are either represented by an (absolute) file system // path or they live in the abstract name space (starting with an @). - if filepath.IsAbs(host) || strings.HasPrefix(host, "@") { - sockPath := filepath.Join(host, ".s.PGSQL."+o["port"]) + if filepath.IsAbs(cfg.Host) || strings.HasPrefix(cfg.Host, "@") { + sockPath := filepath.Join(cfg.Host, ".s.PGSQL."+strconv.Itoa(int(cfg.Port))) return "unix", sockPath } - - return "tcp", net.JoinHostPort(host, o["port"]) -} - -type values map[string]string - -// scanner implements a tokenizer for libpq-style option strings. -type scanner struct { - s []rune - i int -} - -// newScanner returns a new scanner initialized with the option string s. -func newScanner(s string) *scanner { - return &scanner{[]rune(s), 0} + return "tcp", net.JoinHostPort(cfg.Host, strconv.Itoa(int(cfg.Port))) } -// Next returns the next rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) Next() (rune, bool) { - if s.i >= len(s.s) { - return 0, false - } - r := s.s[s.i] - s.i++ - return r, true -} - -// SkipSpaces returns the next non-whitespace rune. -// It returns 0, false if the end of the text has been reached. -func (s *scanner) SkipSpaces() (rune, bool) { - r, ok := s.Next() - for unicode.IsSpace(r) && ok { - r, ok = s.Next() +func (cfg *Config) fromEnv(env []string) error { + e := make(map[string]string) + for _, v := range env { + k, v, ok := strings.Cut(v, "=") + if !ok { + continue + } + switch k { + case "PGHOSTADDR", "PGREQUIREAUTH", "PGCHANNELBINDING", "PGSERVICE", "PGSERVICEFILE", "PGREALM", + "PGSSLCERTMODE", "PGSSLCOMPRESSION", "PGREQUIRESSL", "PGSSLCRL", "PGREQUIREPEER", + "PGSYSCONFDIR", "PGLOCALEDIR", "PGSSLCRLDIR", "PGSSLMINPROTOCOLVERSION", "PGSSLMAXPROTOCOLVERSION", + "PGGSSENCMODE", "PGGSSDELEGATION", "PGTARGETSESSIONATTRS", "PGLOADBALANCEHOSTS", "PGMINPROTOCOLVERSION", + "PGMAXPROTOCOLVERSION", "PGGSSLIB": + return fmt.Errorf("pq: environment variable $%s is not supported", k) + case "PGKRBSRVNAME": + if newGss == nil { + return fmt.Errorf("pq: environment variable $%s is not supported as Kerberos is not enabled", k) + } + } + e[k] = v } - return r, ok + return cfg.setFromTag(e, "env") } // parseOpts parses the options from name and adds them to the values. // // The parsing code is based on conninfo_parse from libpq's fe-connect.c -func parseOpts(name string, o values) error { - s := newScanner(name) +func (cfg *Config) fromDSN(dsn string) error { + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + var err error + dsn, err = convertURL(dsn) + if err != nil { + return err + } + } + + var ( + opt = make(map[string]string) + s = []rune(dsn) + i int + next = func() (rune, bool) { + if i >= len(s) { + return 0, false + } + r := s[i] + i++ + return r, true + } + skipSpaces = func() (rune, bool) { + r, ok := next() + for unicode.IsSpace(r) && ok { + r, ok = next() + } + return r, ok + } + ) for { var ( @@ -187,21 +358,21 @@ func parseOpts(name string, o values) error { ok bool ) - if r, ok = s.SkipSpaces(); !ok { + if r, ok = skipSpaces(); !ok { break } // Scan the key for !unicode.IsSpace(r) && r != '=' { keyRunes = append(keyRunes, r) - if r, ok = s.Next(); !ok { + if r, ok = next(); !ok { break } } // Skip any whitespace if we're not at the = yet if r != '=' { - r, ok = s.SkipSpaces() + r, ok = skipSpaces() } // The current character should be = @@ -210,36 +381,36 @@ func parseOpts(name string, o values) error { } // Skip any whitespace after the = - if r, ok = s.SkipSpaces(); !ok { + if r, ok = skipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o[string(keyRunes)] = "" + opt[string(keyRunes)] = "" break } if r != '\'' { for !unicode.IsSpace(r) { if r == '\\' { - if r, ok = s.Next(); !ok { + if r, ok = next(); !ok { return fmt.Errorf(`missing character after backslash`) } } valRunes = append(valRunes, r) - if r, ok = s.Next(); !ok { + if r, ok = next(); !ok { break } } } else { quote: for { - if r, ok = s.Next(); !ok { + if r, ok = next(); !ok { return fmt.Errorf(`unterminated quoted string literal in connection string`) } switch r { case '\'': break quote case '\\': - r, _ = s.Next() + r, _ = next() fallthrough default: valRunes = append(valRunes, r) @@ -247,12 +418,206 @@ func parseOpts(name string, o values) error { } } - o[string(keyRunes)] = string(valRunes) + opt[string(keyRunes)] = string(valRunes) + } + + return cfg.setFromTag(opt, "postgres") +} + +func (cfg *Config) setFromTag(o map[string]string, tag string) error { + f := "pq: wrong value for %q: " + if tag == "env" { + f = "pq: wrong value for $%s: " + } + var ( + types = reflect.TypeOf(cfg).Elem() + values = reflect.ValueOf(cfg).Elem() + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get(tag) + ) + if k == "" || k == "-" { + continue + } + + v, ok := o[k] + delete(o, k) + if ok { + if t, ok := rt.Tag.Lookup("postgres"); ok && t != "" && t != "-" { + cfg.set = append(cfg.set, t) + } + switch rt.Type.Kind() { + default: + return fmt.Errorf("don't know how to set %s: unknown type %s", rt.Name, rt.Type) + case reflect.String: + if ((tag == "postgres" && k == "sslmode") || (tag == "env" && k == "PGSSLMODE")) && + !pqutil.Contains(sslModes, SSLMode(v)) && + !(strings.HasPrefix(v, "pqgo-") && hasTLSConfig(v[5:])) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslModes)) + } + if ((tag == "postgres" && k == "sslnegotiation") || (tag == "env" && k == "PGSSLNEGOTIATION")) && + !pqutil.Contains(sslNegotiations, SSLNegotiation(v)) { + return fmt.Errorf(f+`%q is not supported; supported values are %s`, k, v, pqutil.Join(sslNegotiations)) + } + rv.SetString(v) + case reflect.Int64: + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + if (tag == "postgres" && k == "connect_timeout") || (tag == "env" && k == "PGCONNECT_TIMEOUT") { + n = int64(time.Duration(n) * time.Second) + } + rv.SetInt(n) + case reflect.Uint16: + n, err := strconv.ParseUint(v, 10, 16) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetUint(n) + case reflect.Bool: + b, err := pqutil.ParseBool(v) + if err != nil { + return fmt.Errorf(f+"%w", k, err) + } + rv.SetBool(b) + } + } + } + + // Set run-time; we delete map keys as they're set in the struct. + if tag == "postgres" { + // Make sure database= sets dbname=; in startup() we send database for + // dbname, and if we have both set it's inconsistent as the loop order + // is a map. + if d, ok := o["database"]; ok { + delete(o, "database") + if o["dbname"] == "" { + o["dbname"] = d + } + } + cfg.Runtime = o } return nil } +func (cfg Config) isset(name string) bool { + return pqutil.Contains(cfg.set, name) +} + +// Convert to a map; mostly so we don't need to rewrite all the code. +func (cfg Config) tomap() values { + var ( + o = make(values) + values = reflect.ValueOf(cfg) + types = reflect.TypeOf(cfg) + ) + for i := 0; i < types.NumField(); i++ { + var ( + rt = types.Field(i) + rv = values.Field(i) + k = rt.Tag.Get("postgres") + ) + if k == "" || k == "-" { + continue + } + if !rv.IsZero() || pqutil.Contains(cfg.set, k) { + switch rt.Type.Kind() { + default: + o[k] = rv.String() + case reflect.Uint16: + n := rv.Uint() + o[k] = strconv.FormatUint(n, 10) + case reflect.Int64: + n := rv.Int() + if k == "connect_timeout" { + n = int64(time.Duration(n) / time.Second) + } + o[k] = strconv.FormatInt(n, 10) + case reflect.Bool: + if rv.Bool() { + o[k] = "yes" + } else { + o[k] = "no" + } + } + } + } + for k, v := range cfg.Runtime { + o[k] = v + } + return o +} + +// Create DSN for this config; primarily for tests. +func (cfg Config) string() string { + var ( + m = cfg.tomap() + keys = make([]string, 0, len(m)) + ) + for k := range m { + switch k { + case "datestyle", "client_encoding": + continue + case "host", "port", "user", "sslsni": + if !cfg.isset(k) { + continue + } + } + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + for i, k := range keys { + if i > 0 { + b.WriteByte(' ') + } + b.WriteString(k) + b.WriteByte('=') + var ( + v = m[k] + nv = make([]rune, 0, len(v)+2) + quote = v == "" + ) + for _, c := range v { + if c == ' ' { + quote = true + } + if c == '\'' { + nv = append(nv, '\\') + } + nv = append(nv, c) + } + if quote { + b.WriteByte('\'') + } + b.WriteString(string(nv)) + if quote { + b.WriteByte('\'') + } + } + return b.String() +} + +// Recognize all sorts of silly things as "UTF-8", like Postgres does +func isUTF8(name string) bool { + s := strings.Map(func(c rune) rune { + if 'A' <= c && c <= 'Z' { + return c + ('a' - 'A') + } + if 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { + return c + } + return -1 // discard + }, name) + return s == "utf8" || s == "unicode" +} + func convertURL(url string) (string, error) { u, err := neturl.Parse(url) if err != nil { @@ -272,11 +637,9 @@ func convertURL(url string) (string, error) { } if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) + pw, _ := u.User.Password() + accrue("user", u.User.Username()) + accrue("password", pw) } if host, port, err := net.SplitHostPort(u.Host); err != nil { @@ -298,98 +661,3 @@ func convertURL(url string) (string, error) { sort.Strings(kvs) // Makes testing easier (not a performance concern) return strings.Join(kvs, " "), nil } - -// parseEnviron tries to mimic some of libpq's environment handling -// -// To ease testing, it does not directly reference os.Environ, but is designed -// to accept its output. -// -// Environment-set connection information is intended to have a higher -// precedence than a library default but lower than any explicitly passed -// information (such as in the URL or connection string). -func parseEnviron(env []string) (map[string]string, error) { - out := make(map[string]string) - for _, e := range env { - k, v, ok := strings.Cut(e, "=") - if !ok { - return nil, fmt.Errorf("invalid environment: %q", e) - } - - accrue := func(key string) { out[key] = v } - - // Last updated for PostgreSQL 18 - switch k { - case "PGHOSTADDR", "PGREQUIREAUTH", "PGCHANNELBINDING", "PGSERVICE", "PGSERVICEFILE", "PGREALM", - "PGSSLCERTMODE", "PGSSLCOMPRESSION", "PGREQUIRESSL", "PGSSLCRL", "PGREQUIREPEER", - "PGSYSCONFDIR", "PGLOCALEDIR", "PGSSLCRLDIR", "PGSSLMINPROTOCOLVERSION", "PGSSLMAXPROTOCOLVERSION", - "PGGSSENCMODE", "PGGSSDELEGATION", "PGTARGETSESSIONATTRS", "PGLOADBALANCEHOSTS", "PGMINPROTOCOLVERSION", - "PGMAXPROTOCOLVERSION": - return nil, fmt.Errorf("setting %q not supported", k) - - case "PGHOST": - accrue("host") - case "PGSSLNEGOTIATION": - accrue("sslnegotiation") - case "PGPORT": - accrue("port") - case "PGDATABASE": - accrue("dbname") - case "PGUSER": - accrue("user") - case "PGPASSWORD": - accrue("password") - case "PGPASSFILE": - accrue("passfile") - case "PGOPTIONS": - accrue("options") - case "PGAPPNAME": - accrue("application_name") - case "PGSSLMODE": - accrue("sslmode") - case "PGSSLCERT": - accrue("sslcert") - case "PGSSLKEY": - accrue("sslkey") - case "PGSSLROOTCERT": - accrue("sslrootcert") - case "PGSSLSNI": - accrue("sslsni") - case "PGGSSLIB": - if newGss == nil { - return nil, fmt.Errorf("setting %q not supported", k) - } - accrue("gsslib") - case "PGKRBSRVNAME": - if newGss == nil { - return nil, fmt.Errorf("setting %q not supported", k) - } - accrue("krbsrvname") - case "PGCONNECT_TIMEOUT": - accrue("connect_timeout") - case "PGCLIENTENCODING": - accrue("client_encoding") - case "PGDATESTYLE": - accrue("datestyle") - case "PGTZ": - accrue("timezone") - case "PGGEQO": - accrue("geqo") - } - } - return out, nil -} - -// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". -func isUTF8(name string) bool { - // Recognize all sorts of silly things as "UTF-8", like Postgres does - s := strings.Map(func(c rune) rune { - if 'A' <= c && c <= 'Z' { - return c + ('a' - 'A') - } - if 'a' <= c && c <= 'z' || '0' <= c && c <= '9' { - return c - } - return -1 // discard - }, name) - return s == "utf8" || s == "unicode" -} diff --git a/connector_test.go b/connector_test.go index cc2a77da9..ecc6e8c82 100644 --- a/connector_test.go +++ b/connector_test.go @@ -6,89 +6,229 @@ import ( "context" "database/sql" "database/sql/driver" + "fmt" "os" "reflect" "testing" + "time" "github.com/lib/pq/internal/pqtest" ) -func TestNewConnector_WorksWithOpenDB(t *testing.T) { - name := "" - c, err := NewConnector(name) - if err != nil { - t.Fatal(err) - } - db := sql.OpenDB(c) - defer db.Close() +func TestNewConnector(t *testing.T) { // database/sql might not call our Open at all unless we do something with // the connection - txn, err := db.Begin() - if err != nil { - t.Fatal(err) + useConn := func(t *testing.T, db any) { + t.Helper() + switch db := db.(type) { + default: + t.Fatalf("unknown type: %T", db) + case *sql.DB: + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Rollback() + case driver.Conn: + tx, err := db.Begin() //lint:ignore SA1019 x + if err != nil { + t.Fatal(err) + } + tx.Rollback() + } } - txn.Rollback() -} -func TestNewConnector_Connect(t *testing.T) { - c, err := NewConnector("") - if err != nil { - t.Fatal(err) - } - db, err := c.Connect(context.Background()) - if err != nil { - t.Fatal(err) - } - defer db.Close() - // database/sql might not call our Open at all unless we do something with - // the connection - txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) - if err != nil { - t.Fatal(err) - } - txn.Rollback() + t.Run("WorksWithOpenDB", func(t *testing.T) { + c, err := NewConnector("") + if err != nil { + t.Fatal(err) + } + db := sql.OpenDB(c) + defer db.Close() + useConn(t, db) + }) + t.Run("Connect", func(t *testing.T) { + c, err := NewConnector("") + if err != nil { + t.Fatal(err) + } + db, err := c.Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + useConn(t, db) + }) + t.Run("Driver", func(t *testing.T) { + c, err := NewConnector("") + if err != nil { + t.Fatal(err) + } + db, err := c.Driver().Open("") + if err != nil { + t.Fatal(err) + } + defer db.Close() + useConn(t, db) + }) + t.Run("Environ", func(t *testing.T) { + os.Setenv("PGPASSFILE", "/tmp/.pgpass") + defer os.Unsetenv("PGPASSFILE") + c, err := NewConnector("") + if err != nil { + t.Fatal(err) + } + if have := c.opts["passfile"]; have != "/tmp/.pgpass" { + t.Fatalf("wrong option for pgassfile: %q", have) + } + }) + + t.Run("WithConfig", func(t *testing.T) { + cfg, err := NewConfig("") + if err != nil { + t.Fatal(err) + } + cfg.SSLMode = SSLModeDisable + cfg.Runtime = map[string]string{"search_path": "foo"} + + c, err := NewConnectorConfig(cfg) + if err != nil { + t.Fatal(err) + } + want := fmt.Sprintf( + `map[client_encoding:UTF8 connect_timeout:20 datestyle:ISO, MDY dbname:pqgo host:localhost port:%d search_path:foo sslmode:disable sslsni:yes user:pqgo]`, + cfg.Port) + if have := fmt.Sprintf("%v", c.opts); have != want { + t.Errorf("\nhave: %s\nwant: %s", have, want) + } + + // pq: unsupported startup parameter: search_path (08P01) + pqtest.SkipPgbouncer(t) + + db := sql.OpenDB(c) + defer db.Close() + useConn(t, db) + }) } -func TestNewConnector_Driver(t *testing.T) { - c, err := NewConnector("") - if err != nil { - t.Fatal(err) - } - db, err := c.Driver().Open("") - if err != nil { - t.Fatal(err) +func TestParseOpts(t *testing.T) { + tests := []struct { + in string + want values + wantErr string + }{ + {"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, ""}, + {"dbname=hello user=goodbye ", values{"dbname": "hello", "user": "goodbye"}, ""}, + {"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, ""}, + {"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, ""}, + {"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, ""}, + {"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, ""}, + {"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, ""}, + {"dbname=hello user=''", values{"dbname": "hello", "user": ""}, ""}, + {"user='' dbname=hello", values{"dbname": "hello", "user": ""}, ""}, + // The last option value is an empty string if there's no non-whitespace after its = + {"dbname=hello user= ", values{"dbname": "hello", "user": ""}, ""}, + + // The parser ignores spaces after = and interprets the next set of non-whitespace characters as the value. + {"user= password=foo", values{"user": "password=foo"}, ""}, + + // Backslash escapes next char + {`user=a\ \'\\b`, values{"user": `a '\b`}, ""}, + {`user='a \'b'`, values{"user": `a 'b`}, ""}, + + // Incomplete escape + {`user=x\`, values{}, "missing character after backslash"}, + + // No '=' after the key + {"postgre://marko@internet", values{}, `missing "="`}, + {"dbname user=goodbye", values{}, `missing "="`}, + {"user=foo blah", values{}, `missing "="`}, + {"user=foo blah ", values{}, `missing "="`}, + + // Unterminated quoted value + {"dbname=hello user='unterminated", values{}, `unterminated quoted string`}, } - defer db.Close() - // database/sql might not call our Open at all unless we do something with - // the connection - txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) - if err != nil { - t.Fatal(err) + + t.Parallel() + for _, tt := range tests { + t.Run("", func(t *testing.T) { + var cfg Config + err := cfg.fromDSN(tt.in) + if !pqtest.ErrorContains(err, tt.wantErr) { + t.Fatalf("wrong error\nhave: %v\nwant: %v", err, tt.wantErr) + } + if have := cfg.tomap(); !reflect.DeepEqual(have, tt.want) { + t.Errorf("\nhave: %#v\nwant: %#v", have, tt.want) + } + }) } - txn.Rollback() } -func TestNewConnector_Environ(t *testing.T) { - os.Setenv("PGPASSFILE", "/tmp/.pgpass") - defer os.Unsetenv("PGPASSFILE") - c, err := NewConnector("") - if err != nil { - t.Fatal(err) - } - for key, expected := range map[string]string{ - "passfile": "/tmp/.pgpass", - } { - if got := c.opts[key]; got != expected { - t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got) - } +func TestRuntimeParameters(t *testing.T) { + tests := []struct { + conninfo string + param string + want string + wantErr string + skipPgbouncer bool + }{ + {"DOESNOTEXIST=foo", "", "", "unrecognized configuration parameter", false}, + + // we can only work with a specific value for these two + {"client_encoding=SQL_ASCII", "", "", `unsupported client_encoding "SQL_ASCII": must be absent or "UTF8"`, false}, + {"datestyle='ISO, YDM'", "", "", `unsupported datestyle "ISO, YDM": must be absent or "ISO, MDY"`, false}, + + // "options" should work exactly as it does in libpq + // Skipped on pgbouncer as it errors with: + // pq: unsupported startup parameter in options: search_path + {"options='-c search_path=pqgotest'", "search_path", "pqgotest", "", true}, + + // pq should override client_encoding in this case + // TODO: not set consistently with pgbouncer + {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", "", true}, + + // allow client_encoding to be set explicitly + {"client_encoding=UTF8", "client_encoding", "UTF8", "", false}, + + // test a runtime parameter not supported by libpq + // Skipped on pgbouncer as it errors with: + // pq: unsupported startup parameter: work_mem + {"work_mem='139kB'", "work_mem", "139kB", "", true}, + + // test fallback_application_name + {"application_name=foo fallback_application_name=bar", "application_name", "foo", "", false}, + {"application_name='' fallback_application_name=bar", "application_name", "", "", false}, + {"fallback_application_name=bar", "application_name", "bar", "", false}, } -} + t.Parallel() + for _, tt := range tests { + t.Run("", func(t *testing.T) { + if tt.skipPgbouncer { + pqtest.SkipPgbouncer(t) + } + if pqtest.Pgbouncer() && tt.wantErr == "unrecognized configuration parameter" { + tt.wantErr = `unsupported startup parameter` + } + + db := pqtest.MustDB(t, tt.conninfo) + var have string + row := db.QueryRow("select current_setting($1)", tt.param) + err := row.Scan(&have) + if !pqtest.ErrorContains(err, tt.wantErr) { + t.Fatalf("wrong error\nhave: %v\nwant: %v", err, tt.wantErr) + } + if have != tt.want { + t.Fatalf("\nhave: %v\nwant: %v", have, tt.want) + } + }) + } +} func TestParseEnviron(t *testing.T) { tests := []struct { in []string - want map[string]string + want values }{ {[]string{"PGDATABASE=hello", "PGUSER=goodbye"}, map[string]string{"dbname": "hello", "user": "goodbye"}}, @@ -98,21 +238,24 @@ func TestParseEnviron(t *testing.T) { map[string]string{"connect_timeout": "30"}}, } + t.Parallel() for _, tt := range tests { t.Run("", func(t *testing.T) { - have, err := parseEnviron(tt.in) + var cfg Config + err := cfg.fromEnv(tt.in) if err != nil { t.Fatal(err) } + have := cfg.tomap() if !reflect.DeepEqual(tt.want, have) { - t.Errorf("want: %#v; have: %#v", tt.want, have) + t.Errorf("\nwant: %#v\nhave: %#v", tt.want, have) } }) } } func TestIsUTF8(t *testing.T) { - var cases = []struct { + tests := []struct { name string want bool }{ @@ -128,10 +271,14 @@ func TestIsUTF8(t *testing.T) { {"punycode", false}, } - for _, test := range cases { - if g := isUTF8(test.name); g != test.want { - t.Errorf("isUTF8(%q) = %v want %v", test.name, g, test.want) - } + t.Parallel() + for _, tt := range tests { + t.Run("", func(t *testing.T) { + have := isUTF8(tt.name) + if have != tt.want { + t.Errorf("\nhave: %v\nwant: %v", have, tt.want) + } + }) } } @@ -154,6 +301,7 @@ func TestParseURL(t *testing.T) { //{"postgres://host/db ", "dbname='db' host='host'", ""}, } + t.Parallel() for _, tt := range tests { t.Run("", func(t *testing.T) { have, err := ParseURL(tt.in) @@ -166,3 +314,106 @@ func TestParseURL(t *testing.T) { }) } } + +func TestNewConfig(t *testing.T) { + tests := []struct { + inDSN string + inEnv []string + want string + wantErr string + }{ + // Override defaults + {"", nil, "", ""}, + {"user=u port=1 host=example.com", nil, + "host=example.com port=1 user=u", ""}, + {"", []string{"PGUSER=u", "PGPORT=1", "PGHOST=example.com"}, + "host=example.com port=1 user=u", ""}, + + // Socket + {"host=/var/run/psql", nil, "host=/var/run/psql sslmode=disable", ""}, + {"host=@/var/run/psql", nil, "host=@/var/run/psql sslmode=disable", ""}, + {"host=/var/run/psql sslmode=require", nil, "host=/var/run/psql sslmode=disable", ""}, + + // Empty value, value with space, and value with escaped \' + {"user=''", nil, "user=''", ""}, + {`user='with\' space'`, nil, `user='with\' space'`, ""}, + + // Bool + {"sslsni=0", nil, "sslsni=no", ""}, + {"sslsni=1", nil, "sslsni=yes", ""}, + {"sslinline=yes", nil, "sslinline=yes", ""}, + {"sslinline=no", nil, "sslinline=no", ""}, + {"sslinline=lol", nil, "", `pq: wrong value for "sslinline": strconv.ParseBool: parsing "lol": invalid syntax`}, + + // application_name and fallback_application_name + {"application_name=acme", nil, "application_name=acme", ""}, + {"application_name=acme fallback_application_name=roadrunner", nil, "application_name=acme fallback_application_name=roadrunner", ""}, + {"fallback_application_name=roadrunner", []string{"PGAPPNAME=acme"}, "application_name=acme fallback_application_name=roadrunner", ""}, + {"fallback_application_name=roadrunner", nil, "application_name=roadrunner fallback_application_name=roadrunner", ""}, + + // Timeout and port + {"connect_timeout=5", nil, "connect_timeout=5", ""}, + {"", []string{"PGCONNECT_TIMEOUT=5"}, "connect_timeout=5", ""}, + {"connect_timeout=5s", nil, "", `pq: wrong value for "connect_timeout": strconv.ParseInt: parsing "5s": invalid syntax`}, + {"", []string{"PGCONNECT_TIMEOUT=5s"}, "", `pq: wrong value for $PGCONNECT_TIMEOUT: strconv.ParseInt: parsing "5s": invalid syntax`}, + {"port=5s", nil, "", `pq: wrong value for "port": strconv.ParseUint: parsing "5s": invalid syntax`}, + {"", []string{"PGPORT=5s"}, "", `pq: wrong value for $PGPORT: strconv.ParseUint: parsing "5s": invalid syntax`}, + + // Runtime + {"user=u search_path=abc", nil, "search_path=abc user=u", ""}, + {"database=db", nil, "dbname=db", ``}, + + // URL + {"postgres://u@example.com:1/db", nil, + "dbname=db host=example.com port=1 user=u", ""}, + {"postgres://u:pw@example.com:1/db?opt=val&sslmode=require", nil, + "dbname=db host=example.com opt=val password=pw port=1 sslmode=require user=u", ""}, + + // Unsupported env vars + {"", []string{"PGREALM=abc"}, "", `pq: environment variable $PGREALM is not supported`}, + {"", []string{"PGKRBSRVNAME=abc"}, "", `pq: environment variable $PGKRBSRVNAME is not supported`}, + + // Unsupported enums + {"sslmode=sslmeharder", nil, "", `pq: wrong value for "sslmode"`}, + {"postgres://u:pw@example.com:1/db?sslmode=sslmeharder", nil, "", `pq: wrong value for "sslmode"`}, + {"", []string{"PGSSLMODE=sslmeharder"}, "", `pq: wrong value for $PGSSLMODE`}, + {"sslnegotiation=sslmeharder", nil, "", `pq: wrong value for "sslnegotiation"`}, + {"postgres://u:pw@example.com:1/db?sslnegotiation=sslmeharder", nil, "", `pq: wrong value for "sslnegotiation"`}, + {"", []string{"PGSSLNEGOTIATION=sslmeharder"}, "", `pq: wrong value for $PGSSLNEGOTIATION`}, + } + + t.Parallel() + for _, tt := range tests { + t.Run("", func(t *testing.T) { + have, err := newConfig(tt.inDSN, tt.inEnv) + if !pqtest.ErrorContains(err, tt.wantErr) { + t.Fatalf("wrong error\nhave: %v\nwant: %v", err, tt.wantErr) + } + if have.string() != tt.want { + t.Errorf("\nhave: %q\nwant: %q", have.string(), tt.want) + } + }) + } + + // Make sure connect_timeout is parsed as seconds. + t.Run("connect_timeout", func(t *testing.T) { + { + have, err := newConfig("connect_timeout=3", []string{}) + if err != nil { + t.Fatal(err) + } + if have.ConnectTimeout != 3*time.Second { + t.Errorf("\nhave: %q\nwant: %q", have.ConnectTimeout, 3*time.Second) + } + } + { + have, err := newConfig("", []string{"PGCONNECT_TIMEOUT=4"}) + if err != nil { + t.Fatal(err) + } + if have.ConnectTimeout != 4*time.Second { + t.Errorf("\nhave: %q\nwant: %q", have.ConnectTimeout, 4*time.Second) + } + } + }) +} diff --git a/deprecated.go b/deprecated.go index 19530eb97..7ea816426 100644 --- a/deprecated.go +++ b/deprecated.go @@ -1,5 +1,11 @@ package pq +import ( + "net" + "path/filepath" + "strings" +) + // PGError is an interface used by previous versions of pq. // // Deprecated: use the Error type. This is never used. @@ -65,3 +71,16 @@ func (e *Error) Get(k byte) (v string) { // Deprecated: directly passing an URL to sql.Open("postgres", "postgres://...") // now works, and calling this manually is no longer required. func ParseURL(url string) (string, error) { return convertURL(url) } + +type values map[string]string + +func (o values) network() (string, string) { + host := o["host"] + // UNIX domain sockets are either represented by an (absolute) file system + // path or they live in the abstract name space (starting with an @). + if filepath.IsAbs(host) || strings.HasPrefix(host, "@") { + sockPath := filepath.Join(host, ".s.PGSQL."+o["port"]) + return "unix", sockPath + } + return "tcp", net.JoinHostPort(host, o["port"]) +} diff --git a/doc.go b/doc.go index db49ea167..d70ec2d2d 100644 --- a/doc.go +++ b/doc.go @@ -1,8 +1,8 @@ /* -Package pq is a pure Go Postgres driver for the database/sql package. +Package pq is a Go PostgreSQL driver for database/sql. -In most cases clients will use the database/sql package instead of -using this package directly. For example: +Most clients will use the database/sql package instead of using this package +directly. For example: import ( "database/sql" @@ -11,124 +11,49 @@ using this package directly. For example: ) func main() { - connStr := "user=pqgo dbname=pqgo sslmode=verify-full" - db, err := sql.Open("postgres", connStr) + dsn := "user=pqgo dbname=pqgo sslmode=verify-full" + db, err := sql.Open("postgres", dsn) if err != nil { log.Fatal(err) } age := 21 - rows, err := db.Query("SELECT name FROM users WHERE age = $1", age) - … + rows, err := db.Query("select name from users where age = $1", age) + // … } -You can also connect to a database using a URL. For example: +You can also connect with an URL: - connStr := "postgres://pqgo:password@localhost/pqgo?sslmode=verify-full" - db, err := sql.Open("postgres", connStr) + dsn := "postgres://pqgo:password@localhost/pqgo?sslmode=verify-full" + db, err := sql.Open("postgres", dsn) # Connection String Parameters -Similarly to libpq, when establishing a connection using pq you are expected to -supply a connection string containing zero or more parameters. -A subset of the connection parameters supported by libpq are also supported by pq. -Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem) -directly in the connection string. This is different from libpq, which does not allow -run-time parameters in the connection string, instead requiring you to supply -them in the options parameter. - -For compatibility with libpq, the following special connection parameters are -supported: - - - dbname - The name of the database to connect to - - user - The user to sign in as - - password - The user's password - - host - The host to connect to. Values that start with / are for unix - domain sockets. (default is localhost) - - port - The port to bind to. (default is 5432) - - sslmode - Whether or not to use SSL (default is require, this is not - the default for libpq) - - fallback_application_name - An application_name to fall back to if one isn't provided. - - connect_timeout - Maximum wait for connection, in seconds. Zero or - not specified means wait indefinitely. - - sslcert - Cert file location. The file must contain PEM encoded data. - - sslkey - Key file location. The file must contain PEM encoded data. - - sslrootcert - The location of the root certificate file. The file - must contain PEM encoded data. - - sslnegotiation - when set to "direct" it will use SSL without negotiation (PostgreSQL ≥17 only). - -Valid values for sslmode are: - - - disable - No SSL - - require - Always SSL (skip verification) - - verify-ca - Always SSL (verify that the certificate presented by the - server was signed by a trusted CA) - - verify-full - Always SSL (verify that the certification presented by - the server was signed by a trusted CA and the server host name - matches the one in the certificate) - - A custom TLS configuration registered with [RegisterTLSConfig]. These must - be prefixed with "pqgo-". - -See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING -for more information about connection string parameters. - -Use single quotes for values that contain whitespace: - - "user=pqgo password='with spaces'" - -A backslash will escape the next character in values: - - "user=space\ man password='it\'s valid'" - -Note that the connection parameter client_encoding (which sets the -text encoding for the connection) may be set but must be "UTF8", -matching with the same rules as Postgres. It is an error to provide -any other value. - -In addition to the parameters listed above, any run-time parameter that can be -set at backend start time can be set in the connection string. For more -information, see -http://www.postgresql.org/docs/current/static/runtime-config.html. - -Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html -supported by libpq are also supported by pq. If any of the environment -variables not supported by pq are set, pq will panic during connection -establishment. Environment variables have a lower precedence than explicitly -provided connection parameters. - -The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html -is supported, but on Windows PGPASSFILE must be specified explicitly. +See [NewConfig]. # Queries -database/sql does not dictate any specific format for parameter -markers in query strings, and pq uses the Postgres-native ordinal markers, -as shown above. The same marker can be reused for the same parameter: +database/sql does not dictate any specific format for parameter placeholders, +and pq uses the PostgreSQL-native ordinal markers ($1, $2, etc.). The same +placeholder can be used more than once: - rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1 - OR age BETWEEN $2 AND $2 + 3`, "orange", 64) + rows, err := db.Query( + `select * from users where name = $1 or age between $2 and $2 + 3`, + "Duck", 64) -pq does not support the LastInsertId() method of the Result type in database/sql. -To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres -RETURNING clause with a standard Query or QueryRow call: +pq does not support [sql.Result.LastInsertId]. Use the RETURNING clause with a +Query or QueryRow call instead to return the identifier: - var userid int - err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age) - VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid) - -For more details on RETURNING, see the Postgres documentation: - - http://www.postgresql.org/docs/current/static/sql-insert.html - http://www.postgresql.org/docs/current/static/sql-update.html - http://www.postgresql.org/docs/current/static/sql-delete.html + row := db.QueryRow(`insert into users(name, age) values('Scrooge McDuck', 93) returning id`) -For additional instructions on querying see the documentation for the database/sql package. + var userid int + err := row.Scan(&userid) # Data Types -Parameters pass through driver.DefaultParameterConverter before they are handled -by this package. When the binary_parameters connection option is enabled, -[]byte values are sent directly to the backend as data in binary format. +Parameters pass through [driver.DefaultParameterConverter] before they are handled +by this package. When the binary_parameters connection option is enabled, []byte +values are sent directly to the backend as data in binary format. This package returns the following types for values from the PostgreSQL backend: @@ -144,18 +69,17 @@ All other types are returned directly from the backend as []byte values in text # Errors -pq may return errors of type *pq.Error which can be interrogated for error details: +pq may return errors of type [*pq.Error] which contain error details: - if err, ok := err.(*pq.Error); ok { - fmt.Println("pq error:", err.Code.Name()) + pqErr := new(pq.Error) + if errors.As(err, &pqErr) { + fmt.Println("pq error:", pqErr.Code.Name()) } -See the pq.Error type for details. - # Bulk imports -You can perform bulk imports by preparing a statement returned by pq.CopyIn (or -pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement +You can perform bulk imports by preparing a statement returned by [CopyIn] (or +[CopyInSchema]) in an explicit transaction ([sql.Tx]). The returned statement handle can then be repeatedly "executed" to copy data into the target table. After all data has been processed you should call Exec() once with no arguments to flush all buffered data. Any call to Exec() might return an error which @@ -168,12 +92,12 @@ explicit transaction in pq. Usage example: - txn, err := db.Begin() + tx, err := db.Begin() if err != nil { log.Fatal(err) } - stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age")) + stmt, err := tx.Prepare(pq.CopyIn("users", "name", "age")) if err != nil { log.Fatal(err) } @@ -195,43 +119,38 @@ Usage example: log.Fatal(err) } - err = txn.Commit() + err = tx.Commit() if err != nil { log.Fatal(err) } # Notifications -PostgreSQL supports a simple publish/subscribe model over database -connections. See http://www.postgresql.org/docs/current/static/sql-notify.html -for more information about the general mechanism. +PostgreSQL supports a simple publish/subscribe model using PostgreSQL's [NOTIFY] mechanism. -To start listening for notifications, you first have to open a new connection -to the database by calling NewListener. This connection can not be used for -anything other than LISTEN / NOTIFY. Calling Listen will open a "notification +To start listening for notifications, you first have to open a new connection to +the database by calling [NewListener]. This connection can not be used for +anything other than LISTEN / NOTIFY. Calling Listen will open a "notification channel"; once a notification channel is open, a notification generated on that -channel will effect a send on the Listener.Notify channel. A notification +channel will effect a send on the Listener.Notify channel. A notification channel will remain open until Unlisten is called, though connection loss might -result in some notifications being lost. To solve this problem, Listener sends -a nil pointer over the Notify channel any time the connection is re-established -following a connection loss. The application can get information about the -state of the underlying connection by setting an event callback in the call to +result in some notifications being lost. To solve this problem, Listener sends a +nil pointer over the Notify channel any time the connection is re-established +following a connection loss. The application can get information about the state +of the underlying connection by setting an event callback in the call to NewListener. -A single Listener can safely be used from concurrent goroutines, which means +A single [Listener] can safely be used from concurrent goroutines, which means that there is often no need to create more than one Listener in your -application. However, a Listener is always connected to a single database, so +application. However, a Listener is always connected to a single database, so you will need to create a new Listener instance for every database you want to receive notifications in. The channel name in both Listen and Unlisten is case sensitive, and can contain -any characters legal in an identifier (see -http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -for more information). Note that the channel name will be truncated to 63 -bytes by the PostgreSQL server. +any characters legal in an [identifier]. Note that the channel name will be +truncated to 63 bytes by the PostgreSQL server. -You can find a complete, working example of Listener usage at -https://godoc.org/github.com/lib/pq/example/listen. +You can find a complete, working example of Listener usage at [cmd/pqlisten]. # Kerberos Support @@ -244,15 +163,11 @@ package: pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) } -This package is in a separate module so that users who don't need Kerberos -don't have to download unnecessary dependencies. - -When imported, additional connection string parameters are supported: +This package is in a separate module so that users who don't need Kerberos don't +have to add unnecessary dependencies. - - krbsrvname - GSS (Kerberos) service name when constructing the - SPN (default is `postgres`). This will be combined with the host - to form the full SPN: `krbsrvname/host`. - - krbspn - GSS (Kerberos) SPN. This takes priority over - `krbsrvname` if present. +[cmd/pqlisten]: https://github.com/lib/pq/tree/master/cmd/pqlisten +[identifier]: http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +[NOTIFY]: http://www.postgresql.org/docs/current/static/sql-notify.html */ package pq diff --git a/encode.go b/encode.go index a5d3a5607..e43fc93d6 100644 --- a/encode.go +++ b/encode.go @@ -50,7 +50,7 @@ func encode(x any, pgtypOid oid.Oid) ([]byte, error) { case bool: return strconv.AppendBool(nil, v), nil case time.Time: - return formatTs(v), nil + return formatTS(v), nil default: return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } @@ -117,9 +117,9 @@ func textDecode(ps *parameterStatus, s []byte, typ oid.Oid) (any, error) { } return b, err case oid.T_timestamptz: - return parseTs(ps.currentLocation, string(s)) + return parseTS(ps.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: - return parseTs(nil, string(s)) + return parseTS(nil, string(s)) case oid.T_time: return parseTime("15:04:05", typ, s) case oid.T_timetz: @@ -161,7 +161,7 @@ func appendEncodedText(buf []byte, x any) ([]byte, error) { case bool: return strconv.AppendBool(buf, v), nil case time.Time: - return append(buf, formatTs(v)...), nil + return append(buf, formatTS(v)...), nil case nil: return append(buf, "\\N"...), nil default: @@ -314,14 +314,14 @@ func (c *locationCache) getLocation(offset int) *time.Location { } var ( - infinityTsEnabled = false - infinityTsNegative time.Time - infinityTsPositive time.Time + infinityTSEnabled = false + infinityTSNegative time.Time + infinityTSPositive time.Time ) const ( - infinityTsEnabledAlready = "pq: infinity timestamp enabled already" - infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" + infinityTSEnabledAlready = "pq: infinity timestamp enabled already" + infinityTSNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" ) // EnableInfinityTs controls the handling of Postgres' "-infinity" and @@ -345,36 +345,36 @@ const ( // undefined behavior. If EnableInfinityTs is called more than once, it will // panic. func EnableInfinityTs(negative time.Time, positive time.Time) { - if infinityTsEnabled { - panic(infinityTsEnabledAlready) + if infinityTSEnabled { + panic(infinityTSEnabledAlready) } if !negative.Before(positive) { - panic(infinityTsNegativeMustBeSmaller) + panic(infinityTSNegativeMustBeSmaller) } - infinityTsEnabled = true - infinityTsNegative = negative - infinityTsPositive = positive + infinityTSEnabled = true + infinityTSNegative = negative + infinityTSPositive = positive } -// Testing might want to toggle infinityTsEnabled -func disableInfinityTs() { - infinityTsEnabled = false +// Testing might want to toggle infinityTSEnabled +func disableInfinityTS() { + infinityTSEnabled = false } // This is a time function specific to the Postgres default DateStyle // setting ("ISO, MDY"), the only one we currently support. This // accounts for the discrepancies between the parsing available with // time.Parse and the Postgres date formatting quirks. -func parseTs(currentLocation *time.Location, str string) (any, error) { +func parseTS(currentLocation *time.Location, str string) (any, error) { switch str { case "-infinity": - if infinityTsEnabled { - return infinityTsNegative, nil + if infinityTSEnabled { + return infinityTSNegative, nil } return []byte(str), nil case "infinity": - if infinityTsEnabled { - return infinityTsPositive, nil + if infinityTSEnabled { + return infinityTSPositive, nil } return []byte(str), nil } @@ -498,15 +498,15 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro return t, p.err } -// formatTs formats t into a format postgres understands. -func formatTs(t time.Time) []byte { - if infinityTsEnabled { +// formatTS formats t into a format postgres understands. +func formatTS(t time.Time) []byte { + if infinityTSEnabled { // t <= -infinity : ! (t > -infinity) - if !t.After(infinityTsNegative) { + if !t.After(infinityTSNegative) { return []byte("-infinity") } // t >= infinity : ! (!t < infinity) - if !t.Before(infinityTsPositive) { + if !t.Before(infinityTSPositive) { return []byte("infinity") } } diff --git a/encode_test.go b/encode_test.go index ec84ea125..5ad2d429c 100644 --- a/encode_test.go +++ b/encode_test.go @@ -173,7 +173,7 @@ var formatTimeTests = []struct { func TestFormatTs(t *testing.T) { for i, tt := range formatTimeTests { - val := string(formatTs(tt.time)) + val := string(formatTS(tt.time)) if val != tt.expected { t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected) } @@ -498,7 +498,7 @@ func TestInfinityTimestamp(t *testing.T) { t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s) } - disableInfinityTs() + disableInfinityTS() var panicErrorString string func() { @@ -507,8 +507,8 @@ func TestInfinityTimestamp(t *testing.T) { }() EnableInfinityTs(y2500, y1500) }() - if panicErrorString != infinityTsNegativeMustBeSmaller { - t.Errorf("Expected error, %q, got %q", infinityTsNegativeMustBeSmaller, panicErrorString) + if panicErrorString != infinityTSNegativeMustBeSmaller { + t.Errorf("Expected error, %q, got %q", infinityTSNegativeMustBeSmaller, panicErrorString) } } diff --git a/error.go b/error.go index c8ad0e0a9..234d39e2c 100644 --- a/error.go +++ b/error.go @@ -11,7 +11,7 @@ import ( "unicode/utf8" ) -// Error severities +// [pq.Error.Severity] values. const ( Efatal = "FATAL" Epanic = "PANIC" diff --git a/example_test.go b/example_test.go index 0356798a5..2a8253f46 100644 --- a/example_test.go +++ b/example_test.go @@ -12,7 +12,7 @@ import ( ) func ExampleNewConnector() { - c, err := pq.NewConnector("postgres://") + c, err := pq.NewConnector("host=postgres dbname=pqgo") if err != nil { log.Fatalf("could not create connector: %v", err) } @@ -29,9 +29,36 @@ func ExampleNewConnector() { // Output: } +func ExampleNewConfig() { + cfg, err := pq.NewConfig("host=postgres dbname=pqgo") + if err != nil { + log.Fatal(err) + } + if cfg.Host == "localhost" { + cfg.Host = "127.0.0.1" + } + + c, err := pq.NewConnectorConfig(cfg) + if err != nil { + log.Fatal(err) + } + + db := sql.OpenDB(c) + defer db.Close() + + // Use the DB + tx, err := db.Begin() + if err != nil { + log.Fatalf("could not start transaction: %v", err) + } + tx.Rollback() + // Output: +} + func ExampleConnectorWithNoticeHandler() { // Base connector to wrap - base, err := pq.NewConnector("postgres://") + dsn := "" + base, err := pq.NewConnector(dsn) if err != nil { log.Fatal(err) } diff --git a/internal/pqtest/fake.go b/internal/pqtest/fake.go index 8692a1201..eba56afba 100644 --- a/internal/pqtest/fake.go +++ b/internal/pqtest/fake.go @@ -66,7 +66,7 @@ func (f Fake) Startup(cn net.Conn) { f.WriteMsg(cn, proto.ReadyForQuery, 'I') } -// ReadStart reads the startup message. +// ReadStartup reads the startup message. func (f Fake) ReadStartup(cn net.Conn) bool { _, _, ok := f.read(cn, true) return ok diff --git a/internal/pqutil/path.go b/internal/pqutil/path.go index 0617bbe41..e6827a96f 100644 --- a/internal/pqutil/path.go +++ b/internal/pqutil/path.go @@ -8,7 +8,8 @@ import ( "runtime" ) -// Matches pqGetHomeDirectory() from PostgreSQL +// Home gets the user's home directory. Matches pqGetHomeDirectory() from +// PostgreSQL // // https://github.com/postgres/postgres/blob/2b117bb/src/interfaces/libpq/fe-connect.c#L8214 func Home() string { diff --git a/internal/pqutil/perm.go b/internal/pqutil/perm.go index 7022ef906..fdfa94a07 100644 --- a/internal/pqutil/perm.go +++ b/internal/pqutil/perm.go @@ -13,9 +13,9 @@ var ( ErrSSLKeyHasWorldPermissions = errors.New("pq: private key has world access; permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") ) -// KeyPermissions checks the permissions on user-supplied SSL key files, which -// should have very little access. libpq does not check key file permissions on -// Windows. +// SSLKeyPermissions checks the permissions on user-supplied SSL key files, +// which should have very little access. libpq does not check key file +// permissions on Windows. // // If the file is owned by the same user the process is running as, the file // should only have 0600. If the file is owned by root, and the group matches diff --git a/internal/pqutil/perm_test.go b/internal/pqutil/perm_test.go index e607d83a0..1c6024a2f 100644 --- a/internal/pqutil/perm_test.go +++ b/internal/pqutil/perm_test.go @@ -11,14 +11,14 @@ import ( "github.com/lib/pq/internal/pqtest" ) -type stat_t_wrapper struct{ stat syscall.Stat_t } +type statWrapper struct{ stat syscall.Stat_t } -func (stat_t *stat_t_wrapper) Name() string { return "pem.key" } -func (stat_t *stat_t_wrapper) Size() int64 { return int64(100) } -func (stat_t *stat_t_wrapper) Mode() os.FileMode { return os.FileMode(stat_t.stat.Mode) } -func (stat_t *stat_t_wrapper) ModTime() time.Time { return time.Now() } -func (stat_t *stat_t_wrapper) IsDir() bool { return true } -func (stat_t *stat_t_wrapper) Sys() any { return &stat_t.stat } +func (stat_t *statWrapper) Name() string { return "pem.key" } +func (stat_t *statWrapper) Size() int64 { return int64(100) } +func (stat_t *statWrapper) Mode() os.FileMode { return os.FileMode(stat_t.stat.Mode) } +func (stat_t *statWrapper) ModTime() time.Time { return time.Now() } +func (stat_t *statWrapper) IsDir() bool { return true } +func (stat_t *statWrapper) Sys() any { return &stat_t.stat } func TestSSLKeyPermissions(t *testing.T) { currentUID := uint32(os.Getuid()) @@ -36,7 +36,7 @@ func TestSSLKeyPermissions(t *testing.T) { for _, tt := range tests { t.Run("", func(t *testing.T) { - have := checkPermissions(&stat_t_wrapper{stat: tt.stat}) + have := checkPermissions(&statWrapper{stat: tt.stat}) if !pqtest.ErrorContains(have, tt.wantErr) { t.Errorf("\nhave: %s\nwant: %s", have, tt.wantErr) } diff --git a/internal/pqutil/pqutil.go b/internal/pqutil/pqutil.go index f83b71a4c..f14954dda 100644 --- a/internal/pqutil/pqutil.go +++ b/internal/pqutil/pqutil.go @@ -16,7 +16,8 @@ func ParseBool(str string) (bool, error) { return false, &strconv.NumError{Func: "ParseBool", Num: str, Err: strconv.ErrSyntax} } -// We want to retain compat with Go 1.18, and slices wasn't added until 1.21 +// Contains is [slices.Contains]. We want to retain compat with Go 1.18, and +// slices wasn't added until 1.21 func Contains[S ~[]E, E comparable](s S, v E) bool { for i := range s { if v == s[i] { diff --git a/ssl.go b/ssl.go index 128aac4c3..ed0ec4f31 100644 --- a/ssl.go +++ b/ssl.go @@ -41,12 +41,19 @@ func RegisterTLSConfig(key string, config *tls.Config) error { return nil } +func hasTLSConfig(key string) bool { + tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() + _, ok := tlsConfs[key] + return ok +} + func getTLSConfigClone(key string) *tls.Config { tlsConfsMu.RLock() + defer tlsConfsMu.RUnlock() if v, ok := tlsConfs[key]; ok { return v.Clone() } - tlsConfsMu.RUnlock() return nil } @@ -101,9 +108,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { } // Set Server Name Indication (SNI), if enabled by connection parameters. - // By default SNI is on, any value which is not starting with "1" disables - // SNI -- that is the same check vanilla libpq uses. - if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { + if o["sslsni"] == "yes" { // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so // just always set ServerName here and let crypto/tls do the filtering. @@ -152,8 +157,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // in the user's home directory. The configured files must exist and have // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) error { - sslinline := o["sslinline"] - if sslinline == "true" { + if o["sslinline"] == "yes" { cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) if err != nil { return err @@ -225,10 +229,8 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) error { if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() - sslinline := o["sslinline"] - var cert []byte - if sslinline == "true" { + if o["sslinline"] == "yes" { cert = []byte(sslrootcert) } else { var err error diff --git a/staticcheck.conf b/staticcheck.conf new file mode 100644 index 000000000..83abe48e5 --- /dev/null +++ b/staticcheck.conf @@ -0,0 +1,5 @@ +checks = [ + 'all', + '-ST1000', # "Must have at least one package comment" + '-ST1003', # "func EnableInfinityTs should be EnableInfinityTS" +]