Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the support of the LCS command #2116

Merged
merged 19 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 84 additions & 2 deletions src/commands/cmd_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "commands/command_parser.h"
#include "error_constants.h"
#include "server/redis_reply.h"
#include "server/redis_request.h"
#include "server/server.h"
#include "storage/redis_db.h"
#include "time_util.h"
Expand Down Expand Up @@ -620,6 +621,88 @@ class CommandCAD : public Commander {
}
};

class CommandLCS : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
CommandParser parser(args, 3);
bool get_idx = false;
bool get_len = false;
while (parser.Good()) {
if (parser.EatEqICase("IDX")) {
get_idx = true;
} else if (parser.EatEqICase("LEN")) {
get_len = true;
} else if (parser.EatEqICase("WITHMATCHLEN")) {
with_match_len_ = true;
} else if (parser.EatEqICase("MINMATCHLEN")) {
min_match_len_ = GET_OR_RET(parser.TakeInt<int64_t>());
if (min_match_len_ < 0) {
min_match_len_ = 0;
}
} else {
return parser.InvalidSyntax();
}
}

// Complain if the user passed ambiguous parameters.
if (get_idx && get_len) {
return {Status::RedisParseErr,
"If you want both the length and indexes, "
"please just use IDX."};
}

if (get_len) {
type_ = StringLCSType::LEN;
} else if (get_idx) {
type_ = StringLCSType::IDX;
}

return Status::OK();
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::String string_db(srv->storage, conn->GetNamespace());

StringLCSResult rst;
auto s = string_db.LCS(args_[1], args_[2], {type_, min_match_len_}, &rst);
if (!s.ok()) {
return {Status::RedisExecErr, s.ToString()};
}

// Build output by the rst type.
if (auto lcs = std::get_if<std::string>(&rst)) {
*output = redis::BulkString(*lcs);
} else if (auto len = std::get_if<uint32_t>(&rst)) {
*output = redis::Integer(*len);
} else if (auto result = std::get_if<StringLCSIdxResult>(&rst)) {
*output = conn->HeaderOfMap(2);
*output += redis::BulkString("matches");
*output += redis::MultiLen(result->matches.size());
for (const auto &match : result->matches) {
*output += redis::MultiLen(with_match_len_ ? 3 : 2);
*output += redis::MultiLen(2);
*output += redis::Integer(match.a.start);
*output += redis::Integer(match.a.end);
*output += redis::MultiLen(2);
*output += redis::Integer(match.b.start);
*output += redis::Integer(match.b.end);
if (with_match_len_) {
*output += redis::Integer(match.match_len);
}
}
*output += redis::BulkString("len");
*output += redis::Integer(result->len);
}

return Status::OK();
}

private:
StringLCSType type_ = StringLCSType::NONE;
bool with_match_len_ = false;
int64_t min_match_len_ = 0;
};

REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandGet>("get", 2, "read-only", 1, 1, 1), MakeCmdAttr<CommandGetEx>("getex", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandStrlen>("strlen", 2, "read-only", 1, 1, 1),
Expand All @@ -637,6 +720,5 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandIncrByFloat>("incrbyfloat", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandIncr>("incr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandDecrBy>("decrby", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandDecr>("decr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandCAS>("cas", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), )

MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), MakeCmdAttr<CommandLCS>("lcs", -3, "read-only", 1, 2, 1), )
} // namespace redis
4 changes: 0 additions & 4 deletions src/server/redis_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@

