@@ -3,18 +3,30 @@ package dbrules
3
3
import (
4
4
"database/sql"
5
5
"fmt"
6
+ "strconv"
7
+ "strings"
6
8
7
9
. "github.com/mjarkk/laravalidate"
8
10
)
9
11
10
- type DB struct { conn * sql.DB }
12
+ type DB struct {
13
+ conn * sql.DB
14
+ variableStyle QueryVariableStyle
15
+ }
16
+
17
+ type QueryVariableStyle uint8
18
+
19
+ const (
20
+ DefaultStyle QueryVariableStyle = iota // ?
21
+ PgStyle // $1, $2, ...
22
+ )
11
23
12
- func AddRules (conn * sql.DB ) {
24
+ func AddRules (conn * sql.DB , variableStyle QueryVariableStyle ) {
13
25
if conn == nil {
14
26
panic ("DB connection cannot be nil" )
15
27
}
16
28
17
- db := & DB {conn }
29
+ db := & DB {conn , variableStyle }
18
30
19
31
RegisterValidator ("exists" , db .Exists )
20
32
@@ -25,6 +37,21 @@ func AddRules(conn *sql.DB) {
25
37
LogValidatorsWithoutMessages ()
26
38
}
27
39
40
+ func (b * DB ) prepareQuery (in string ) string {
41
+ switch b .variableStyle {
42
+ case PgStyle :
43
+ for i := 1 ; i <= strings .Count (in , "?" ); i ++ {
44
+ iStr := "$" + strconv .Itoa (i )
45
+ in = strings .Replace (in , "?" , iStr , 1 )
46
+ }
47
+ }
48
+ return in
49
+ }
50
+
51
+ func (b * DB ) query (query string , args ... any ) (* sql.Rows , error ) {
52
+ return b .conn .Query (b .prepareQuery (query ), args ... )
53
+ }
54
+
28
55
func (b * DB ) Exists (ctx * ValidatorCtx ) (string , bool ) {
29
56
if len (ctx .Args ) == 0 {
30
57
return "args" , false
@@ -46,19 +73,25 @@ func (b *DB) Exists(ctx *ValidatorCtx) (string, bool) {
46
73
return "" , true
47
74
}
48
75
49
- row := b .conn .QueryRow (
50
- fmt .Sprintf ("SELECT %s FROM %s WHERE %s = ? LIMIT 1" , column , tableName , column ),
76
+ query := fmt .Sprintf ("SELECT %s FROM %s WHERE %s = ? LIMIT 1" , column , tableName , column )
77
+ result , err := b .query (
78
+ query ,
51
79
ctx .Value .Interface (),
52
80
)
53
- if row . Err () != nil {
81
+ if err != nil {
54
82
return "exists" , false
55
83
}
56
84
57
- resp := sql.RawBytes {}
58
- err := row .Scan (& resp )
59
- if err != nil {
60
- return "exists" , false
85
+ defer result .Close ()
86
+ for result .Next () {
87
+ resp := sql.RawBytes {}
88
+ err := result .Scan (& resp )
89
+ if err != nil {
90
+ return "exists" , false
91
+ }
92
+
93
+ return "" , true
61
94
}
62
95
63
- return "" , true
96
+ return "exists " , false
64
97
}
0 commit comments