Skip to content

Commit ec96099

Browse files
committed
fix postgresql arguments in dbrules
1 parent 64af40b commit ec96099

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func main() {
7878
db, err := sql.Open(driverName, dns)
7979
// Check err
8080

81-
dbrules.AddRules(db)
81+
dbrules.AddRules(db, dbrules.DefaultStyle)
8282

8383
// Now you can add translations!
8484
// translations.RegisterNlTranslations()

dbrules/dbrules.go

+44-11
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,30 @@ package dbrules
33
import (
44
"database/sql"
55
"fmt"
6+
"strconv"
7+
"strings"
68

79
. "github.com/mjarkk/laravalidate"
810
)
911

10-
type DB struct{ conn *sql.DB }
12+
type DB struct {
13+
conn *sql.DB
14+
variableStyle QueryVariableStyle
15+
}
16+
17+
type QueryVariableStyle uint8
18+
19+
const (
20+
DefaultStyle QueryVariableStyle = iota // ?
21+
PgStyle // $1, $2, ...
22+
)
1123

12-
func AddRules(conn *sql.DB) {
24+
func AddRules(conn *sql.DB, variableStyle QueryVariableStyle) {
1325
if conn == nil {
1426
panic("DB connection cannot be nil")
1527
}
1628

17-
db := &DB{conn}
29+
db := &DB{conn, variableStyle}
1830

1931
RegisterValidator("exists", db.Exists)
2032

@@ -25,6 +37,21 @@ func AddRules(conn *sql.DB) {
2537
LogValidatorsWithoutMessages()
2638
}
2739

40+
func (b *DB) prepareQuery(in string) string {
41+
switch b.variableStyle {
42+
case PgStyle:
43+
for i := 1; i <= strings.Count(in, "?"); i++ {
44+
iStr := "$" + strconv.Itoa(i)
45+
in = strings.Replace(in, "?", iStr, 1)
46+
}
47+
}
48+
return in
49+
}
50+
51+
func (b *DB) query(query string, args ...any) (*sql.Rows, error) {
52+
return b.conn.Query(b.prepareQuery(query), args...)
53+
}
54+
2855
func (b *DB) Exists(ctx *ValidatorCtx) (string, bool) {
2956
if len(ctx.Args) == 0 {
3057
return "args", false
@@ -46,19 +73,25 @@ func (b *DB) Exists(ctx *ValidatorCtx) (string, bool) {
4673
return "", true
4774
}
4875

49-
row := b.conn.QueryRow(
50-
fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", column, tableName, column),
76+
query := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ? LIMIT 1", column, tableName, column)
77+
result, err := b.query(
78+
query,
5179
ctx.Value.Interface(),
5280
)
53-
if row.Err() != nil {
81+
if err != nil {
5482
return "exists", false
5583
}
5684

57-
resp := sql.RawBytes{}
58-
err := row.Scan(&resp)
59-
if err != nil {
60-
return "exists", false
85+
defer result.Close()
86+
for result.Next() {
87+
resp := sql.RawBytes{}
88+
err := result.Scan(&resp)
89+
if err != nil {
90+
return "exists", false
91+
}
92+
93+
return "", true
6194
}
6295

63-
return "", true
96+
return "exists", false
6497
}

0 commit comments

Comments
 (0)