@@ -4,8 +4,10 @@ import (
4
4
"bytes"
5
5
"encoding/hex"
6
6
"fmt"
7
+ "slices"
7
8
"strconv"
8
9
"strings"
10
+ "sync"
9
11
"time"
10
12
"unicode/utf8"
11
13
)
@@ -24,53 +26,75 @@ type Query struct {
24
26
// https://github.com/jackc/pgx/issues/1380
25
27
const replacementcharacterwidth = 3
26
28
29
+ const maxBufSize = 16384 // 16 Ki
30
+
31
+ var bufPool = & pool [* bytes.Buffer ]{
32
+ new : func () * bytes.Buffer {
33
+ return & bytes.Buffer {}
34
+ },
35
+ reset : func (b * bytes.Buffer ) bool {
36
+ n := b .Len ()
37
+ b .Reset ()
38
+ return n < maxBufSize
39
+ },
40
+ }
41
+
42
+ var null = []byte ("null" )
43
+
27
44
func (q * Query ) Sanitize (args ... any ) (string , error ) {
28
45
argUse := make ([]bool , len (args ))
29
- buf := & bytes.Buffer {}
46
+ buf := bufPool .get ()
47
+ defer bufPool .put (buf )
30
48
31
49
for _ , part := range q .Parts {
32
- var str string
33
50
switch part := part .(type ) {
34
51
case string :
35
- str = part
52
+ buf . WriteString ( part )
36
53
case int :
37
54
argIdx := part - 1
38
-
55
+ var p [] byte
39
56
if argIdx < 0 {
40
57
return "" , fmt .Errorf ("first sql argument must be > 0" )
41
58
}
42
59
43
60
if argIdx >= len (args ) {
44
61
return "" , fmt .Errorf ("insufficient arguments" )
45
62
}
63
+
64
+ // Prevent SQL injection via Line Comment Creation
65
+ // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
66
+ buf .WriteByte (' ' )
67
+
46
68
arg := args [argIdx ]
47
69
switch arg := arg .(type ) {
48
70
case nil :
49
- str = " null"
71
+ p = null
50
72
case int64 :
51
- str = strconv .FormatInt ( arg , 10 )
73
+ p = strconv .AppendInt ( buf . AvailableBuffer (), arg , 10 )
52
74
case float64 :
53
- str = strconv .FormatFloat ( arg , 'f' , - 1 , 64 )
75
+ p = strconv .AppendFloat ( buf . AvailableBuffer (), arg , 'f' , - 1 , 64 )
54
76
case bool :
55
- str = strconv .FormatBool ( arg )
77
+ p = strconv .AppendBool ( buf . AvailableBuffer (), arg )
56
78
case []byte :
57
- str = QuoteBytes (arg )
79
+ p = QuoteBytes (buf . AvailableBuffer (), arg )
58
80
case string :
59
- str = QuoteString (arg )
81
+ p = QuoteString (buf . AvailableBuffer (), arg )
60
82
case time.Time :
61
- str = arg .Truncate (time .Microsecond ).Format ("'2006-01-02 15:04:05.999999999Z07:00:00'" )
83
+ p = arg .Truncate (time .Microsecond ).
84
+ AppendFormat (buf .AvailableBuffer (), "'2006-01-02 15:04:05.999999999Z07:00:00'" )
62
85
default :
63
86
return "" , fmt .Errorf ("invalid arg type: %T" , arg )
64
87
}
65
88
argUse [argIdx ] = true
66
89
90
+ buf .Write (p )
91
+
67
92
// Prevent SQL injection via Line Comment Creation
68
93
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
69
- str = " " + str + " "
94
+ buf . WriteByte ( ' ' )
70
95
default :
71
96
return "" , fmt .Errorf ("invalid Part type: %T" , part )
72
97
}
73
- buf .WriteString (str )
74
98
}
75
99
76
100
for i , used := range argUse {
@@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
82
106
}
83
107
84
108
func NewQuery (sql string ) (* Query , error ) {
85
- l := & sqlLexer {
86
- src : sql ,
87
- stateFn : rawState ,
109
+ query := & Query {}
110
+ query .init (sql )
111
+
112
+ return query , nil
113
+ }
114
+
115
+ var sqlLexerPool = & pool [* sqlLexer ]{
116
+ new : func () * sqlLexer {
117
+ return & sqlLexer {}
118
+ },
119
+ reset : func (sl * sqlLexer ) bool {
120
+ * sl = sqlLexer {}
121
+ return true
122
+ },
123
+ }
124
+
125
+ func (q * Query ) init (sql string ) {
126
+ parts := q .Parts [:0 ]
127
+ if parts == nil {
128
+ // dirty, but fast heuristic to preallocate for ~90% usecases
129
+ n := strings .Count (sql , "$" ) + strings .Count (sql , "--" ) + 1
130
+ parts = make ([]Part , 0 , n )
88
131
}
89
132
133
+ l := sqlLexerPool .get ()
134
+ defer sqlLexerPool .put (l )
135
+
136
+ l .src = sql
137
+ l .stateFn = rawState
138
+ l .parts = parts
139
+
90
140
for l .stateFn != nil {
91
141
l .stateFn = l .stateFn (l )
92
142
}
93
143
94
- query := & Query {Parts : l .parts }
95
-
96
- return query , nil
144
+ q .Parts = l .parts
97
145
}
98
146
99
- func QuoteString (str string ) string {
100
- return "'" + strings .ReplaceAll (str , "'" , "''" ) + "'"
147
+ func QuoteString (dst []byte , str string ) []byte {
148
+ const quote = '\''
149
+
150
+ // Preallocate space for the worst case scenario
151
+ dst = slices .Grow (dst , len (str )* 2 + 2 )
152
+
153
+ // Add opening quote
154
+ dst = append (dst , quote )
155
+
156
+ // Iterate through the string without allocating
157
+ for i := 0 ; i < len (str ); i ++ {
158
+ if str [i ] == quote {
159
+ dst = append (dst , quote , quote )
160
+ } else {
161
+ dst = append (dst , str [i ])
162
+ }
163
+ }
164
+
165
+ // Add closing quote
166
+ dst = append (dst , quote )
167
+
168
+ return dst
101
169
}
102
170
103
- func QuoteBytes (buf []byte ) string {
104
- return `'\x` + hex .EncodeToString (buf ) + "'"
171
+ func QuoteBytes (dst , buf []byte ) []byte {
172
+ if len (buf ) == 0 {
173
+ return append (dst , `'\x'` ... )
174
+ }
175
+
176
+ // Calculate required length
177
+ requiredLen := 3 + hex .EncodedLen (len (buf )) + 1
178
+
179
+ // Ensure dst has enough capacity
180
+ if cap (dst )- len (dst ) < requiredLen {
181
+ newDst := make ([]byte , len (dst ), len (dst )+ requiredLen )
182
+ copy (newDst , dst )
183
+ dst = newDst
184
+ }
185
+
186
+ // Record original length and extend slice
187
+ origLen := len (dst )
188
+ dst = dst [:origLen + requiredLen ]
189
+
190
+ // Add prefix
191
+ dst [origLen ] = '\''
192
+ dst [origLen + 1 ] = '\\'
193
+ dst [origLen + 2 ] = 'x'
194
+
195
+ // Encode bytes directly into dst
196
+ hex .Encode (dst [origLen + 3 :len (dst )- 1 ], buf )
197
+
198
+ // Add suffix
199
+ dst [len (dst )- 1 ] = '\''
200
+
201
+ return dst
105
202
}
106
203
107
204
type sqlLexer struct {
@@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
319
416
}
320
417
}
321
418
419
+ var queryPool = & pool [* Query ]{
420
+ new : func () * Query {
421
+ return & Query {}
422
+ },
423
+ reset : func (q * Query ) bool {
424
+ n := len (q .Parts )
425
+ q .Parts = q .Parts [:0 ]
426
+ return n < 64 // drop too large queries
427
+ },
428
+ }
429
+
322
430
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
323
431
// as necessary. This function is only safe when standard_conforming_strings is
324
432
// on.
325
433
func SanitizeSQL (sql string , args ... any ) (string , error ) {
326
- query , err := NewQuery ( sql )
327
- if err != nil {
328
- return "" , err
329
- }
434
+ query := queryPool . get ( )
435
+ query . init ( sql )
436
+ defer queryPool . put ( query )
437
+
330
438
return query .Sanitize (args ... )
331
439
}
440
+
441
+ type pool [E any ] struct {
442
+ p sync.Pool
443
+ new func () E
444
+ reset func (E ) bool
445
+ }
446
+
447
+ func (pool * pool [E ]) get () E {
448
+ v , ok := pool .p .Get ().(E )
449
+ if ! ok {
450
+ v = pool .new ()
451
+ }
452
+
453
+ return v
454
+ }
455
+
456
+ func (p * pool [E ]) put (v E ) {
457
+ if p .reset (v ) {
458
+ p .p .Put (v )
459
+ }
460
+ }
0 commit comments