Skip to content

Commit 3302893

Browse files
authored
Implement Hash API support in TestExecutor (zio#344)
1 parent d340342 commit 3302893

File tree

2 files changed

+220
-7
lines changed

2 files changed

+220
-7
lines changed

redis/src/main/scala/zio/redis/TestExecutor.scala

+218-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ private[redis] final class TestExecutor private (
1515
sets: TMap[String, Set[String]],
1616
strings: TMap[String, String],
1717
randomPick: Int => USTM[Int],
18-
hyperLogLogs: TMap[String, Set[String]]
18+
hyperLogLogs: TMap[String, Set[String]],
19+
hashes: TMap[String, Map[String, String]]
1920
) extends RedisExecutor.Service {
2021

2122
def execute(command: Chunk[RespValue.BulkString]): IO[RedisError, RespValue] =
@@ -273,11 +274,12 @@ private[redis] final class TestExecutor private (
273274
}
274275

275276
val key = input.head.asString
277+
276278
orWrongType(isSet(key))(
277279
{
278280
val start = input(1).asString.toInt
279281
val maybeRegex = if (input.size > 2) input(2).asString match {
280-
case "MATCH" => Some(input(3).asString.r)
282+
case "MATCH" => Some(input(3).asString.replace("*", ".*").r)
281283
case _ => None
282284
}
283285
else None
@@ -769,6 +771,203 @@ private[redis] final class TestExecutor private (
769771
)
770772
)
771773

774+
case api.Hashes.HDel.name =>
775+
val key = input(0).asString
776+
val values = input.tail.map(_.asString)
777+
778+
orWrongType(isHash(key))(
779+
for {
780+
hash <- hashes.getOrElse(key, Map.empty)
781+
countExists = hash.keys count values.contains
782+
newHash = hash -- values
783+
_ <- if (newHash.isEmpty) hashes.delete(key) else hashes.put(key, newHash)
784+
} yield RespValue.Integer(countExists.toLong)
785+
)
786+
787+
case api.Hashes.HExists.name =>
788+
val key = input(0).asString
789+
val field = input(1).asString
790+
791+
orWrongType(isHash(key))(
792+
for {
793+
hash <- hashes.getOrElse(key, Map.empty)
794+
exists = hash.keys.exists(_ == field)
795+
} yield if (exists) RespValue.Integer(1L) else RespValue.Integer(0L)
796+
)
797+
798+
case api.Hashes.HGet.name =>
799+
val key = input(0).asString
800+
val field = input(1).asString
801+
802+
orWrongType(isHash(key))(
803+
for {
804+
hash <- hashes.getOrElse(key, Map.empty)
805+
value = hash.get(field)
806+
} yield value.fold[RespValue](RespValue.NullBulkString)(result => RespValue.bulkString(result))
807+
)
808+
809+
case api.Hashes.HGetAll.name =>
810+
val key = input(0).asString
811+
812+
orWrongType(isHash(key))(
813+
for {
814+
hash <- hashes.getOrElse(key, Map.empty)
815+
results = hash.flatMap { case (k, v) => Iterable.apply(k, v) } map RespValue.bulkString
816+
} yield RespValue.Array(Chunk.fromIterable(results))
817+
)
818+
819+
case api.Hashes.HIncrBy.name =>
820+
val key = input(0).asString
821+
val field = input(1).asString
822+
val incr = input(2).asString.toLong
823+
824+
orWrongType(isHash(key))(
825+
(for {
826+
hash <- hashes.getOrElse(key, Map.empty)
827+
newValue <- STM.fromTry(Try(hash.getOrElse(field, "0").toLong + incr))
828+
newMap = hash + (field -> newValue.toString)
829+
_ <- hashes.put(key, newMap)
830+
} yield newValue).fold(_ => Replies.Error, result => RespValue.Integer(result))
831+
)
832+
833+
case api.Hashes.HIncrByFloat.name =>
834+
val key = input(0).asString
835+
val field = input(1).asString
836+
val incr = input(2).asString.toDouble
837+
838+
orWrongType(isHash(key))(
839+
(for {
840+
hash <- hashes.getOrElse(key, Map.empty)
841+
newValue <- STM.fromTry(Try(hash.getOrElse(field, "0").toDouble + incr))
842+
newHash = hash + (field -> newValue.toString)
843+
_ <- hashes.put(key, newHash)
844+
} yield newValue).fold(_ => Replies.Error, result => RespValue.bulkString(result.toString))
845+
)
846+
847+
case api.Hashes.HKeys.name =>
848+
val key = input(0).asString
849+
850+
orWrongType(isHash(key))(
851+
for {
852+
hash <- hashes.getOrElse(key, Map.empty)
853+
} yield RespValue.Array(Chunk.fromIterable(hash.keys map RespValue.bulkString))
854+
)
855+
856+
case api.Hashes.HLen.name =>
857+
val key = input(0).asString
858+
859+
orWrongType(isHash(key))(
860+
for {
861+
hash <- hashes.getOrElse(key, Map.empty)
862+
} yield RespValue.Integer(hash.size.toLong)
863+
)
864+
865+
case api.Hashes.HmGet.name =>
866+
val key = input(0).asString
867+
val fields = input.tail.map(_.asString)
868+
869+
orWrongType(isHash(key))(
870+
for {
871+
hash <- hashes.getOrElse(key, Map.empty)
872+
result = fields.map(hash.get)
873+
} yield RespValue.Array(result.map {
874+
case None => RespValue.NullBulkString
875+
case Some(value) => RespValue.bulkString(value)
876+
})
877+
)
878+
879+
case api.Hashes.HmSet.name =>
880+
val key = input(0).asString
881+
val values = input.tail.map(_.asString)
882+
883+
orWrongType(isHash(key))(
884+
for {
885+
hash <- hashes.getOrElse(key, Map.empty)
886+
newMap = hash ++ values.grouped(2).map(g => (g(0), g(1)))
887+
_ <- hashes.put(key, newMap)
888+
} yield Replies.Ok
889+
)
890+
891+
case api.Hashes.HScan.name =>
892+
def maybeGetCount(key: RespValue.BulkString, value: RespValue.BulkString): Option[Int] =
893+
key.asString match {
894+
case "COUNT" => Some(value.asString.toInt)
895+
case _ => None
896+
}
897+
898+
val key = input.head.asString
899+
900+
orWrongType(isHash(key))(
901+
{
902+
val start = input(1).asString.toInt
903+
val maybeRegex = if (input.size > 2) input(2).asString match {
904+
case "MATCH" => Some(input(3).asString.replace("*", ".*").r)
905+
case _ => None
906+
}
907+
else None
908+
val maybeCount =
909+
if (input.size > 4) maybeGetCount(input(4), input(5))
910+
else if (input.size > 2) maybeGetCount(input(2), input(3))
911+
else None
912+
val end = start + maybeCount.getOrElse(10)
913+
for {
914+
set <- hashes.getOrElse(key, Map.empty)
915+
filtered =
916+
maybeRegex.map(regex => set.filter { case (k, _) => regex.pattern.matcher(k).matches }).getOrElse(set)
917+
resultSet = filtered.slice(start, end)
918+
nextIndex = if (filtered.size <= end) 0 else end
919+
results = Replies.array(resultSet.flatMap { case (k, v) => Iterable(k, v) })
920+
} yield RespValue.array(RespValue.bulkString(nextIndex.toString), results)
921+
}
922+
)
923+
924+
case api.Hashes.HSet.name =>
925+
val key = input(0).asString
926+
val values = input.tail.map(_.asString)
927+
928+
orWrongType(isHash(key))(
929+
for {
930+
hash <- hashes.getOrElse(key, Map.empty)
931+
newHash = hash ++ values.grouped(2).map(g => (g(0), g(1)))
932+
_ <- hashes.put(key, newHash)
933+
} yield RespValue.Integer(newHash.size.toLong - hash.size.toLong)
934+
)
935+
936+
case api.Hashes.HSetNx.name =>
937+
val key = input(0).asString
938+
val field = input(1).asString
939+
val value = input(2).asString
940+
941+
orWrongType(isHash(key))(
942+
for {
943+
hash <- hashes.getOrElse(key, Map.empty)
944+
contains = hash.contains(field)
945+
newHash = hash ++ (if (contains) Map.empty else Map(field -> value))
946+
_ <- hashes.put(key, newHash)
947+
} yield RespValue.Integer(if (contains) 0L else 1L)
948+
)
949+
950+
case api.Hashes.HStrLen.name =>
951+
val key = input(0).asString
952+
val field = input(1).asString
953+
954+
orWrongType(isHash(key))(
955+
for {
956+
hash <- hashes.getOrElse(key, Map.empty)
957+
len = hash.get(field).map(_.length.toLong).getOrElse(0L)
958+
} yield RespValue.Integer(len)
959+
)
960+
961+
case api.Hashes.HVals.name =>
962+
val key = input(0).asString
963+
964+
orWrongType(isHash(key))(
965+
for {
966+
hash <- hashes.getOrElse(key, Map.empty)
967+
values = hash.values map RespValue.bulkString
968+
} yield RespValue.Array(Chunk.fromIterable(values))
969+
)
970+
772971
case _ => STM.succeedNow(RespValue.Error("ERR unknown command"))
773972
}
774973
}
@@ -784,23 +983,35 @@ private[redis] final class TestExecutor private (
784983
isString <- strings.contains(name)
785984
isList <- lists.contains(name)
786985
isHyper <- hyperLogLogs.contains(name)
787-
} yield !isString && !isList && !isHyper
986+
isHash <- hashes.contains(name)
987+
} yield !isString && !isList && !isHyper && !isHash
788988

