Skip to content

Commit

Permalink
Support for the JSON.ARRTRIM command (#1881)
Browse files Browse the repository at this point in the history
  • Loading branch information
jihuayu authored Nov 12, 2023
1 parent 21f7997 commit 21c8168
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/commands/cmd_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,47 @@ class CommandJsonArrPop : public Commander {
int64_t index_ = -1;
};

class CommandJsonArrTrim : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
path_ = args_[2];
start_ = GET_OR_RET(ParseInt<int64_t>(args_[3], 10));
stop_ = GET_OR_RET(ParseInt<int64_t>(args_[4], 10));

return Status::OK();
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::Json json(srv->storage, conn->GetNamespace());

std::vector<std::optional<uint64_t>> results;

auto s = json.ArrTrim(args_[1], path_, start_, stop_, results);

if (s.IsNotFound()) {
*output = redis::NilString();
return Status::OK();
}
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};

*output = redis::MultiLen(results.size());
for (const auto &len : results) {
if (len.has_value()) {
*output += redis::Integer(len.value());
} else {
*output += redis::NilString();
}
}

return Status::OK();
}

private:
std::string path_;
int64_t start_ = 0;
int64_t stop_ = 0;
};

class CommanderJsonArrIndex : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
Expand Down Expand Up @@ -399,6 +440,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandJsonSet>("json.set", 4, "write", 1, 1
MakeCmdAttr<CommandJsonInfo>("json.info", 2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandJsonType>("json.type", -2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandJsonArrAppend>("json.arrappend", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandJsonArrTrim>("json.arrtrim", 5, "write", 1, 1, 1),
MakeCmdAttr<CommandJsonClear>("json.clear", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandJsonToggle>("json.toggle", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandJsonArrLen>("json.arrlen", -2, "read-only", 1, 1, 1),
Expand Down
33 changes: 33 additions & 0 deletions src/types/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#pragma once

#include <algorithm>
#include <jsoncons/json.hpp>
#include <jsoncons/json_error.hpp>
#include <jsoncons/json_options.hpp>
Expand Down Expand Up @@ -413,6 +414,38 @@ struct JsonValue {
return popped_values;
}

Status ArrTrim(std::string_view path, int64_t start, int64_t stop, std::vector<std::optional<uint64_t>> &results) {
try {
jsoncons::jsonpath::json_replace(
value, path, [&results, start, stop](const std::string & /*path*/, jsoncons::json &val) {
if (val.is_array()) {
auto len = static_cast<int64_t>(val.size());
auto begin_index = start < 0 ? std::max(len + start, static_cast<int64_t>(0)) : start;
auto end_index = std::min(stop < 0 ? std::max(len + stop, static_cast<int64_t>(0)) : stop, len - 1);

if (begin_index >= len || begin_index > end_index) {
val = jsoncons::json::array();
results.emplace_back(0);
return;
}

auto n_val = jsoncons::json::array();
auto begin_iter = val.array_range().begin();

n_val.insert(n_val.end(), begin_iter + begin_index, begin_iter + end_index + 1);
val = n_val;
results.emplace_back(static_cast<int64_t>(n_val.size()));
} else {
results.emplace_back(std::nullopt);
}
});
} catch (const jsoncons::jsonpath::jsonpath_error &e) {
return {Status::NotOK, e.what()};
}

return Status::OK();
}

JsonValue(const JsonValue &) = default;
JsonValue(JsonValue &&) = default;

Expand Down
19 changes: 19 additions & 0 deletions src/types/redis_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,4 +339,23 @@ rocksdb::Status Json::ObjKeys(const std::string &user_key, const std::string &pa
return rocksdb::Status::OK();
}

rocksdb::Status Json::ArrTrim(const std::string &user_key, const std::string &path, int64_t start, int64_t stop,
std::vector<std::optional<uint64_t>> &results) {
auto ns_key = AppendNamespacePrefix(user_key);

LockGuard guard(storage_->GetLockManager(), ns_key);

JsonMetadata metadata;
JsonValue json_val;
auto s = read(ns_key, &metadata, &json_val);
if (!s.ok()) return s;

auto len_res = json_val.ArrTrim(path, start, stop, results);
if (!len_res) return rocksdb::Status::InvalidArgument(len_res.Msg());
bool is_write =
std::any_of(results.begin(), results.end(), [](const std::optional<uint64_t> &val) { return val.has_value(); });
if (!is_write) return rocksdb::Status::OK();
return write(ns_key, &metadata, json_val);
}

} // namespace redis
3 changes: 3 additions & 0 deletions src/types/redis_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class Json : public Database {
rocksdb::Status ArrIndex(const std::string &user_key, const std::string &path, const std::string &needle,
ssize_t start, ssize_t end, std::vector<ssize_t> *result);

rocksdb::Status ArrTrim(const std::string &user_key, const std::string &path, int64_t start, int64_t stop,
std::vector<std::optional<uint64_t>> &results);

private:
rocksdb::Status write(Slice ns_key, JsonMetadata *metadata, const JsonValue &json_val);
rocksdb::Status read(const Slice &ns_key, JsonMetadata *metadata, JsonValue *value);
Expand Down
84 changes: 84 additions & 0 deletions tests/gocase/unit/type/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,90 @@ func TestJson(t *testing.T) {
require.ErrorContains(t, rdb.Do(ctx, "JSON.ARRPOP", "a", "$", "0", "1").Err(), "wrong number of arguments")
})

t.Run("JSON.ARRTRIM basics", func(t *testing.T) {
require.NoError(t, rdb.Del(ctx, "a").Err())
// key no exists
require.EqualError(t, rdb.Do(ctx, "JSON.ARRTRIM", "not_exists", "$", 0, 0).Err(), redis.Nil.Error())
// key not json
require.NoError(t, rdb.Do(ctx, "SET", "no_json", "1").Err())
require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "no_json", "$", 0, 0).Err())
// json path no exists
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"a1":{}}`).Err())
require.EqualValues(t, []interface{}{}, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.not_exists", 0, 0).Val())
// json path not array
require.EqualValues(t, []interface{}{nil}, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", 0, 0).Val())
// json path has one array
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"a1":[1,2,3,4,5,6,7,8,9]}`).Err())
require.EqualValues(t, []interface{}{int64(5)}, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", 2, 6).Val())
// json path has many array
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"a":{"a1":[1,2,3,4,5,6]},"b":{"a1":["a",{},"b"]},"c":{"a1":[7,8,9,10,11]}}`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(3), int64(2), int64(3)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$..a1", 1, 3).Val())
require.EqualValues(t, "[{\"a\":{\"a1\":[2,3,4]},\"b\":{\"a1\":[{},\"b\"]},\"c\":{\"a1\":[8,9,10]}}]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// json path has many array and one is not array
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"a":{"a1":[1,2,3,4,5,6]},"b":{"a1":{"b":1,"c":1}},"c":{"a1":[7,8,9,10]}}`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(3), interface{}(nil), int64(3)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$..a1", 1, 3).Val())
require.EqualValues(t, "[{\"a\":{\"a1\":[2,3,4]},\"b\":{\"a1\":{\"b\":1,\"c\":1}},\"c\":{\"a1\":[8,9,10]}}]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start not a integer
require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", "no", 1).Err())
require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", 1.1, 1).Err())
// stop not a integer
require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", 1, 1.1).Err())
// args size != 5
require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", 0).Err())
require.Error(t, rdb.Do(ctx, "JSON.ARRTRIM", "a", "$.a1", 0, 2, 3).Err())
})

t.Run("JSON.ARRTRIM special <start> and <stop> args", func(t *testing.T) {
// start < 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(4)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", -5, 8).Val())
require.EqualValues(t, "[[6,7,8,9]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start + len < 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(6)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", -20, 5).Val())
require.EqualValues(t, "[[1,2,3,4,5,6]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start > len
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(0)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", 15, 25).Val())
require.EqualValues(t, "[[]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start = 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(9)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", 0, 8).Val())
require.EqualValues(t, "[[1,2,3,4,5,6,7,8,9]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// stop = 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(1)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", -12, 0).Val())
require.EqualValues(t, "[[1]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// stop < 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(9)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", 0, -2).Val())
require.EqualValues(t, "[[1,2,3,4,5,6,7,8,9]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// len + stop < 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(1)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", 0, -20).Val())
require.EqualValues(t, "[[1]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// stop > len
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(10)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", 0, 20).Val())
require.EqualValues(t, "[[1,2,3,4,5,6,7,8,9,10]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start > stop
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(0)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", 8, 5).Val())
require.EqualValues(t, "[[]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start < 0 and stop < 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(4)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", -8, -5).Val())
require.EqualValues(t, "[[3,4,5,6]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start < 0 , stop < 0 and start > end
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(0)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", -5, -8).Val())
require.EqualValues(t, "[[]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
// start + len < 0 , stop + len < 0
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `[1,2,3,4,5,6,7,8,9,10]`).Err())
require.EqualValues(t, []interface{}([]interface{}{int64(1)}), rdb.Do(ctx, "JSON.ARRTRIM", "a", "$", -30, -20).Val())
require.EqualValues(t, "[[1]]", rdb.Do(ctx, "JSON.GET", "a", "$").Val())
})

t.Run("JSON.TOGGLE basics", func(t *testing.T) {
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `true`).Err())
require.EqualValues(t, []interface{}{int64(0)}, rdb.Do(ctx, "JSON.TOGGLE", "a", "$").Val())
Expand Down

0 comments on commit 21c8168

Please sign in to comment.