-
Notifications
You must be signed in to change notification settings - Fork 28
/
database_test.go
116 lines (101 loc) · 2.28 KB
/
database_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package sqlingo
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"testing"
"time"
)
func (m *mockConn) Prepare(query string) (driver.Stmt, error) {
if m.prepareError != nil {
return nil, m.prepareError
}
m.lastSql = query
return &mockStmt{
columnCount: m.columnCount,
rowCount: m.rowCount,
}, nil
}
func (m mockConn) Close() error {
return nil
}
func (m *mockConn) Begin() (driver.Tx, error) {
if m.beginTxError != nil {
return nil, m.beginTxError
}
m.mockTx = &mockTx{}
return m.mockTx, nil
}
var sharedMockConn = &mockConn{
columnCount: 11,
rowCount: 10,
}
func (m mockDriver) Open(name string) (driver.Conn, error) {
return sharedMockConn, nil
}
func newMockDatabase() Database {
db, err := Open("sqlingo-mock", "dummy")
if err != nil {
panic(err)
}
db.(*database).dialect = dialectMySQL
return db
}
func init() {
sql.Register("sqlingo-mock", &mockDriver{})
}
func TestDatabase(t *testing.T) {
if _, err := Open("unknowndb", "unknown"); err == nil {
t.Error()
}
db := newMockDatabase()
if db.GetDB() == nil {
t.Error()
}
interceptorExecutedCount := 0
loggerExecutedCount := 0
db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
if sql != "SELECT 1" {
t.Error()
}
interceptorExecutedCount++
return invoker(ctx, sql)
})
db.SetLogger(func(sql string, _ time.Duration, _, _ bool) {
if sql != "SELECT 1" {
t.Error(sql)
}
loggerExecutedCount++
})
_, _ = db.Query("SELECT 1")
if interceptorExecutedCount != 1 || loggerExecutedCount != 1 {
t.Error(interceptorExecutedCount, loggerExecutedCount)
}
_, _ = db.Execute("SELECT 1")
if loggerExecutedCount != 2 {
t.Error(loggerExecutedCount)
}
db.SetInterceptor(func(ctx context.Context, sql string, invoker InvokerFunc) error {
return errors.New("error")
})
if _, err := db.Query("SELECT 1"); err == nil {
t.Error("should get error here")
}
}
func TestDatabaseRetry(t *testing.T) {
db := newMockDatabase()
retryCount := 0
db.SetRetryPolicy(func(err error) bool {
retryCount++
return retryCount < 10
})
sharedMockConn.prepareError = errors.New("error")
if _, err := db.Query("SELECT 1"); err == nil {
t.Error("should get error here")
}
if retryCount != 10 {
t.Error(retryCount)
}
sharedMockConn.prepareError = nil
}