Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/commands/cmd_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct CommandFunction : Commander {
}

std::string libname;
auto s = lua::FunctionLoad(conn, GET_OR_RET(parser.TakeStr()), true, replace, &libname);
auto s = lua::FunctionLoad(conn, &ctx, GET_OR_RET(parser.TakeStr()), true, replace, &libname);
if (!s) return s;

*output = SimpleString(libname);
Expand All @@ -55,21 +55,21 @@ struct CommandFunction : Commander {
with_code = true;
}

return lua::FunctionList(srv, conn, libname, with_code, output);
return lua::FunctionList(srv, conn, ctx, libname, with_code, output);
} else if (parser.EatEqICase("listfunc")) {
std::string funcname;
if (parser.EatEqICase("funcname")) {
funcname = GET_OR_RET(parser.TakeStr());
}

return lua::FunctionListFunc(srv, conn, funcname, output);
return lua::FunctionListFunc(srv, conn, ctx, funcname, output);
} else if (parser.EatEqICase("listlib")) {
auto libname = GET_OR_RET(parser.TakeStr().Prefixed("expect a library name"));

return lua::FunctionListLib(conn, libname, output);
} else if (parser.EatEqICase("delete")) {
auto libname = GET_OR_RET(parser.TakeStr());
if (!lua::FunctionIsLibExist(conn, libname)) {
if (!lua::FunctionIsLibExist(conn, &ctx, libname)) {
return {Status::NotOK, "no such library"};
}
auto s = lua::FunctionDelete(ctx, conn, libname);
Expand All @@ -94,7 +94,8 @@ struct CommandFCall : Commander {
return {Status::NotOK, "Number of keys can't be negative"};
}

return lua::FunctionCall(conn, args_[1], std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
return lua::FunctionCall(conn, &ctx, args_[1],
std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
std::vector<std::string>(args_.begin() + 3 + numkeys, args_.end()), output, read_only);
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/commands/cmd_script.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class CommandEvalImpl : public Commander {
}

return lua::EvalGenericCommand(
conn, args_[1], std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
conn, &ctx, args_[1], std::vector<std::string>(args_.begin() + 3, args_.begin() + 3 + numkeys),
std::vector<std::string>(args_.begin() + 3 + numkeys, args_.end()), evalsha, output, read_only);
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/server/redis_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ static bool IsCmdAllowedInStaleData(const std::string &cmd_name) {
void Connection::ExecuteCommands(std::deque<CommandTokens> *to_process_cmds) {
const Config *config = srv_->GetConfig();
std::string reply;
std::string password = config->requirepass;
const std::string &password = config->requirepass;

while (!to_process_cmds->empty()) {
CommandTokens cmd_tokens = std::move(to_process_cmds->front());
Expand Down
52 changes: 26 additions & 26 deletions src/storage/scripting.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ int RedisRegisterFunction(lua_State *lua) {
return 0;
}

Status FunctionLoad(redis::Connection *conn, const std::string &script, bool need_to_store, bool replace,
[[maybe_unused]] std::string *lib_name, bool read_only) {
Status FunctionLoad(redis::Connection *conn, engine::Context *ctx, const std::string &script, bool need_to_store,
bool replace, [[maybe_unused]] std::string *lib_name, bool read_only) {
std::string first_line, lua_code;
if (auto pos = script.find('\n'); pos != std::string::npos) {
first_line = script.substr(0, pos);
Expand All @@ -296,17 +296,17 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee
auto srv = conn->GetServer();
auto lua = conn->Owner()->Lua();

if (FunctionIsLibExist(conn, libname, need_to_store, read_only)) {
if (FunctionIsLibExist(conn, ctx, libname, need_to_store, read_only)) {
if (!replace) {
return {Status::NotOK, "library already exists, please specify REPLACE to force load"};
}
engine::Context ctx(srv->storage);
auto s = FunctionDelete(ctx, conn, libname);
auto s = FunctionDelete(*ctx, conn, libname);
if (!s) return s;
}

ScriptRunCtx script_run_ctx;
script_run_ctx.conn = conn;
script_run_ctx.ctx = ctx;
script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;

SaveOnRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME, &script_run_ctx);
Expand Down Expand Up @@ -339,14 +339,15 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee

RemoveFromRegistry(lua, REGISTRY_SCRIPT_RUN_CTX_NAME);

if (!FunctionIsLibExist(conn, libname, false, read_only)) {
if (!FunctionIsLibExist(conn, ctx, libname, false, read_only)) {
return {Status::NotOK, "Please register some function in FUNCTION LOAD"};
}

return need_to_store ? srv->FunctionSetCode(libname, script) : Status::OK();
}

bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, bool need_check_storage, bool read_only) {
bool FunctionIsLibExist(redis::Connection *conn, engine::Context *ctx, const std::string &libname,
bool need_check_storage, bool read_only) {
auto srv = conn->GetServer();
auto lua = conn->Owner()->Lua();

Expand All @@ -373,14 +374,15 @@ bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, boo
if (!s) return false;

std::string lib_name;
s = FunctionLoad(conn, code, false, false, &lib_name, read_only);
s = FunctionLoad(conn, ctx, code, false, false, &lib_name, read_only);
return static_cast<bool>(s);
}

// FunctionCall will firstly find the function in the lua runtime,
// if it is not found, it will try to load the library where the function is located from storage
Status FunctionCall(redis::Connection *conn, const std::string &name, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, std::string *output, bool read_only) {
Status FunctionCall(redis::Connection *conn, engine::Context *ctx, const std::string &name,
const std::vector<std::string> &keys, const std::vector<std::string> &argv, std::string *output,
bool read_only) {
auto srv = conn->GetServer();
auto lua = conn->Owner()->Lua();

Expand All @@ -397,14 +399,15 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std:
std::string libcode;
s = srv->FunctionGetCode(libname, &libcode);
if (!s) return s;
s = FunctionLoad(conn, libcode, false, false, &libname, read_only);
s = FunctionLoad(conn, ctx, libcode, false, false, &libname, read_only);
if (!s) return s;

lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_PREFIX + name).c_str());
}

ScriptRunCtx script_run_ctx;
script_run_ctx.conn = conn;
script_run_ctx.ctx = ctx;
script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
lua_getglobal(lua, (REDIS_LUA_REGISTER_FUNC_FLAGS_PREFIX + name).c_str());
if (!lua_isnil(lua, -1)) {
Expand Down Expand Up @@ -447,13 +450,12 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std:
}

// list all library names and their code (enabled via `with_code`)
Status FunctionList(Server *srv, const redis::Connection *conn, const std::string &libname, bool with_code,
std::string *output) {
Status FunctionList(Server *srv, const redis::Connection *conn, engine::Context &ctx, const std::string &libname,
bool with_code, std::string *output) {
std::string start_key = engine::kLuaLibCodePrefix + libname;
std::string end_key = start_key;
end_key.back()++;

engine::Context ctx(srv->storage);
rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
rocksdb::Slice upper_bound(end_key);
read_options.iterate_upper_bound = &upper_bound;
Expand Down Expand Up @@ -487,12 +489,12 @@ Status FunctionList(Server *srv, const redis::Connection *conn, const std::strin

// extension to Redis Function
// list all function names and their corresponding library names
Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::string &funcname, std::string *output) {
Status FunctionListFunc(Server *srv, const redis::Connection *conn, engine::Context &ctx, const std::string &funcname,
std::string *output) {
std::string start_key = engine::kLuaFuncLibPrefix + funcname;
std::string end_key = start_key;
end_key.back()++;

engine::Context ctx(srv->storage);
rocksdb::ReadOptions read_options = ctx.DefaultScanOptions();
rocksdb::Slice upper_bound(end_key);
read_options.iterate_upper_bound = &upper_bound;
Expand Down Expand Up @@ -603,8 +605,9 @@ Status FunctionDelete(engine::Context &ctx, redis::Connection *conn, const std::
return Status::OK();
}

Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sha, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, bool evalsha, std::string *output, bool read_only) {
Status EvalGenericCommand(redis::Connection *conn, engine::Context *ctx, const std::string &body_or_sha,
const std::vector<std::string> &keys, const std::vector<std::string> &argv, bool evalsha,
std::string *output, bool read_only) {
Server *srv = conn->GetServer();
// Use the worker's private Lua VM when entering the read-only mode
lua_State *lua = conn->Owner()->Lua();
Expand Down Expand Up @@ -652,6 +655,7 @@ Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sh

ScriptRunCtx current_script_run_ctx;
current_script_run_ctx.conn = conn;
current_script_run_ctx.ctx = ctx;
current_script_run_ctx.flags = read_only ? ScriptFlagType::kScriptNoWrites : 0;
lua_getglobal(lua, fmt::format(REDIS_LUA_FUNC_SHA_FLAGS, funcname + 2).c_str());
if (!lua_isnil(lua, -1)) {
Expand Down Expand Up @@ -820,14 +824,10 @@ int RedisGenericCommand(lua_State *lua, int raise_error) {
}

std::string output;
// TODO: make it possible for multiple redis commands in lua script to use the same txn context.
{
engine::Context ctx(srv->storage);
s = conn->ExecuteCommand(ctx, cmd_name, args, cmd.get(), &output);
if (!s) {
PushError(lua, s.Msg().data());
return raise_error ? RaiseError(lua) : 1;
}
s = conn->ExecuteCommand(*script_run_ctx->ctx, cmd_name, args, cmd.get(), &output);
if (!s) {
PushError(lua, s.Msg().data());
return raise_error ? RaiseError(lua) : 1;
}

srv->FeedMonitorConns(conn, args);
Expand Down
29 changes: 17 additions & 12 deletions src/storage/scripting.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "lua.hpp"
#include "server/redis_connection.h"
#include "status.h"
#include "storage/storage.h"

namespace engine {
struct Context;
Expand Down Expand Up @@ -62,23 +63,25 @@ int RedisSetResp(lua_State *lua);

Status CreateFunction(Server *srv, const std::string &body, std::string *sha, lua_State *lua, bool need_to_store);

Status EvalGenericCommand(redis::Connection *conn, const std::string &body_or_sha, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, bool evalsha, std::string *output,
bool read_only = false);
Status EvalGenericCommand(redis::Connection *conn, engine::Context *ctx, const std::string &body_or_sha,
const std::vector<std::string> &keys, const std::vector<std::string> &argv, bool evalsha,
std::string *output, bool read_only = false);

bool ScriptExists(lua_State *lua, const std::string &sha);

Status FunctionLoad(redis::Connection *conn, const std::string &script, bool need_to_store, bool replace,
std::string *lib_name, bool read_only = false);
Status FunctionCall(redis::Connection *conn, const std::string &name, const std::vector<std::string> &keys,
const std::vector<std::string> &argv, std::string *output, bool read_only = false);
Status FunctionList(Server *srv, const redis::Connection *conn, const std::string &libname, bool with_code,
std::string *output);
Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::string &funcname, std::string *output);
Status FunctionLoad(redis::Connection *conn, engine::Context *ctx, const std::string &script, bool need_to_store,
bool replace, std::string *lib_name, bool read_only = false);
Status FunctionCall(redis::Connection *conn, engine::Context *ctx, const std::string &name,
const std::vector<std::string> &keys, const std::vector<std::string> &argv, std::string *output,
bool read_only = false);
Status FunctionList(Server *srv, const redis::Connection *conn, engine::Context &ctx, const std::string &libname,
bool with_code, std::string *output);
Status FunctionListFunc(Server *srv, const redis::Connection *conn, engine::Context &ctx, const std::string &funcname,
std::string *output);
Status FunctionListLib(redis::Connection *conn, const std::string &libname, std::string *output);
Status FunctionDelete(engine::Context &ctx, redis::Connection *conn, const std::string &name);
bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, bool need_check_storage = true,
bool read_only = false);
bool FunctionIsLibExist(redis::Connection *conn, engine::Context *ctx, const std::string &libname,
bool need_check_storage = true, bool read_only = false);

const char *RedisProtocolToLuaType(lua_State *lua, const char *reply);
const char *RedisProtocolToLuaTypeInt(lua_State *lua, const char *reply);
Expand Down Expand Up @@ -150,6 +153,8 @@ struct ScriptRunCtx {
int current_slot = -1;
// the current connection
redis::Connection *conn = nullptr;
// the storage context
engine::Context *ctx = nullptr;
};

/// SaveOnRegistry saves user-defined data to lua REGISTRY
Expand Down
Loading