Skip to content

Commit ca04098

Browse files
authored
Merge pull request #2136 from ninedraft/optimize-sanitize
Reduce SQL sanitizer allocations
2 parents 4ff0a45 + e452f80 commit ca04098

File tree

6 files changed

+387
-28
lines changed

6 files changed

+387
-28
lines changed

conn_test.go

-1
Original file line numberDiff line numberDiff line change
@@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
14171417
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
14181418

14191419
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
1420-
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
14211420
}

internal/sanitize/benchmmark.sh

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/env bash
2+
3+
current_branch=$(git rev-parse --abbrev-ref HEAD)
4+
if [ "$current_branch" == "HEAD" ]; then
5+
current_branch=$(git rev-parse HEAD)
6+
fi
7+
8+
restore_branch() {
9+
echo "Restoring original branch/commit: $current_branch"
10+
git checkout "$current_branch"
11+
}
12+
trap restore_branch EXIT
13+
14+
# Check if there are uncommitted changes
15+
if ! git diff --quiet || ! git diff --cached --quiet; then
16+
echo "There are uncommitted changes. Please commit or stash them before running this script."
17+
exit 1
18+
fi
19+
20+
# Ensure that at least one commit argument is passed
21+
if [ "$#" -lt 1 ]; then
22+
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
23+
exit 1
24+
fi
25+
26+
commits=("$@")
27+
benchmarks_dir=benchmarks
28+
29+
if ! mkdir -p "${benchmarks_dir}"; then
30+
echo "Unable to create dir for benchmarks data"
31+
exit 1
32+
fi
33+
34+
# Benchmark results
35+
bench_files=()
36+
37+
# Run benchmark for each listed commit
38+
for i in "${!commits[@]}"; do
39+
commit="${commits[i]}"
40+
git checkout "$commit" || {
41+
echo "Failed to checkout $commit"
42+
exit 1
43+
}
44+
45+
# Sanitized commmit message
46+
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
47+
48+
# Benchmark data will go there
49+
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
50+
51+
if ! go test -bench=. -count=10 >"$bench_file"; then
52+
echo "Benchmarking failed for commit $commit"
53+
exit 1
54+
fi
55+
56+
bench_files+=("$bench_file")
57+
done
58+
59+
# go install golang.org/x/perf/cmd/benchstat[@latest]
60+
benchstat "${bench_files[@]}"

internal/sanitize/sanitize.go

