Skip to content

Commit

Permalink
Add support of the command ZDIFF and ZDIFFSTORE (#2021)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshMadushan authored Jan 17, 2024
1 parent 5af0b3b commit 69b054e
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 1 deletion.
97 changes: 96 additions & 1 deletion src/commands/cmd_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,99 @@ class CommandZRandMember : public Commander {
bool no_parameters_ = true;
};

class CommandZDiff : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
auto parse_result = ParseInt<int>(args[1], 10);
if (!parse_result) return {Status::RedisParseErr, errValueNotInteger};

numkeys_ = *parse_result;
if (numkeys_ > args.size() - 2) return {Status::RedisParseErr, errInvalidSyntax};

size_t j = 0;
while (j < numkeys_) {
keys_.emplace_back(args[j + 2]);
j++;
}

if (auto i = 2 + numkeys_; i < args.size()) {
if (util::ToLower(args[i]) == "withscores") {
with_scores_ = true;
}
}

return Commander::Parse(args);
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::ZSet zset_db(srv->storage, conn->GetNamespace());

std::vector<MemberScore> members_with_scores;
auto s = zset_db.Diff(keys_, &members_with_scores);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}

output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ ? 2 : 1)));
for (const auto &ms : members_with_scores) {
output->append(redis::BulkString(ms.member));
if (with_scores_) output->append(redis::BulkString(util::Float2String(ms.score)));
}

return Status::OK();
}

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[1], 10);
return {2, 2 + num_key, 1};
}

protected:
size_t numkeys_ = 0;
std::vector<rocksdb::Slice> keys_;
bool with_scores_ = false;
};

class CommandZDiffStore : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
auto parse_result = ParseInt<int>(args[2], 10);
if (!parse_result) return {Status::RedisParseErr, errValueNotInteger};

numkeys_ = *parse_result;
if (numkeys_ > args.size() - 3) return {Status::RedisParseErr, errInvalidSyntax};

size_t j = 0;
while (j < numkeys_) {
keys_.emplace_back(args[j + 3]);
j++;
}

return Commander::Parse(args);
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::ZSet zset_db(srv->storage, conn->GetNamespace());

uint64_t stored_count = 0;
auto s = zset_db.DiffStore(args_[1], keys_, &stored_count);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}
*output = redis::Integer(stored_count);
return Status::OK();
}

static CommandKeyRange Range(const std::vector<std::string> &args) {
int num_key = *ParseInt<int>(args[1], 10);
return {3, 2 + num_key, 1};
}

protected:
size_t numkeys_ = 0;
std::vector<rocksdb::Slice> keys_;
};

REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandZCard>("zcard", 2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandZCount>("zcount", 4, "read-only", 1, 1, 1),
Expand Down Expand Up @@ -1451,6 +1544,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandZAdd>("zadd", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandZScan>("zscan", -3, "read-only", 1, 1, 1),
MakeCmdAttr<CommandZUnionStore>("zunionstore", -4, "write", CommandZUnionStore::Range),
MakeCmdAttr<CommandZUnion>("zunion", -3, "read-only", CommandZUnion::Range),
MakeCmdAttr<CommandZRandMember>("zrandmember", -2, "read-only", 1, 1, 1))
MakeCmdAttr<CommandZRandMember>("zrandmember", -2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandZDiff>("zdiff", -3, "read-only", CommandZDiff::Range),
MakeCmdAttr<CommandZDiffStore>("zdiffstore", -3, "read-only", CommandZDiffStore::Range), )

} // namespace redis
38 changes: 38 additions & 0 deletions src/types/redis_zset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,4 +931,42 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count,
return rocksdb::Status::OK();
}

rocksdb::Status ZSet::Diff(const std::vector<Slice> &keys, MemberScores *members) {
members->clear();
MemberScores source_member_scores;
RangeScoreSpec spec;
uint64_t size = 0;
auto s = RangeByScore(keys[0], spec, &source_member_scores, &size);
if (!s.ok()) return s;

if (size == 0) {
return rocksdb::Status::OK();
}

std::map<std::string, bool> exclude_members;
MemberScores target_member_scores;
for (size_t i = 1; i < keys.size(); i++) {
uint64_t size = 0;
s = RangeByScore(keys[i], spec, &target_member_scores, &size);
if (!s.ok()) return s;
for (const auto &member_score : target_member_scores) {
exclude_members[member_score.member] = true;
}
}
for (const auto &member_score : source_member_scores) {
if (exclude_members.find(member_score.member) == exclude_members.end()) {
members->push_back(member_score);
}
}
return rocksdb::Status::OK();
}

