Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the support of the LCS command #2116

Merged
merged 19 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 197 additions & 2 deletions src/commands/cmd_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "commands/command_parser.h"
#include "error_constants.h"
#include "server/redis_reply.h"
#include "server/redis_request.h"
#include "server/server.h"
#include "storage/redis_db.h"
#include "time_util.h"
Expand Down Expand Up @@ -620,6 +621,201 @@ class CommandCAD : public Commander {
}
};

class CommandLCS : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
CommandParser parser(args, 3);
while (parser.Good()) {
if (parser.EatEqICase("IDX")) {
get_idx_ = true;
} else if (parser.EatEqICase("LEN")) {
get_len_ = true;
} else if (parser.EatEqICase("WITHMATCHLEN")) {
with_match_len_ = true;
} else if (parser.EatEqICase("MINMATCHLEN")) {
min_match_len_ = GET_OR_RET(parser.TakeInt<int64_t>());
if (min_match_len_ < 0) {
min_match_len_ = 0;
}
} else {
return parser.InvalidSyntax();
}
}

// Complain if the user passed ambiguous parameters.
if (get_idx_ && get_len_) {
return {Status::RedisParseErr,
"If you want both the length and indexes, "
"please just use IDX."};
}

return Status::OK();
}

Status Execute(Server *srv, Connection *conn, std::string *output) override {
redis::String string_db(srv->storage, conn->GetNamespace());

std::string a;
std::string b;
auto s1 = string_db.Get(args_[1], &a);
auto s2 = string_db.Get(args_[2], &b);

if (!s1.ok() && !s1.IsNotFound()) {
return {Status::RedisExecErr, s1.ToString()};
}
if (!s2.ok() && !s2.IsNotFound()) {
return {Status::RedisExecErr, s2.ToString()};
}
if (s1.IsNotFound()) a = "";
if (s2.IsNotFound()) b = "";

// Detect string truncation or later overflows.
if (a.length() >= UINT32_MAX - 1 || b.length() >= UINT32_MAX - 1) {
return {Status::RedisExecErr, "String too long for LCS"};
}

// Compute the LCS using the vanilla dynamic programming technique of
// building a table of LCS(x, y) substrings.
auto alen = static_cast<uint32_t>(a.length());
auto blen = static_cast<uint32_t>(b.length());

// Allocate the LCS table.
uint64_t dp_size = (alen + 1) * (blen + 1);
uint64_t bulk_size = dp_size * sizeof(uint32_t);
JoverZhang marked this conversation as resolved.
Show resolved Hide resolved
if (bulk_size >= SIZE_MAX || bulk_size / dp_size != sizeof(uint32_t)) {
return {Status::RedisExecErr, "Insufficient memory, failed allocating transient memory for LCS"};
}
if (bulk_size > PROTO_BULK_MAX_SIZE) {
jihuayu marked this conversation as resolved.
Show resolved Hide resolved
return {Status::RedisExecErr, "Insufficient memory, transient memory for LCS exceeds proto-max-bulk-len"};
}
std::vector<uint32_t> dp(dp_size, 0);
auto lcs = [&](const uint32_t i, const uint32_t j) -> uint32_t & { return dp[i * (blen + 1) + j]; };

// Start building the LCS table.
for (uint32_t i = 1; i <= alen; i++) {
for (uint32_t j = 1; j <= blen; j++) {
if (a[i - 1] == b[j - 1]) {
// The len LCS (and the LCS itself) of two
// sequences with the same final character, is the
// LCS of the two sequences without the last char
// plus that last char.
lcs(i, j) = lcs(i - 1, j - 1) + 1;
} else {
// If the last character is different, take the longest
// between the LCS of the first string and the second
// minus the last char, and the reverse.
lcs(i, j) = std::max(lcs(i - 1, j), lcs(i, j - 1));
}
}
}

// Store the actual LCS string if needed.
std::string result;
uint32_t idx = lcs(alen, blen);

// Allocate when we need to compute the actual LCS string.
bool compute_lcs = get_idx_ || !get_len_;
if (compute_lcs) result.resize(idx);

// Build a array if we have to emit the matched ranges.
std::string matches;
uint32_t matches_len = 0;

uint32_t i = alen;
uint32_t j = blen;
uint32_t arange_start = alen; // alen signals that values are not set.
uint32_t arange_end = 0;
uint32_t brange_start = 0;
uint32_t brange_end = 0;
while (compute_lcs && i > 0 && j > 0) {
JoverZhang marked this conversation as resolved.
Show resolved Hide resolved
bool emit_range = false;
if (a[i - 1] == b[j - 1]) {
// If there is a match, store the character and reduce
// the indexes to look for a new match.
result.at(idx - 1) = a[i - 1];

// Track the current range.
if (arange_start == alen) {
arange_start = i - 1;
arange_end = i - 1;
brange_start = j - 1;
brange_end = j - 1;
}
// Let's see if we can extend the range backward since
// it is contiguous.
else if (arange_start == i && brange_start == j) {
arange_start--;
brange_start--;
} else {
emit_range = true;
}

// Emit the range if we matched with the first byte of
// one of the two strings. We'll exit the loop ASAP.
if (arange_start == 0 || brange_start == 0) {
emit_range = true;
}
idx--;
i--;
j--;
} else {
// Otherwise reduce i and j depending on the largest
// LCS between, to understand what direction we need to go.
uint32_t lcs1 = lcs(i - 1, j);
uint32_t lcs2 = lcs(i, j - 1);
if (lcs1 > lcs2)
i--;
else
j--;
if (arange_start != alen) emit_range = true;
}

// Emit the current range if needed.
uint32_t match_len = arange_end - arange_start + 1;
if (emit_range) {
if (get_idx_ && (min_match_len_ == 0 || match_len >= min_match_len_)) {
JoverZhang marked this conversation as resolved.
Show resolved Hide resolved
matches += redis::MultiLen(with_match_len_ ? 3 : 2);
matches += redis::MultiLen(2);
matches += redis::Integer(arange_start);
matches += redis::Integer(arange_end);
matches += redis::MultiLen(2);
matches += redis::Integer(brange_start);
matches += redis::Integer(brange_end);
if (with_match_len_) {
matches += redis::Integer(match_len);
}
matches_len++;
}

// Restart at the next match.
arange_start = alen;
}
}

// Build output by the given options.
if (get_idx_) {
*output = conn->HeaderOfMap(2);
*output += redis::BulkString("matches");
*output += redis::MultiLen(matches_len);
*output += matches;
*output += redis::BulkString("len");
*output += redis::Integer(lcs(alen, blen));
} else if (get_len_) {
*output = redis::Integer(lcs(alen, blen));
} else {
*output = redis::BulkString(result);
}

return Status::OK();
}

