Skip to content

Commit 2cc9320

Browse files
committed
Add redis lua migrations
1 parent 555501f commit 2cc9320

File tree

12 files changed

+741
-25
lines changed

12 files changed

+741
-25
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
SOURCE ?= file go_bindata github github_ee bitbucket aws_s3 google_cloud_storage godoc_vfs gitlab
2-
DATABASE ?= postgres mysql redshift cassandra spanner cockroachdb yugabytedb clickhouse mongodb sqlserver firebird neo4j pgx pgx5 rqlite
2+
DATABASE ?= postgres mysql redis redshift cassandra spanner cockroachdb yugabytedb clickhouse mongodb sqlserver firebird neo4j pgx pgx5 rqlite
33
DATABASE_TEST ?= $(DATABASE) sqlite sqlite3 sqlcipher
44
VERSION ?= $(shell git describe --tags 2>/dev/null | cut -c 2-)
55
TEST_FLAGS ?=

database/redis/README.md

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# redis
2+
3+
URL format:
4+
5+
- standalone connection:
6+
7+
`redis://<user>:<password>@<host>:<port>/<db_number>`
8+
9+
- failover connection:
10+
11+
`redis://<user>:<password>@/<db_number>?sentinel_addr=<sentinel_host>:<sentinel_port>`
12+
13+
- cluster connection:
14+
15+
`redis://<user>:<password>@<host>:<port>?addr=<host2>:<port2>&addr=<host3>:<port3>`
16+
17+
`rediss://<user>:<password>@<host>:<port>?addr=<host2>:<port2>&addr=<host3>:<port3>`
18+
19+
| URL Query | WithInstance Config | Description |
20+
|--------------------|---------------------|---------------------------------------------|
21+
| `x-mode` | - | The Mode that used to choose client type |
22+
| `x-migrations-key` | `MigrationsKey` | Specify the key where migrations are stored |
23+
| `x-lock-key` | `LockKey` | Specify the key where locks are stored |
24+
| `x-lock-timeout` | `LockTimeout` | Specify the timeout of lock |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
return redis.call("DEL", "test_key")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
return redis.call("SET", "test_key", "1")

database/redis/redis.go

