Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/pgmock/pgmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ func WaitForClose() Step {

func AcceptUnauthenticatedConnRequestSteps() []Step {
return []Step{
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}),
SendMessage(&pgproto3.AuthenticationOk{}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}
}
47 changes: 47 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ type Config struct {
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler

// MinProtocolVersion is the minimum acceptable PostgreSQL protocol version.
// If the server does not support at least this version, the connection will fail.
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0".
MinProtocolVersion string

// MaxProtocolVersion is the maximum PostgreSQL protocol version to request from the server.
// Valid values: "3.0", "3.2", "latest". Defaults to "3.0" for compatibility.
MaxProtocolVersion string

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down Expand Up @@ -213,6 +222,8 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
// PGTZ
// PGMINPROTOCOLVERSION
// PGMAXPROTOCOLVERSION
//
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
//
Expand Down Expand Up @@ -338,6 +349,8 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"target_session_attrs": {},
"service": {},
"servicefile": {},
"min_protocol_version": {},
"max_protocol_version": {},
}

// Adding kerberos configuration
Expand Down Expand Up @@ -430,6 +443,27 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}

minProto, err := parseProtocolVersion(settings["min_protocol_version"])
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid min_protocol_version", err: err}
}
maxProto, err := parseProtocolVersion(settings["max_protocol_version"])
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid max_protocol_version", err: err}
}
if minProto > maxProto {
return nil, &ParseConfigError{ConnString: connString, msg: "min_protocol_version cannot be greater than max_protocol_version"}
}

config.MinProtocolVersion = settings["min_protocol_version"]
config.MaxProtocolVersion = settings["max_protocol_version"]
if config.MinProtocolVersion == "" {
config.MinProtocolVersion = "3.0"
}
if config.MaxProtocolVersion == "" {
config.MaxProtocolVersion = "3.0"
}

return config, nil
}

Expand Down Expand Up @@ -467,6 +501,8 @@ func parseEnvSettings() map[string]string {
"PGSERVICEFILE": "servicefile",
"PGTZ": "timezone",
"PGOPTIONS": "options",
"PGMINPROTOCOLVERSION": "min_protocol_version",
"PGMAXPROTOCOLVERSION": "max_protocol_version",
}

for envname, realname := range nameMap {
Expand Down Expand Up @@ -960,3 +996,14 @@ func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn

return nil
}

func parseProtocolVersion(s string) (uint32, error) {
switch s {
case "", "3.0":
return pgproto3.ProtocolVersion30, nil
case "3.2", "latest":
return pgproto3.ProtocolVersion32, nil
default:
return 0, fmt.Errorf("invalid protocol version: %q", s)
}
}
116 changes: 116 additions & 0 deletions pgconn/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1198,3 +1198,119 @@ func TestParseConfigExplicitEmptyUserDefaultsToOSUser(t *testing.T) {
)
}
}

func TestParseConfigProtocolVersion(t *testing.T) {
tests := []struct {
name string
connString string
envMin string
envMax string
expectedMin string
expectedMax string
expectError bool
expectedErrContain string
}{
{
name: "defaults to 3.0",
connString: "postgres://localhost/test",
expectedMin: "3.0",
expectedMax: "3.0",
},
{
name: "max_protocol_version=3.2",
connString: "postgres://localhost/test?max_protocol_version=3.2",
expectedMin: "3.0",
expectedMax: "3.2",
},
{
name: "min_protocol_version=3.2 and max_protocol_version=3.2",
connString: "postgres://localhost/test?min_protocol_version=3.2&max_protocol_version=3.2",
expectedMin: "3.2",
expectedMax: "3.2",
},
{
name: "max_protocol_version=latest",
connString: "postgres://localhost/test?max_protocol_version=latest",
expectedMin: "3.0",
expectedMax: "latest",
},
{
name: "min and max = latest",
connString: "postgres://localhost/test?min_protocol_version=latest&max_protocol_version=latest",
expectedMin: "latest",
expectedMax: "latest",
},
{
name: "invalid min_protocol_version",
connString: "postgres://localhost/test?min_protocol_version=2.0",
expectError: true,
expectedErrContain: "invalid min_protocol_version",
},
{
name: "invalid max_protocol_version",
connString: "postgres://localhost/test?max_protocol_version=4.0",
expectError: true,
expectedErrContain: "invalid max_protocol_version",
},
{
name: "min > max",
connString: "postgres://localhost/test?min_protocol_version=3.2&max_protocol_version=3.0",
expectError: true,
expectedErrContain: "min_protocol_version cannot be greater than max_protocol_version",
},
{
name: "environment variable PGMINPROTOCOLVERSION without matching max fails",
connString: "postgres://localhost/test",
envMin: "3.2",
expectError: true,
expectedErrContain: "min_protocol_version cannot be greater than max_protocol_version",
},
{
name: "environment variables PGMINPROTOCOLVERSION and PGMAXPROTOCOLVERSION together",
connString: "postgres://localhost/test",
envMin: "3.2",
envMax: "3.2",
expectedMin: "3.2",
expectedMax: "3.2",
},
{
name: "environment variable PGMAXPROTOCOLVERSION",
connString: "postgres://localhost/test",
envMax: "3.2",
expectedMin: "3.0",
expectedMax: "3.2",
},
{
name: "conn string overrides environment variable",
connString: "postgres://localhost/test?max_protocol_version=3.0",
envMax: "3.2",
expectedMin: "3.0",
expectedMax: "3.0",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear protocol version env vars
t.Setenv("PGMINPROTOCOLVERSION", "")
t.Setenv("PGMAXPROTOCOLVERSION", "")

if tt.envMin != "" {
t.Setenv("PGMINPROTOCOLVERSION", tt.envMin)
}
if tt.envMax != "" {
t.Setenv("PGMAXPROTOCOLVERSION", tt.envMax)
}

config, err := pgconn.ParseConfig(tt.connString)
if tt.expectError {
require.ErrorContains(t, err, tt.expectedErrContain)
return
}

require.NoError(t, err)
assert.Equal(t, tt.expectedMin, config.MinProtocolVersion, "MinProtocolVersion")
assert.Equal(t, tt.expectedMax, config.MaxProtocolVersion, "MaxProtocolVersion")
})
}
}
29 changes: 22 additions & 7 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ type NotificationHandler func(*PgConn, *Notification)
type PgConn struct {
conn net.Conn
pid uint32 // backend pid
secretKey uint32 // key to use to send a cancel query message to the server
secretKey []byte // key to use to send a cancel query message to the server
parameterStatuses map[string]string // parameters that have been reported by the server
txStatus byte
frontend *pgproto3.Frontend
Expand Down Expand Up @@ -319,6 +319,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
return e
}