private:
bool get_idx_ = false;
bool get_len_ = false;
bool with_match_len_ = false;
int64_t min_match_len_ = 0;
};

REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandGet>("get", 2, "read-only", 1, 1, 1), MakeCmdAttr<CommandGetEx>("getex", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandStrlen>("strlen", 2, "read-only", 1, 1, 1),
Expand All @@ -637,6 +833,5 @@ REDIS_REGISTER_COMMANDS(
MakeCmdAttr<CommandIncrByFloat>("incrbyfloat", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandIncr>("incr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandDecrBy>("decrby", 3, "write", 1, 1, 1),
MakeCmdAttr<CommandDecr>("decr", 2, "write", 1, 1, 1), MakeCmdAttr<CommandCAS>("cas", -4, "write", 1, 1, 1),
MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), )

MakeCmdAttr<CommandCAD>("cad", 3, "write", 1, 1, 1), MakeCmdAttr<CommandLCS>("lcs", -3, "read-only", 1, 2, 1), )
} // namespace redis
4 changes: 0 additions & 4 deletions src/server/redis_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@

namespace redis {

const size_t PROTO_INLINE_MAX_SIZE = 16 * 1024L;
const size_t PROTO_BULK_MAX_SIZE = 512 * 1024L * 1024L;
const size_t PROTO_MULTI_MAX_SIZE = 1024 * 1024L;

Status Request::Tokenize(evbuffer *input) {
size_t pipeline_size = 0;

Expand Down
4 changes: 4 additions & 0 deletions src/server/redis_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Server;

namespace redis {

const size_t PROTO_INLINE_MAX_SIZE = 16 * 1024L;
JoverZhang marked this conversation as resolved.
Show resolved Hide resolved
const size_t PROTO_BULK_MAX_SIZE = 512 * 1024L * 1024L;
const size_t PROTO_MULTI_MAX_SIZE = 1024 * 1024L;

using CommandTokens = std::vector<std::string>;

class Connection;
Expand Down
104 changes: 104 additions & 0 deletions tests/gocase/unit/type/strings/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -895,4 +895,108 @@ func TestString(t *testing.T) {
require.ErrorContains(t, rdb.Do(ctx, "CAD", "cad_key").Err(), "ERR wrong number of arguments")
require.ErrorContains(t, rdb.Do(ctx, "CAD", "cad_key", "123", "234").Err(), "ERR wrong number of arguments")
})

rna1 := "CACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTTCGTCCGGGTGTG"
rna2 := "ATTAAAGGTTTATACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTT"
rnalcs := "ACCTTCCCAGGTAACAAACCAACCAACTTTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTGGCTGTCACTCGGCTGCATGCTTAGTGCACTCACGCAGTATAATTAATAACTAATTACTGTCGTTGACAGGACACGAGTAACTCGTCTATCTTCTGCAGGCTGCTTACGGTTTCGTCCGTGTTGCAGCCGATCATCAGCACATCTAGGTTT"

t.Run("LCS basic", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
require.Equal(t, rnalcs, rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2"}).Val().MatchString)
})

t.Run("LCS len", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
require.Equal(t, int64(len(rnalcs)), rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Len: true}).Val().Len)
})