+156-27
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"bytes"
55
"encoding/hex"
66
"fmt"
7+
"slices"
78
"strconv"
89
"strings"
10+
"sync"
911
"time"
1012
"unicode/utf8"
1113
)
@@ -24,53 +26,75 @@ type Query struct {
2426
// https://github.com/jackc/pgx/issues/1380
2527
const replacementcharacterwidth = 3
2628

29+
const maxBufSize = 16384 // 16 Ki
30+
31+
var bufPool = &pool[*bytes.Buffer]{
32+
new: func() *bytes.Buffer {
33+
return &bytes.Buffer{}
34+
},
35+
reset: func(b *bytes.Buffer) bool {
36+
n := b.Len()
37+
b.Reset()
38+
return n < maxBufSize
39+
},
40+
}
41+
42+
var null = []byte("null")
43+
2744
func (q *Query) Sanitize(args ...any) (string, error) {
2845
argUse := make([]bool, len(args))
29-
buf := &bytes.Buffer{}
46+
buf := bufPool.get()
47+
defer bufPool.put(buf)
3048

3149
for _, part := range q.Parts {
32-
var str string
3350
switch part := part.(type) {
3451
case string:
35-
str = part
52+
buf.WriteString(part)
3653
case int:
3754
argIdx := part - 1
38-
55+
var p []byte
3956
if argIdx < 0 {
4057
return "", fmt.Errorf("first sql argument must be > 0")
4158
}
4259

4360
if argIdx >= len(args) {
4461
return "", fmt.Errorf("insufficient arguments")
4562
}
63+
64+
// Prevent SQL injection via Line Comment Creation
65+
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
66+
buf.WriteByte(' ')
67+
4668
arg := args[argIdx]
4769
switch arg := arg.(type) {
4870
case nil:
49-
str = "null"
71+
p = null
5072
case int64:
51-
str = strconv.FormatInt(arg, 10)
73+
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
5274
case float64:
53-
str = strconv.FormatFloat(arg, 'f', -1, 64)
75+
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
5476
case bool:
55-
str = strconv.FormatBool(arg)
77+
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
5678
case []byte:
57-
str = QuoteBytes(arg)
79+
p = QuoteBytes(buf.AvailableBuffer(), arg)
5880
case string:
59-
str = QuoteString(arg)
81+
p = QuoteString(buf.AvailableBuffer(), arg)
6082
case time.Time:
61-
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
83+
p = arg.Truncate(time.Microsecond).
84+
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
6285
default:
6386
return "", fmt.Errorf("invalid arg type: %T", arg)
6487
}
6588
argUse[argIdx] = true
6689

90+
buf.Write(p)
91+
6792
// Prevent SQL injection via Line Comment Creation
6893
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
69-
str = " " + str + " "
94+
buf.WriteByte(' ')
7095
default:
7196
return "", fmt.Errorf("invalid Part type: %T", part)
7297
}
73-
buf.WriteString(str)
7498
}
7599

76100
for i, used := range argUse {
@@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
82106
}
83107

84108
func NewQuery(sql string) (*Query, error) {
85-
l := &sqlLexer{
86-
src: sql,
87-
stateFn: rawState,
109+
query := &Query{}
110+
query.init(sql)
111+
112+
return query, nil
113+
}
114+
115+
var sqlLexerPool = &pool[*sqlLexer]{
116+
new: func() *sqlLexer {
117+
return &sqlLexer{}
118+
},
119+
reset: func(sl *sqlLexer) bool {
120+
*sl = sqlLexer{}
121+
return true
122+
},
123+
}
124+
125+
func (q *Query) init(sql string) {
126+
parts := q.Parts[:0]
127+
if parts == nil {
128+
// dirty, but fast heuristic to preallocate for ~90% usecases
129+
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
130+
parts = make([]Part, 0, n)
88131
}
89132

133+
l := sqlLexerPool.get()
134+
defer sqlLexerPool.put(l)
135+
136+
l.src = sql
137+
l.stateFn = rawState
138+
l.parts = parts
139+
90140
for l.stateFn != nil {
91141
l.stateFn = l.stateFn(l)
92142
}
93143

94-
query := &Query{Parts: l.parts}
95-
96-
return query, nil
144+
q.Parts = l.parts
97145
}
98146

99-
func QuoteString(str string) string {
100-
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
147+
func QuoteString(dst []byte, str string) []byte {
148+
const quote = '\''
149+
150+
// Preallocate space for the worst case scenario
151+
dst = slices.Grow(dst, len(str)*2+2)
152+
153+
// Add opening quote
154+
dst = append(dst, quote)
155+
156+
// Iterate through the string without allocating
157+
for i := 0; i < len(str); i++ {
158+
if str[i] == quote {
159+
dst = append(dst, quote, quote)
160+
} else {
161+
dst = append(dst, str[i])
162+
}
163+
}
164+
165+
// Add closing quote
166+
dst = append(dst, quote)
167+
168+
return dst
101169
}
102170

103-
func QuoteBytes(buf []byte) string {
104-
return `'\x` + hex.EncodeToString(buf) + "'"
171+
func QuoteBytes(dst, buf []byte) []byte {
172+
if len(buf) == 0 {
173+
return append(dst, `'\x'`...)
174+
}
175+
176+
// Calculate required length
177+
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
178+
179+
// Ensure dst has enough capacity
180+
if cap(dst)-len(dst) < requiredLen {
181+
newDst := make([]byte, len(dst), len(dst)+requiredLen)
182+
copy(newDst, dst)
183+
dst = newDst
184+
}
185+
186+
// Record original length and extend slice
187+
origLen := len(dst)
188+
dst = dst[:origLen+requiredLen]
189+
190+
// Add prefix
191+
dst[origLen] = '\''
192+
dst[origLen+1] = '\\'
193+
dst[origLen+2] = 'x'
194+
195+
// Encode bytes directly into dst
196+
hex.Encode(dst[origLen+3:len(dst)-1], buf)
197+
198+
// Add suffix
199+
dst[len(dst)-1] = '\''
200+
201+
return dst
105202
}
106203

107204
type sqlLexer struct {
@@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
319416
}
320417
}
321418

419+
var queryPool = &pool[*Query]{
420+
new: func() *Query {
421+
return &Query{}
422+
},
423+
reset: func(q *Query) bool {
424+
n := len(q.Parts)
425+
q.Parts = q.Parts[:0]
426+
return n < 64 // drop too large queries
427+
},
428+
}
429+
322430
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
323431
// as necessary. This function is only safe when standard_conforming_strings is
324432
// on.
325433
func SanitizeSQL(sql string, args ...any) (string, error) {
326-
query, err := NewQuery(sql)
327-
if err != nil {
328-
return "", err
329-
}
434+
query := queryPool.get()
435+
query.init(sql)
436+
defer queryPool.put(query)
437+
330438
return query.Sanitize(args...)
331439
}
440+
441+
type pool[E any] struct {
442+
p sync.Pool
443+
new func() E
444+
reset func(E) bool
445+
}
446+
447+
func (pool *pool[E]) get() E {
448+
v, ok := pool.p.Get().(E)
449+
if !ok {
450+
v = pool.new()
451+
}
452+
453+
return v
454+
}
455+
456+
func (p *pool[E]) put(v E) {
457+
if p.reset(v) {
458+
p.p.Put(v)
459+
}
460+
}
+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// sanitize_benchmark_test.go
2+
package sanitize_test
3+
4+
import (
5+
"testing"
6+
"time"
7+
8+
"github.com/jackc/pgx/v5/internal/sanitize"
9+
)
10+
11+
var benchmarkSanitizeResult string
12+
13+
const benchmarkQuery = "" +
14+
`SELECT *
15+
FROM "water_containers"
16+
WHERE NOT "id" = $1 -- int64
17+
AND "tags" NOT IN $2 -- nil
18+
AND "volume" > $3 -- float64
19+
AND "transportable" = $4 -- bool
20+
AND position($5 IN "sign") -- bytes
21+
AND "label" LIKE $6 -- string
22+
AND "created_at" > $7; -- time.Time`
23+
24+
var benchmarkArgs = []any{
25+
int64(12345),
26+
nil,
27+
float64(500),
28+
true,
29+
[]byte("8BADF00D"),
30+
"kombucha's han'dy awokowa",
31+
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
32+
}
33+
34+
func BenchmarkSanitize(b *testing.B) {
35+
query, err := sanitize.NewQuery(benchmarkQuery)
36+
if err != nil {
37+
b.Fatalf("failed to create query: %v", err)
38+
}
39+
40+
b.ResetTimer()
41+
b.ReportAllocs()
42+
43+
for i := 0; i < b.N; i++ {
44+
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
45+
if err != nil {
46+
b.Fatalf("failed to sanitize query: %v", err)
47+
}
48+
}
49+
}
50+
51+
var benchmarkNewSQLResult string
52+
53+
func BenchmarkSanitizeSQL(b *testing.B) {
54+
b.ReportAllocs()
55+
var err error
56+
for i := 0; i < b.N; i++ {
57+
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
58+
if err != nil {
59+
b.Fatalf("failed to sanitize SQL: %v", err)
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)