+228
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
package redis
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"github.com/golang-migrate/migrate/v4"
7+
"github.com/golang-migrate/migrate/v4/database"
8+
"github.com/redis/go-redis/v9"
9+
"go.uber.org/atomic"
10+
"io"
11+
neturl "net/url"
12+
"strconv"
13+
"strings"
14+
"time"
15+
)
16+
17+
func init() {
18+
db := Redis{}
19+
database.Register("redis", &db)
20+
database.Register("rediss", &db)
21+
}
22+
23+
var (
24+
DefaultMigrationsKey = "schema_migrations"
25+
DefaultLockKey = "lock:schema_migrations"
26+
DefaultLockTimeout = 15 * time.Second
27+
)
28+
29+
func convertVersionFromDB(result []interface{}) (int, bool, error) {
30+
if result[0] == nil || result[1] == nil {
31+
return database.NilVersion, false, nil
32+
}
33+
34+
version, err := strconv.Atoi(result[0].(string))
35+
if err != nil {
36+
return 0, false, fmt.Errorf("can't parse version: %w", err)
37+
}
38+
39+
dirty, err := strconv.ParseBool(result[1].(string))
40+
if err != nil {
41+
return 0, false, fmt.Errorf("can't parse dirty: %w", err)
42+
}
43+
44+
return version, dirty, nil
45+
}
46+
47+
type Mode int8
48+
49+
const (
50+
ModeUnspecified Mode = iota
51+
ModeStandalone
52+
ModeFailover
53+
ModeCluster
54+
)
55+
56+
var rawModeToMode = map[string]Mode{
57+
"": ModeUnspecified,
58+
"standalone": ModeStandalone,
59+
"failover": ModeFailover,
60+
"cluster": ModeCluster,
61+
}
62+
63+
func parseMode(rawMode string) (Mode, error) {
64+
mode, ok := rawModeToMode[strings.ToLower(rawMode)]
65+
if ok {
66+
return mode, nil
67+
}
68+
69+
return ModeUnspecified, fmt.Errorf("unexpected mode: %q", rawMode)
70+
}
71+
72+
type Config struct {
73+
MigrationsKey string
74+
LockKey string
75+
LockTimeout time.Duration
76+
}
77+
78+
func newClient(url string, mode Mode) (redis.UniversalClient, error) {
79+
if mode == ModeUnspecified {
80+
var err error
81+
82+
mode, err = determineMode(url)
83+
if err != nil {
84+
return nil, err
85+
}
86+
}
87+
88+
switch mode {
89+
case ModeStandalone:
90+
options, err := redis.ParseURL(url)
91+
if err != nil {
92+
return nil, err
93+
}
94+
95+
return redis.NewClient(options), nil
96+
case ModeFailover:
97+
options, err := parseFailoverURL(url)
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
return redis.NewFailoverClient(options), nil
103+
case ModeCluster:
104+
options, err := redis.ParseClusterURL(url)
105+
if err != nil {
106+
return nil, err
107+
}
108+
109+
return redis.NewClusterClient(options), nil
110+
default:
111+
return nil, fmt.Errorf("unexpected mode: %q", mode)
112+
}
113+
}
114+
115+
func WithInstance(client redis.UniversalClient, config *Config) (database.Driver, error) {
116+
if config.MigrationsKey == "" {
117+
config.MigrationsKey = DefaultMigrationsKey
118+
}
119+
120+
if config.LockKey == "" {
121+
config.LockKey = DefaultLockKey
122+
}
123+
124+
if config.LockTimeout == 0 {
125+
config.LockTimeout = DefaultLockTimeout
126+
}
127+
128+
return &Redis{
129+
client: client,
130+
config: config,
131+
}, nil
132+
}
133+
134+
type Redis struct {
135+
client redis.UniversalClient
136+
isLocked atomic.Bool
137+
config *Config
138+
}
139+
140+
func (r *Redis) Open(url string) (database.Driver, error) {
141+
purl, err := neturl.Parse(url)
142+
if err != nil {
143+
return nil, err
144+
}
145+
146+
query := purl.Query()
147+
148+
mode, err := parseMode(query.Get("x-mode"))
149+
if err != nil {
150+
return nil, err
151+
}
152+
153+
var lockTimeout time.Duration
154+
rawLockTimeout := query.Get("x-lock-timeout")
155+
if rawLockTimeout != "" {
156+
lockTimeout, err = time.ParseDuration(rawLockTimeout)
157+
if err != nil {
158+
return nil, fmt.Errorf("invalid x-lock-timeout: %w", err)
159+
}
160+
}
161+
162+
client, err := newClient(migrate.FilterCustomQuery(purl).String(), mode)
163+
if err != nil {
164+
return nil, fmt.Errorf("can't create client: %w", err)
165+
}
166+
167+
return WithInstance(
168+
client,
169+
&Config{
170+
MigrationsKey: query.Get("x-migrations-key"),
171+
LockKey: query.Get("x-lock-key"),
172+
LockTimeout: lockTimeout,
173+
},
174+
)
175+
}
176+
177+
func (r *Redis) Close() error {
178+
return r.client.Close()
179+
}
180+
181+
func (r *Redis) Lock() error {
182+
return database.CasRestoreOnErr(&r.isLocked, false, true, database.ErrLocked, func() error {
183+
return r.client.SetArgs(context.Background(), r.config.LockKey, 1, redis.SetArgs{
184+
Mode: "NX",
185+
TTL: r.config.LockTimeout,
186+
}).Err()
187+
})
188+
}
189+
190+
func (r *Redis) Unlock() error {
191+
return database.CasRestoreOnErr(&r.isLocked, true, false, database.ErrNotLocked, func() error {
192+
return r.client.Del(context.Background(), r.config.LockKey).Err()
193+
})
194+
}
195+
196+
func (r *Redis) Run(migration io.Reader) error {
197+
script, err := io.ReadAll(migration)
198+
if err != nil {
199+
return err
200+
}
201+
202+
if err = r.client.Eval(context.Background(), string(script), nil).Err(); err != nil {
203+
return fmt.Errorf("migration failed: %w", err)
204+
}
205+
206+
return nil
207+
}
208+
209+
func (r *Redis) SetVersion(version int, dirty bool) error {
210+
if version > 0 || (version == database.NilVersion && dirty) {
211+
return r.client.HMSet(context.Background(), r.config.MigrationsKey, "version", version, "dirty", dirty).Err()
212+
}
213+
214+
return r.client.Del(context.Background(), r.config.MigrationsKey).Err()
215+
}
216+
217+
func (r *Redis) Version() (version int, dirty bool, err error) {
218+
result, err := r.client.HMGet(context.Background(), r.config.MigrationsKey, "version", "dirty").Result()
219+
if err != nil {
220+
return 0, false, err
221+
}
222+
223+
return convertVersionFromDB(result)
224+
}
225+
226+
func (r *Redis) Drop() error {
227+
return r.client.FlushDB(context.Background()).Err()
228+
}

0 commit comments

Comments
 (0)