namespace redis {

const size_t PROTO_INLINE_MAX_SIZE = 16 * 1024L;
const size_t PROTO_BULK_MAX_SIZE = 512 * 1024L * 1024L;
const size_t PROTO_MULTI_MAX_SIZE = 1024 * 1024L;

Status Request::Tokenize(evbuffer *input) {
size_t pipeline_size = 0;

Expand Down
4 changes: 4 additions & 0 deletions src/server/redis_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Server;

namespace redis {

constexpr size_t PROTO_INLINE_MAX_SIZE = 16 * 1024L;
constexpr size_t PROTO_BULK_MAX_SIZE = 512 * 1024L * 1024L;
constexpr size_t PROTO_MULTI_MAX_SIZE = 1024 * 1024L;

using CommandTokens = std::vector<std::string>;

class Connection;
Expand Down
153 changes: 153 additions & 0 deletions src/types/redis_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <string>

#include "parse_util.h"
#include "server/redis_request.h"
#include "storage/redis_metadata.h"
#include "time_util.h"

Expand Down Expand Up @@ -530,4 +531,156 @@ rocksdb::Status String::CAD(const std::string &user_key, const std::string &valu
return rocksdb::Status::OK();
}

rocksdb::Status String::LCS(const std::string &user_key1, const std::string &user_key2, StringLCSArgs args,
StringLCSResult *rst) {
if (args.type == StringLCSType::LEN) {
*rst = static_cast<uint32_t>(0);
} else if (args.type == StringLCSType::IDX) {
*rst = StringLCSIdxResult{{}, 0};
} else {
*rst = std::string{};
}

std::string a;
std::string b;
std::string ns_key1 = AppendNamespacePrefix(user_key1);
std::string ns_key2 = AppendNamespacePrefix(user_key2);
auto s1 = getValue(ns_key1, &a);
auto s2 = getValue(ns_key2, &b);

if (!s1.ok() && !s1.IsNotFound()) {
return s1;
}
if (!s2.ok() && !s2.IsNotFound()) {
return s2;
}
if (s1.IsNotFound()) a = "";
if (s2.IsNotFound()) b = "";

// Detect string truncation or later overflows.
if (a.length() >= UINT32_MAX - 1 || b.length() >= UINT32_MAX - 1) {
return rocksdb::Status::InvalidArgument("String too long for LCS");
}

// Compute the LCS using the vanilla dynamic programming technique of
// building a table of LCS(x, y) substrings.
auto alen = static_cast<uint32_t>(a.length());
auto blen = static_cast<uint32_t>(b.length());

// Allocate the LCS table.
uint64_t dp_size = (alen + 1) * (blen + 1);
uint64_t bulk_size = dp_size * sizeof(uint32_t);
if (bulk_size > PROTO_BULK_MAX_SIZE || bulk_size / dp_size != sizeof(uint32_t)) {
return rocksdb::Status::Aborted("Insufficient memory, transient memory for LCS exceeds proto-max-bulk-len");
}
std::vector<uint32_t> dp(dp_size, 0);
auto lcs = [&dp, blen](const uint32_t i, const uint32_t j) -> uint32_t & { return dp[i * (blen + 1) + j]; };

// Start building the LCS table.
for (uint32_t i = 1; i <= alen; i++) {
for (uint32_t j = 1; j <= blen; j++) {
if (a[i - 1] == b[j - 1]) {
// The len LCS (and the LCS itself) of two
// sequences with the same final character, is the
// LCS of the two sequences without the last char
// plus that last char.
lcs(i, j) = lcs(i - 1, j - 1) + 1;
} else {
// If the last character is different, take the longest
// between the LCS of the first string and the second
// minus the last char, and the reverse.
lcs(i, j) = std::max(lcs(i - 1, j), lcs(i, j - 1));
}
}
}

uint32_t idx = lcs(alen, blen);

// Only compute the length of LCS.
if (auto result = std::get_if<uint32_t>(rst)) {
*result = idx;
return rocksdb::Status::OK();
}

// Store the length of the LCS first if needed.
if (auto result = std::get_if<StringLCSIdxResult>(rst)) {
result->len = idx;
}

// Allocate when we need to compute the actual LCS string.
if (auto result = std::get_if<std::string>(rst)) {
result->resize(idx);
}

uint32_t i = alen;
uint32_t j = blen;
uint32_t arange_start = alen; // alen signals that values are not set.
uint32_t arange_end = 0;
uint32_t brange_start = 0;
uint32_t brange_end = 0;
while (i > 0 && j > 0) {
bool emit_range = false;
if (a[i - 1] == b[j - 1]) {
// If there is a match, store the character if needed.
// And reduce the indexes to look for a new match.
if (auto result = std::get_if<std::string>(rst)) {
result->at(idx - 1) = a[i - 1];
}

// Track the current range.
if (arange_start == alen) {
arange_start = i - 1;
arange_end = i - 1;
brange_start = j - 1;
brange_end = j - 1;
}
// Let's see if we can extend the range backward since
// it is contiguous.
else if (arange_start == i && brange_start == j) {
arange_start--;
brange_start--;
} else {
emit_range = true;
}

// Emit the range if we matched with the first byte of
// one of the two strings. We'll exit the loop ASAP.
if (arange_start == 0 || brange_start == 0) {
emit_range = true;
}
idx--;
i--;
j--;
} else {
// Otherwise reduce i and j depending on the largest
// LCS between, to understand what direction we need to go.
uint32_t lcs1 = lcs(i - 1, j);
uint32_t lcs2 = lcs(i, j - 1);
if (lcs1 > lcs2)
i--;
else
j--;
if (arange_start != alen) emit_range = true;
}

// Emit the current range if needed.
if (emit_range) {
if (auto result = std::get_if<StringLCSIdxResult>(rst)) {
uint32_t match_len = arange_end - arange_start + 1;

// Always emit the range when the `min_match_len` is not set.
if (args.min_match_len == 0 || match_len >= args.min_match_len) {
result->matches.emplace_back(StringLCSRange{arange_start, arange_end},
StringLCSRange{brange_start, brange_end}, match_len);
}
}

// Restart at the next match.
arange_start = alen;
}
}

return rocksdb::Status::OK();
}

} // namespace redis
33 changes: 32 additions & 1 deletion src/types/redis_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cstdint>
#include <optional>
#include <string>
#include <variant>
#include <vector>

#include "storage/redis_db.h"
Expand All @@ -42,8 +43,36 @@ struct StringSetArgs {
bool keep_ttl;
};

namespace redis {
enum class StringLCSType { NONE, LEN, IDX };

struct StringLCSArgs {
StringLCSType type;
int64_t min_match_len;
};

struct StringLCSRange {
uint32_t start;
uint32_t end;
};

struct StringLCSMatchedRange {
StringLCSRange a;
StringLCSRange b;
uint32_t match_len;

StringLCSMatchedRange(StringLCSRange ra, StringLCSRange rb, uint32_t len) : a(ra), b(rb), match_len(len) {}
};

struct StringLCSIdxResult {
// Matched ranges.
std::vector<StringLCSMatchedRange> matches;
// LCS length.
uint32_t len;
};

using StringLCSResult = std::variant<std::string, uint32_t, StringLCSIdxResult>;

namespace redis {
class String : public Database {
public:
explicit String(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {}
Expand All @@ -68,6 +97,8 @@ class String : public Database {
rocksdb::Status CAS(const std::string &user_key, const std::string &old_value, const std::string &new_value,
uint64_t ttl, int *flag);
rocksdb::Status CAD(const std::string &user_key, const std::string &value, int *flag);
rocksdb::Status LCS(const std::string &user_key1, const std::string &user_key2, StringLCSArgs args,
StringLCSResult *rst);

private:
rocksdb::Status getValue(const std::string &ns_key, std::string *value);
Expand Down
Loading
Loading