Skip to content

Commit

Permalink
support watch command
Browse files Browse the repository at this point in the history
  • Loading branch information
HDT3213 committed Jun 20, 2021
1 parent 5d05e2e commit ae25076
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 40 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ MSET (10 keys): 65487.89 requests per second

## Todo List

+ [ ] `Multi` Command
+ [ ] `Watch` Command and CAS support
+ [x] `Multi` Command
+ [x] `Watch` Command and CAS support
+ [ ] Stream support
+ [ ] RDB file loader
+ [ ] Master-Slave mode
Expand Down
6 changes: 3 additions & 3 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试
- 自动过期功能(TTL)
- 发布订阅
- 地理位置
- AOF 持久化及AOF重写
- AOF 持久化及 AOF 重写
- Multi 命令开启的事务具有`原子性``隔离性`. 若在执行过程中遇到错误, godis 会回滚已执行的命令
- 内置集群模式. 集群对客户端是透明的, 您可以像使用单机版 redis 一样使用 godis 集群
- `MSET`, `DEL` 命令在集群模式下原子性执行
Expand Down Expand Up @@ -105,8 +105,8 @@ MSET (10 keys): 65487.89 requests per second

## 开发计划

+ [ ] `Multi` 命令
+ [ ] `Watch` 命令和 CAS 支持
+ [x] `Multi` 命令
+ [x] `Watch` 命令和 CAS 支持
+ [ ] Stream 队列
+ [ ] 加载 RDB 文件
+ [ ] 主从模式
Expand Down
84 changes: 75 additions & 9 deletions cluster/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package cluster
import (
"github.com/hdt3213/godis"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply"
"strconv"
)

const relayMulti = "_multi"
const innerWatch = "_watch"

var relayMultiBytes = []byte(relayMulti)

Expand All @@ -25,9 +28,15 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R
keys = append(keys, wKeys...)
keys = append(keys, rKeys...)
}
watching := conn.GetWatching()
watchingKeys := make([]string, 0, len(watching))
for key := range watching {
watchingKeys = append(watchingKeys, key)
}
keys = append(keys, watchingKeys...)
if len(keys) == 0 {
// empty transaction or only `PING`s
return godis.ExecMulti(cluster.db, cmdLines)
return godis.ExecMulti(cluster.db, conn, watching, cmdLines)
}
groupMap := cluster.groupBy(keys)
if len(groupMap) > 1 {
Expand All @@ -41,24 +50,41 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R

// out parser not support reply.MultiRawReply, so we have to encode it
if peer == cluster.self {
return godis.ExecMulti(cluster.db, cmdLines)
return godis.ExecMulti(cluster.db, conn, watching, cmdLines)
}
return execMultiOnOtherNode(cluster, conn, peer, cmdLines)
return execMultiOnOtherNode(cluster, conn, peer, watching, cmdLines)
}

