-
Notifications
You must be signed in to change notification settings - Fork 20
/
options.go
263 lines (226 loc) · 8.49 KB
/
options.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
package wire
import (
"context"
"crypto/tls"
"log/slog"
"regexp"
"strconv"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jeroenrinzema/psql-wire/pkg/buffer"
"github.com/lib/pq/oid"
)
// ParseFn parses the given query and returns a prepared statement which could
// be used to execute at a later point in time.
type ParseFn func(ctx context.Context, query string) (PreparedStatements, error)
// PreparedStatementFn represents a query of which a statement has been
// prepared. The statement could be executed at any point in time with the given
// arguments and data writer.
type PreparedStatementFn func(ctx context.Context, writer DataWriter, parameters []Parameter) error
// Prepared is a small wrapper function returning a list of prepared statements.
// More then one prepared statement could be returned within the simple query
// protocol. An error is returned when more than one prepared statement is
// returned in the [extended query protocol].
//
// [extended query protocol]: https://www.postgresql.org/docs/15/protocol-flow.html#PROTOCOL-FLOW-MULTI-STATEMENT
func Prepared(stmts ...*PreparedStatement) PreparedStatements {
return stmts
}
// NewStatement constructs a new prepared statement for the given function.
func NewStatement(fn PreparedStatementFn, options ...PreparedOptionFn) *PreparedStatement {
stmt := &PreparedStatement{
fn: fn,
}
for _, option := range options {
option(stmt)
}
return stmt
}
// PreparedOptionFn options pattern used to define options while preparing a new statement.
type PreparedOptionFn func(*PreparedStatement)
// WithColumns sets the given columns as the columns which are returned by the
// prepared statement.
func WithColumns(columns Columns) PreparedOptionFn {
return func(stmt *PreparedStatement) {
stmt.columns = columns
}
}
// WithParameters sets the given parameters as the parameters which are expected
// by the prepared statement.
func WithParameters(parameters []oid.Oid) PreparedOptionFn {
return func(stmt *PreparedStatement) {
stmt.parameters = parameters
}
}
type PreparedStatements []*PreparedStatement
type PreparedStatement struct {
fn PreparedStatementFn
parameters []oid.Oid
columns Columns
}
// SessionHandler represents a wrapper function defining the state of a single
// session. This function allows the user to wrap additional metadata around the
// shared context.
type SessionHandler func(ctx context.Context) (context.Context, error)
// StatementCache represents a cache which could be used to store and retrieve
// prepared statements bound to a name.
type StatementCache interface {
// Set attempts to bind the given statement to the given name. Any
// previously defined statement is overridden.
Set(ctx context.Context, name string, fn *PreparedStatement) error
// Get attempts to get the prepared statement for the given name. An error
// is returned when no statement has been found.
Get(ctx context.Context, name string) (*Statement, error)
}
// PortalCache represents a cache which could be used to bind and execute
// prepared statements with parameters.
type PortalCache interface {
Bind(ctx context.Context, name string, statement *Statement, parameters []Parameter, columns []FormatCode) error
Get(ctx context.Context, name string) (*Portal, error)
Execute(ctx context.Context, name string, reader *buffer.Reader, writer *buffer.Writer) error
}
type CloseFn func(ctx context.Context) error
// OptionFn options pattern used to define and set options for the given
// PostgreSQL server.
type OptionFn func(*Server) error
// Statements sets the statement cache used to cache statements for later use. By
// default [DefaultStatementCache] is used.
func Statements(handler func() StatementCache) OptionFn {
return func(srv *Server) error {
srv.Statements = handler
return nil
}
}
// Portals sets the portals cache used to cache statements for later use. By
// default [DefaultPortalCache] is used.
func Portals(handler func() PortalCache) OptionFn {
return func(srv *Server) error {
srv.Portals = handler
return nil
}
}
// CloseConn sets the close connection handle inside the given server instance.
func CloseConn(fn CloseFn) OptionFn {
return func(srv *Server) error {
srv.CloseConn = fn
return nil
}
}
// TerminateConn sets the terminate connection handle inside the given server instance.
func TerminateConn(fn CloseFn) OptionFn {
return func(srv *Server) error {
srv.TerminateConn = fn
return nil
}
}
// MessageBufferSize sets the message buffer size which is allocated once a new
// connection gets constructed. If a negative value or zero value is provided is
// the default message buffer size used.
func MessageBufferSize(size int) OptionFn {
return func(srv *Server) error {
srv.BufferedMsgSize = size
return nil
}
}
// TLSConfig sets the given TLS config to be used to initialize a
// secure connection between the front-end (client) and back-end (server).
func TLSConfig(config *tls.Config) OptionFn {
return func(srv *Server) error {
srv.TLSConfig = config
return nil
}
}
// SessionAuthStrategy sets the given authentication strategy within the given
// server. The authentication strategy is called when a handshake is initiated.
func SessionAuthStrategy(fn AuthStrategy) OptionFn {
return func(srv *Server) error {
srv.Auth = fn
return nil
}
}
// GlobalParameters sets the server parameters which are send back to the
// front-end (client) once a handshake has been established.
func GlobalParameters(params Parameters) OptionFn {
return func(srv *Server) error {
srv.Parameters = params
return nil
}
}
// Logger sets the given [slog.Logger] as the logger for the given server.
func Logger(logger *slog.Logger) OptionFn {
return func(srv *Server) error {
srv.logger = logger
return nil
}
}
// Version sets the PostgreSQL version for the server which is send back to the
// front-end (client) once a handshake has been established.
func Version(version string) OptionFn {
return func(srv *Server) error {
srv.Version = version
return nil
}
}
// ExtendTypes provides the ability to extend the underlying connection types.
// Types registered inside the given [github.com/jackc/pgx/v5/pgtype.Map] are
// registered to all incoming connections.
func ExtendTypes(fn func(*pgtype.Map)) OptionFn {
return func(srv *Server) error {
fn(srv.types)
return nil
}
}
// SessionMiddleware sets the given session handler within the underlying server. The
// session handler is called when a new connection is opened and authenticated
// allowing for additional metadata to be wrapped around the connection context.
func SessionMiddleware(fn SessionHandler) OptionFn {
return func(srv *Server) error {
if srv.Session == nil {
srv.Session = fn
return nil
}
wrapper := func(parent SessionHandler) SessionHandler {
return func(ctx context.Context) (context.Context, error) {
ctx, err := parent(ctx)
if err != nil {
return ctx, err
}
return fn(ctx)
}
}
srv.Session = wrapper(srv.Session)
return nil
}
}
// QueryParameters represents a regex which could be used to identify and lookup
// parameters defined inside a given query. Parameters could be defined as
// [positional parameters] and non-positional parameters.
//
// [positional parameters]: https://www.postgresql.org/docs/15/sql-expressions.html#SQL-EXPRESSIONS-PARAMETERS-POSITIONAL
var QueryParameters = regexp.MustCompile(`\$(\d+)|\?`)
// ParseParameters attempts to parse the parameters in the given string and
// returns the expected parameters. This is necessary for the query protocol
// where the parameter types are expected to be defined in the extended query protocol.
func ParseParameters(query string) []oid.Oid {
// NOTE: we have to lookup all parameters within the given query.
// Parameters could represent positional parameters or anonymous
// parameters. We return a zero parameter oid for each parameter
// indicating that the given parameters could contain any type. We
// could safely ignore the err check while converting given
// parameters since ony matches are returned by the positional
// parameter regex.
matches := QueryParameters.FindAllStringSubmatch(query, -1)
parameters := make([]oid.Oid, 0, len(matches))
for _, match := range matches {
// NOTE: we have to check whether the returned match is a
// positional parameter or an un-positional parameter.
// SELECT * FROM users WHERE id = ?
if match[1] == "" {
parameters = append(parameters, 0)
}
position, _ := strconv.Atoi(match[1]) //nolint:errcheck
if position > len(parameters) {
parameters = parameters[:position]
}
}
return parameters
}