t.Run("LCS indexes", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
matches := rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true}).Val().Matches
require.Equal(t, []redis.LCSMatchedPosition{
{
Key1: redis.LCSPosition{Start: 238, End: 238},
Key2: redis.LCSPosition{Start: 239, End: 239},
},
{
Key1: redis.LCSPosition{Start: 236, End: 236},
Key2: redis.LCSPosition{Start: 238, End: 238},
},
{
Key1: redis.LCSPosition{Start: 229, End: 230},
Key2: redis.LCSPosition{Start: 236, End: 237},
},
{
Key1: redis.LCSPosition{Start: 224, End: 224},
Key2: redis.LCSPosition{Start: 235, End: 235},
},
{
Key1: redis.LCSPosition{Start: 1, End: 222},
Key2: redis.LCSPosition{Start: 13, End: 234},
},
}, matches)
})

t.Run("LCS indexes with match len", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
matches := rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true, WithMatchLen: true}).Val().Matches
require.Equal(t, []redis.LCSMatchedPosition{
{
Key1: redis.LCSPosition{Start: 238, End: 238},
Key2: redis.LCSPosition{Start: 239, End: 239},
MatchLen: 1,
},
{
Key1: redis.LCSPosition{Start: 236, End: 236},
Key2: redis.LCSPosition{Start: 238, End: 238},
MatchLen: 1,
},
{
Key1: redis.LCSPosition{Start: 229, End: 230},
Key2: redis.LCSPosition{Start: 236, End: 237},
MatchLen: 2,
},
{
Key1: redis.LCSPosition{Start: 224, End: 224},
Key2: redis.LCSPosition{Start: 235, End: 235},
MatchLen: 1,
},
{
Key1: redis.LCSPosition{Start: 1, End: 222},
Key2: redis.LCSPosition{Start: 13, End: 234},
MatchLen: 222,
},
}, matches)
})

t.Run("LCS indexes with match len and minimum match len", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", rna2, 0).Err())
matches := rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true, WithMatchLen: true, MinMatchLen: 5}).Val().Matches
require.Equal(t, []redis.LCSMatchedPosition{
{
Key1: redis.LCSPosition{Start: 1, End: 222},
Key2: redis.LCSPosition{Start: 13, End: 234},
MatchLen: 222,
},
}, matches)
})
JoverZhang marked this conversation as resolved.
Show resolved Hide resolved

t.Run("LCS empty", func(t *testing.T) {
require.NoError(t, rdb.Set(ctx, "virus1", rna1, 0).Err())
require.NoError(t, rdb.Set(ctx, "virus2", "", 0).Err())

require.Equal(t, rna1, rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus1"}).Val().MatchString)
require.Equal(t, "", rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2"}).Val().MatchString)
require.Equal(t, "", rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus2", Key2: "virus1"}).Val().MatchString)
require.Equal(t, "", rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus2", Key2: "virus2"}).Val().MatchString)

require.Equal(t, int64(0), rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2"}).Val().Len)
require.Equal(t, []redis.LCSMatchedPosition{}, rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true}).Val().Matches)
require.Equal(t, []redis.LCSMatchedPosition{}, rdb.LCS(ctx, &redis.LCSQuery{Key1: "virus1", Key2: "virus2", Idx: true, WithMatchLen: true}).Val().Matches)
})
}
Loading