func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, cmdLines []CmdLine) redis.Reply {
func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string, watching map[string]uint32, cmdLines []CmdLine) redis.Reply {
defer func() {
conn.ClearQueuedCmds()
conn.SetMultiState(false)
}()
relayCmdLine := [][]byte{ // relay it to executing node
relayMultiBytes,
}
// watching commands
var watchingCmdLine = utils.ToCmdLine(innerWatch)
for key, ver := range watching {
verStr := strconv.FormatUint(uint64(ver), 10)
watchingCmdLine = append(watchingCmdLine, []byte(key), []byte(verStr))
}
relayCmdLine = append(relayCmdLine, encodeCmdLine([]CmdLine{watchingCmdLine})...)
relayCmdLine = append(relayCmdLine, encodeCmdLine(cmdLines)...)
rawRelayResult := cluster.relay(peer, conn, relayCmdLine)
var rawRelayResult redis.Reply
if peer == cluster.self {
// this branch just for testing
rawRelayResult = execRelayedMulti(cluster, nil, relayCmdLine)
} else {
rawRelayResult = cluster.relay(peer, conn, relayCmdLine)
}
if reply.IsErrorReply(rawRelayResult) {
return rawRelayResult
}
_, ok := rawRelayResult.(*reply.EmptyMultiBulkReply)
if ok {
return rawRelayResult
}
relayResult, ok := rawRelayResult.(*reply.MultiBulkReply)
if !ok {
return reply.MakeErrReply("execute failed")
Expand All @@ -71,25 +97,65 @@ func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string,
}

// execRelayedMulti execute relayed multi commands transaction
// cmdLine format: _multi base64ed-cmdLine
// cmdLine format: _multi watch-cmdLine base64ed-cmdLine
// result format: base64ed-reply list
func execRelayedMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.Reply {
if len(cmdLine) < 2 {
return reply.MakeArgNumErrReply("_exec")
}
decoded, err := parseEncodedMultiRawReply(cmdLine[1:])
if err != nil {
return reply.MakeErrReply(err.Error())
}
var cmdLines []CmdLine
var txCmdLines []CmdLine
for _, rep := range decoded.Replies {
mbr, ok := rep.(*reply.MultiBulkReply)
if !ok {
return reply.MakeErrReply("exec failed")
}
cmdLines = append(cmdLines, mbr.Args)
txCmdLines = append(txCmdLines, mbr.Args)
}
watching := make(map[string]uint32)
watchCmdLine := txCmdLines[0] // format: _watch key1 ver1 key2 ver2...
for i := 2; i < len(watchCmdLine); i += 2 {
key := string(watchCmdLine[i-1])
verStr := string(watchCmdLine[i])
ver, err := strconv.ParseUint(verStr, 10, 64)
if err != nil {
return reply.MakeErrReply("watching command line failed")
}
watching[key] = uint32(ver)
}
rawResult := godis.ExecMulti(cluster.db, conn, watching, txCmdLines[1:])
_, ok := rawResult.(*reply.EmptyMultiBulkReply)
if ok {
return rawResult
}
rawResult := godis.ExecMulti(cluster.db, cmdLines)
resultMBR, ok := rawResult.(*reply.MultiRawReply)
if !ok {
return reply.MakeErrReply("exec failed")
}
return encodeMultiRawReply(resultMBR)
}

func execWatch(cluster *Cluster, conn redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 {
return reply.MakeArgNumErrReply("watch")
}
args = args[1:]
watching := conn.GetWatching()
for _, bkey := range args {
key := string(bkey)
peer := cluster.peerPicker.PickNode(key)
result := cluster.relay(peer, conn, utils.ToCmdLine("GetVer", key))
if reply.IsErrorReply(result) {
return result
}
intResult, ok := result.(*reply.IntReply)
if !ok {
return reply.MakeErrReply("get version failed")
}
watching[key] = uint32(intResult.Code)
}
return reply.MakeOkReply()
}
71 changes: 55 additions & 16 deletions cluster/multi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,63 @@ func TestMultiExecOnOthers(t *testing.T) {
testCluster.Exec(conn, utils.ToCmdLine("lrange", key, "0", "-1"))

cmdLines := conn.GetQueuedCmdLine()
relayCmdLine := [][]byte{ // relay it to executing node
relayMultiBytes,
}
relayCmdLine = append(relayCmdLine, encodeCmdLine(cmdLines)...)
rawRelayResult := execRelayedMulti(testCluster, conn, relayCmdLine)
if reply.IsErrorReply(rawRelayResult) {
t.Error()
}
relayResult, ok := rawRelayResult.(*reply.MultiBulkReply)
if !ok {
t.Error()
}
rep, err := parseEncodedMultiRawReply(relayResult.Args)
if err != nil {
t.Error()
}
rawResp := execMultiOnOtherNode(testCluster, conn, testCluster.self, nil, cmdLines)
rep := rawResp.(*reply.MultiRawReply)
if len(rep.Replies) != 2 {
t.Errorf("expect 2 replies actual %d", len(rep.Replies))
}
asserts.AssertMultiBulkReply(t, rep.Replies[1], []string{value})
}

func TestWatch(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn)
key := utils.RandString(10)
value := utils.RandString(10)
testCluster.Exec(conn, utils.ToCmdLine("watch", key))
testCluster.Exec(conn, utils.ToCmdLine("set", key, value))
result := testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result)
key2 := utils.RandString(10)
value2 := utils.RandString(10)
testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2))
result = testCluster.Exec(conn, utils.ToCmdLine("exec"))
asserts.AssertNotError(t, result)
result = testCluster.Exec(conn, utils.ToCmdLine("get", key2))
asserts.AssertNullBulk(t, result)

