-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathpgmock_test.go
93 lines (79 loc) · 2.4 KB
/
pgmock_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package pgmock_test
import (
"context"
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgmock"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestScript(t *testing.T) {
script := &pgmock.Script{
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
}
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
pgproto3.FieldDescription{
Name: []byte("?column?"),
TableOID: 0,
TableAttributeNumber: 0,
DataTypeOID: 23,
DataTypeSize: 4,
TypeModifier: -1,
Format: 0,
},
},
}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.DataRow{
Values: [][]byte{[]byte("42")},
}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Terminate{}))
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Second))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
if err != nil {
serverErrChan <- err
return
}
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, connStr)
require.NoError(t, err)
results, err := pgConn.Exec(ctx, "select 42").ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", string(results[0].CommandTag))
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "42", string(results[0].Rows[0][0]))
pgConn.Close(ctx)
assert.NoError(t, <-serverErrChan)
}