789989
// check whether the key is a list or unused.
790990
private[this] def isList(name: String): STM[Nothing, Boolean] =
791991
for {
792992
isString <- strings.contains(name)
793993
isSet <- sets.contains(name)
794994
isHyper <- hyperLogLogs.contains(name)
795-
} yield !isString && !isSet && !isHyper
995+
isHash <- hashes.contains(name)
996+
} yield !isString && !isSet && !isHyper && !isHash
796997

797998
//check whether the key is a hyperLogLog or unused.
798999
private[this] def isHyperLogLog(name: String): ZSTM[Any, Nothing, Boolean] =
7991000
for {
8001001
isString <- strings.contains(name)
8011002
isSet <- sets.contains(name)
8021003
isList <- lists.contains(name)
803-
} yield !isString && !isSet && !isList
1004+
isHash <- hashes.contains(name)
1005+
} yield !isString && !isSet && !isList && !isHash
1006+
1007+
//check whether the key is a hash or unused.
1008+
private[this] def isHash(name: String): ZSTM[Any, Nothing, Boolean] =
1009+
for {
1010+
isString <- strings.contains(name)
1011+
isSet <- sets.contains(name)
1012+
isList <- lists.contains(name)
1013+
isHyper <- hyperLogLogs.contains(name)
1014+
} yield !isString && !isSet && !isList && !isHyper
8041015

8051016
@tailrec
8061017
private[this] def dropWhileLimit[A](xs: Chunk[A])(p: A => Boolean, k: Int): Chunk[A] =
@@ -903,7 +1114,8 @@ private[redis] object TestExecutor {
9031114
strings <- TMap.empty[String, String].commit
9041115
hyperLogLogs <- TMap.empty[String, Set[String]].commit
9051116
lists <- TMap.empty[String, Chunk[String]].commit
906-
} yield new TestExecutor(lists, sets, strings, randomPick, hyperLogLogs)
1117+
hashes <- TMap.empty[String, Map[String, String]].commit
1118+
} yield new TestExecutor(lists, sets, strings, randomPick, hyperLogLogs, hashes)
9071119

9081120
executor.toLayer
9091121
}

redis/src/test/scala/zio/redis/ApiSpec.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ object ApiSpec
3636
connectionSuite,
3737
setsSuite,
3838
hyperLogLogSuite,
39-
listSuite
39+
listSuite,
40+
hashSuite
4041
).filterAnnotations(TestAnnotation.tagged)(t => !t.contains(TestExecutorUnsupportedTag))
4142
.get
4243
.provideCustomLayerShared(RedisExecutor.test ++ Clock.live)

0 commit comments

Comments
 (0)