maxProtocolVersion, err := parseProtocolVersion(config.MaxProtocolVersion)
if err != nil {
return nil, newPerDialConnectError("invalid max_protocol_version", err)
}
minProtocolVersion, err := parseProtocolVersion(config.MinProtocolVersion)
if err != nil {
return nil, newPerDialConnectError("invalid min_protocol_version", err)
}

pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address)
if err != nil {
return nil, newPerDialConnectError("dial error", err)
Expand Down Expand Up @@ -371,7 +380,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)

startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
ProtocolVersion: maxProtocolVersion,
Parameters: make(map[string]string),
}

Expand Down Expand Up @@ -452,6 +461,12 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
return pgConn, nil
case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
// handled by ReceiveMessage
case *pgproto3.NegotiateProtocolVersion:
serverVersion := pgproto3.ProtocolVersion30&0xFFFF0000 | uint32(msg.NewestMinorProtocol)
if serverVersion < minProtocolVersion {
pgConn.conn.Close()
return nil, newPerDialConnectError("server protocol version too low", nil)
}
case *pgproto3.ErrorResponse:
pgConn.conn.Close()
return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg))
Expand Down Expand Up @@ -641,7 +656,7 @@ func (pgConn *PgConn) TxStatus() byte {
}

// SecretKey returns the backend secret key used to send a cancel query message to the server.
func (pgConn *PgConn) SecretKey() uint32 {
func (pgConn *PgConn) SecretKey() []byte {
return pgConn.secretKey
}

Expand Down Expand Up @@ -1040,11 +1055,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
defer contextWatcher.Unwatch()
}

buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
buf := make([]byte, 12+len(pgConn.secretKey))
binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf)))
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], pgConn.pid)
binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey)
copy(buf[12:], pgConn.secretKey)

if _, err := cancelConn.Write(buf); err != nil {
return fmt.Errorf("write to connection for cancellation: %w", err)
Expand Down Expand Up @@ -2077,7 +2092,7 @@ func (pgConn *PgConn) CustomData() map[string]any {
type HijackedConn struct {
Conn net.Conn
PID uint32 // backend pid
SecretKey uint32 // key to use to send a cancel query message to the server
SecretKey []byte // key to use to send a cancel query message to the server
ParameterStatuses map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend *pgproto3.Frontend
Expand Down
40 changes: 37 additions & 3 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ func TestConnectTimeout(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersion30, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
Expand Down Expand Up @@ -4112,7 +4112,7 @@ func TestSNISupport(t *testing.T) {
}

srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: []byte{0, 0, 0, 0}}).Encode(nil)))
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))

serverSNINameChan <- sniHost
Expand Down Expand Up @@ -4553,3 +4553,37 @@ func TestCancelRequestContextWatcherHandler(t *testing.T) {
})
}
}

func TestConnectProtocolVersion32(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE") + " max_protocol_version=3.2")
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the failures in CI, I think this might need to change to parse the config, then set max protocol version directly on the config struct. I think that string concatenation is yielding an invalid connection string.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jackc Pushed!

require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)

if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("CockroachDB does not support protocol version 3.2 yet")
}

result, err := pgConn.Exec(context.Background(), "show server_version_num").ReadAll()
require.NoError(t, err)
require.Len(t, result, 1)
require.Len(t, result[0].Rows, 1)
require.Len(t, result[0].Rows[0], 1)
pgVersion, err := strconv.Atoi(string(result[0].Rows[0][0]))
require.NoError(t, err)

// Check secret key length - PG18+ returns 32 bytes, older versions return 4
secretKey := pgConn.SecretKey()

if pgVersion < 180000 {
assert.Len(t, secretKey, 4)
} else {
assert.Len(t, secretKey, 32)
}
}
2 changes: 1 addition & 1 deletion pgproto3/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
code := binary.BigEndian.Uint32(buf)

switch code {
case ProtocolVersionNumber:
case ProtocolVersion30, ProtocolVersion32:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
Expand Down
Loading