testCluster.Exec(conn, utils.ToCmdLine("watch", key))
result = testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result)
testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2))
result = testCluster.Exec(conn, utils.ToCmdLine("exec"))
asserts.AssertNotError(t, result)
result = testCluster.Exec(conn, utils.ToCmdLine("get", key2))
asserts.AssertBulkReply(t, result, value2)
}

func TestWatch2(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn)
key := utils.RandString(10)
value := utils.RandString(10)
testCluster.Exec(conn, utils.ToCmdLine("watch", key))
testCluster.Exec(conn, utils.ToCmdLine("set", key, value))
result := testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result)
key2 := utils.RandString(10)
value2 := utils.RandString(10)
testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2))
cmdLines := conn.GetQueuedCmdLine()
execMultiOnOtherNode(testCluster, conn, testCluster.self, conn.GetWatching(), cmdLines)
result = testCluster.Exec(conn, utils.ToCmdLine("get", key2))
asserts.AssertNullBulk(t, result)

testCluster.Exec(conn, utils.ToCmdLine("watch", key))
result = testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result)
testCluster.Exec(conn, utils.ToCmdLine("set", key2, value2))
execMultiOnOtherNode(testCluster, conn, testCluster.self, conn.GetWatching(), cmdLines)
result = testCluster.Exec(conn, utils.ToCmdLine("get", key2))
asserts.AssertBulkReply(t, result, value2)
}
2 changes: 2 additions & 0 deletions cluster/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ func makeRouter() map[string]CmdFunc {
routerMap["flushdb"] = FlushDB
routerMap["flushall"] = FlushAll
routerMap[relayMulti] = execRelayedMulti
routerMap["getver"] = defaultFunc
routerMap["watch"] = execWatch

return routerMap
}
Expand Down
28 changes: 24 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type DB struct {
data dict.Dict
// key -> expireTime (time.Time)
ttlMap dict.Dict
// key -> version(uint32)
versionMap dict.Dict

// dict.Dict will ensure concurrent-safety of its method
// use this mutex for complicated command only, eg. rpush, incr ...
Expand Down Expand Up @@ -72,10 +74,11 @@ type UndoFunc func(db *DB, args [][]byte) []CmdLine
// MakeDB create DB instance and start it
func MakeDB() *DB {
db := &DB{
data: dict.MakeConcurrent(dataDictSize),
ttlMap: dict.MakeConcurrent(ttlDictSize),
locker: lock.Make(lockerSize),
hub: pubsub.MakeHub(),
data: dict.MakeConcurrent(dataDictSize),
ttlMap: dict.MakeConcurrent(ttlDictSize),
versionMap: dict.MakeConcurrent(dataDictSize),
locker: lock.Make(lockerSize),
hub: pubsub.MakeHub(),
}

// aof
Expand Down Expand Up @@ -249,6 +252,23 @@ func (db *DB) IsExpired(key string) bool {
return expired
}

/* --- add version --- */

func (db *DB) addVersion(keys ...string) {
for _, key := range keys {
versionCode := db.GetVersion(key)
db.versionMap.Put(key, versionCode+1)
}
}

func (db *DB) GetVersion(key string) uint32 {
entity, ok := db.versionMap.Get(key)
if !ok {
return 0
}
return entity.(uint32)
}

/* ---- Subscribe Functions ---- */

// AfterClientClose does some clean after client close connection
Expand Down
5 changes: 5 additions & 0 deletions exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func execSpecialCmd(c redis.Connection, cmdLine [][]byte, cmdName string, db *DB
return reply.MakeArgNumErrReply(cmdName), true
}
return execMulti(db, c), true
} else if cmdName == "watch" {
if !validateArity(-2, cmdLine) {
return reply.MakeArgNumErrReply(cmdName), true
}
return Watch(db, c, cmdLine[1:]), true
}
return nil, false
}
1 change: 1 addition & 0 deletions exec_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func execNormalCommand(db *DB, cmdArgs [][]byte) redis.Reply {

prepare := cmd.prepare
write, read := prepare(cmdArgs[1:])
db.addVersion(write...)
db.RWLocks(write, read)
defer db.RWUnLocks(write, read)
fun := cmd.executor
Expand Down
1 change: 1 addition & 0 deletions interface/redis/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ type Connection interface {
GetQueuedCmdLine() [][][]byte
EnqueueCmd([][]byte)
ClearQueuedCmds()
GetWatching() map[string]uint32
}
Loading

0 comments on commit ae25076

Please sign in to comment.