rocksdb::Status ZSet::DiffStore(const Slice &dst, const std::vector<Slice> &keys, uint64_t *stored_count) {
MemberScores mscores;
auto s = Diff(keys, &mscores);
if (!s.ok()) return s;
*stored_count = mscores.size();
return Overwrite(dst, mscores);
}

} // namespace redis
2 changes: 2 additions & 0 deletions src/types/redis_zset.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class ZSet : public SubKeyScanner {
AggregateMethod aggregate_method, uint64_t *saved_cnt);
rocksdb::Status Union(const std::vector<KeyWeight> &keys_weights, AggregateMethod aggregate_method,
std::vector<MemberScore> *members);
rocksdb::Status Diff(const std::vector<Slice> &keys, MemberScores *members);
rocksdb::Status DiffStore(const Slice &dst, const std::vector<Slice> &keys, uint64_t *stored_count);
rocksdb::Status MGet(const Slice &user_key, const std::vector<Slice> &members, std::map<std::string, double> *scores);
rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata);

Expand Down
78 changes: 78 additions & 0 deletions tests/cppunit/types/zset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,81 @@ TEST_F(RedisZSetTest, RandMember) {
auto s = zset_->Del(key_);
EXPECT_TRUE(s.ok());
}

TEST_F(RedisZSetTest, Diff) {
uint64_t ret = 0;

std::string k1 = "key1";
std::vector<MemberScore> k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}};

std::string k2 = "key2";
std::vector<MemberScore> k2_mscores = {{"c", -150.1}};

std::string k3 = "key3";
std::vector<MemberScore> k3_mscores = {{"a", -1000.1}, {"c", -100.1}, {"e", 8000.9}};

auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret);
EXPECT_EQ(ret, 4);
zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret);
EXPECT_EQ(ret, 1);
zset_->Add(k3, ZAddFlags::Default(), &k3_mscores, &ret);
EXPECT_EQ(ret, 3);

std::vector<MemberScore> mscores;
zset_->Diff({k1, k2, k3}, &mscores);

EXPECT_EQ(2, mscores.size());
std::vector<MemberScore> expected_mscores = {{"b", -100.1}, {"d", 1.234}};
int index = 0;
for (const auto &mscore : expected_mscores) {
EXPECT_EQ(mscore.member, mscores[index].member);
EXPECT_EQ(mscore.score, mscores[index].score);
index++;
}

s = zset_->Del(k1);
EXPECT_TRUE(s.ok());
s = zset_->Del(k2);
EXPECT_TRUE(s.ok());
s = zset_->Del(k3);
EXPECT_TRUE(s.ok());
}

TEST_F(RedisZSetTest, DiffStore) {
uint64_t ret = 0;

std::string k1 = "key1";
std::vector<MemberScore> k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}};

std::string k2 = "key2";
std::vector<MemberScore> k2_mscores = {{"c", -150.1}};

auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret);
EXPECT_EQ(ret, 4);
zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret);
EXPECT_EQ(ret, 1);

uint64_t stored_count = 0;
zset_->DiffStore("zsetdiff", {k1, k2}, &stored_count);
EXPECT_EQ(stored_count, 3);

RangeScoreSpec spec;
std::vector<MemberScore> mscores;
zset_->RangeByScore("zsetdiff", spec, &mscores, nullptr);
EXPECT_EQ(mscores.size(), 3);

std::vector<MemberScore> expected_mscores = {{"a", -100.1}, {"b", -100.1}, {"d", 1.234}};
int index = 0;
for (const auto &mscore : expected_mscores) {
EXPECT_EQ(mscore.member, mscores[index].member);
EXPECT_EQ(mscore.score, mscores[index].score);
index++;
}

s = zset_->Del(k1);
EXPECT_TRUE(s.ok());
s = zset_->Del(k2);
EXPECT_TRUE(s.ok());
s = zset_->Del("zsetdiff");
EXPECT_TRUE(s.ok());
}
Loading

0 comments on commit 69b054e

Please sign in to comment.