Skip to content

Commit 5d9d5c6

Browse files
committed
Lua: require supports loads from assets
Implements a `require` function that supports built-in modules like so: ```lua local log = require('devilutionx.log') ``` It falls back to reading from assets, so this loads `lua/user.lua`: ```lua local user = require('lua.user') ``` The bytecode for the asset scripts is cached, in case we want to later support multiple isolated environments. There may be a simpler or better way to do this. It's good enough for now until someone more knowledgeable about Lua comes along.
1 parent 026907e commit 5d9d5c6

File tree

5 files changed

+116
-34
lines changed

5 files changed

+116
-34
lines changed

Source/engine/assets.cpp

+27
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ bool FindMpqFile(std::string_view filename, MpqArchive **archive, uint32_t *file
7575
AssetRef FindAsset(std::string_view filename)
7676
{
7777
AssetRef result;
78+
if (filename.empty() || filename.back() == '\\')
79+
return result;
7880
result.path[0] = '\0';
7981

8082
char pathBuf[AssetRef::PathBufSize];
@@ -113,6 +115,9 @@ AssetRef FindAsset(std::string_view filename)
113115
AssetRef FindAsset(std::string_view filename)
114116
{
115117
AssetRef result;
118+
if (filename.empty() || filename.back() == '\\')
119+
return result;
120+
116121
std::string relativePath { filename };
117122
#ifndef _WIN32
118123
std::replace(relativePath.begin(), relativePath.end(), '\\', '/');
@@ -206,4 +211,26 @@ SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe)
206211
#endif
207212
}
208213

214+
tl::expected<AssetData, std::string> LoadAsset(std::string_view path)
215+
{
216+
AssetRef ref = FindAsset(path);
217+
if (!ref.ok()) {
218+
return tl::make_unexpected(StrCat("Asset not found: ", path));
219+
}
220+
221+
const size_t size = ref.size();
222+
std::unique_ptr<char[]> data { new char[size] };
223+
224+
AssetHandle handle = OpenAsset(std::move(ref));
225+
if (!handle.ok()) {
226+
return tl::make_unexpected(StrCat("Failed to open asset: ", path, "\n", handle.error()));
227+
}
228+
229+
if (size > 0 && !handle.read(data.get(), size)) {
230+
return tl::make_unexpected(StrCat("Read failed: ", path, "\n", handle.error()));
231+
}
232+
233+
return AssetData { std::move(data), size };
234+
}
235+
209236
} // namespace devilution

Source/engine/assets.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <string_view>
88

99
#include <SDL.h>
10+
#include <expected.hpp>
1011

1112
#include "appfat.h"
1213
#include "diablo.h"
@@ -246,4 +247,16 @@ AssetHandle OpenAsset(std::string_view filename, size_t &fileSize, bool threadsa
246247

247248
SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe = false);
248249

250+
struct AssetData {
251+
std::unique_ptr<char[]> data;
252+
size_t size;
253+
254+
explicit operator std::string_view() const
255+
{
256+
return std::string_view(data.get(), size);
257+
}
258+
};
259+
260+
tl::expected<AssetData, std::string> LoadAsset(std::string_view path);
261+
249262
} // namespace devilution

Source/lua/lua.cpp

+73-31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <optional>
44
#include <string_view>
5+
#include <unordered_map>
56

67
#include <sol/sol.hpp>
78

@@ -17,7 +18,58 @@ namespace devilution {
1718

1819
namespace {
1920

20-
std::optional<sol::state> luaState;
21+
struct LuaState {
22+
sol::state sol;
23+
std::unordered_map<std::string, sol::bytecode> compiledScripts;
24+
};
25+
26+
std::optional<LuaState> CurrentLuaState;
27+
28+
// A Lua function that we use to generate a `require` implementation.
29+
constexpr std::string_view RequireGenSrc = R"(
30+
function requireGen(loaded, loadFn)
31+
return function(packageName)
32+
local p = loaded[packageName]
33+
if p == nil then
34+
local loader = loadFn(packageName)
35+
if type(loader) == "string" then
36+
error(loader)
37+
end
38+
p = loader(packageName)
39+
loaded[packageName] = p
40+
end
41+
return p
42+
end
43+
end
44+
)";
45+
46+
sol::object LuaLoadScriptFromAssets(std::string_view packageName)
47+
{
48+
LuaState &luaState = *CurrentLuaState;
49+
std::string path { packageName };
50+
std::replace(path.begin(), path.end(), '.', '\\');
51+
path.append(".lua");
52+
53+
auto iter = luaState.compiledScripts.find(path);
54+
if (iter != luaState.compiledScripts.end()) {
55+
return luaState.sol.load(iter->second.as_string_view(), path, sol::load_mode::binary);
56+
}
57+
58+
tl::expected<AssetData, std::string> assetData = LoadAsset(path);
59+
if (!assetData.has_value()) {
60+
sol::stack::push(luaState.sol.lua_state(), assetData.error());
61+
return sol::stack_object(luaState.sol.lua_state(), -1);
62+
}
63+
sol::load_result result = luaState.sol.load(std::string_view(*assetData), path, sol::load_mode::text);
64+
if (!result.valid()) {
65+
sol::stack::push(luaState.sol.lua_state(),
66+
StrCat("Lua error when loading ", path, ": ", result.get<std::string>()));
67+
return sol::stack_object(luaState.sol.lua_state(), -1);
68+
}
69+
const sol::function fn = result;
70+
luaState.compiledScripts[path] = fn.dump();
71+
return result;
72+
}
2173

2274
int LuaPrint(lua_State *state)
2375
{
@@ -50,29 +102,15 @@ bool CheckResult(sol::protected_function_result result, bool optional)
50102

51103
void RunScript(std::string_view path, bool optional)
52104
{
53-
AssetRef ref = FindAsset(path);
54-
if (!ref.ok()) {
55-
if (!optional)
56-
app_fatal(StrCat("Asset not found: ", path));
57-
return;
58-
}
105+
tl::expected<AssetData, std::string> assetData = LoadAsset(path);
59106

60-
const size_t size = ref.size();
61-
std::unique_ptr<char[]> luaScript { new char[size] };
62-
63-
AssetHandle handle = OpenAsset(std::move(ref));
64-
if (!handle.ok()) {
65-
app_fatal(StrCat("Failed to open asset: ", path, "\n", handle.error()));
66-
return;
67-
}
68-
69-
if (size > 0 && !handle.read(luaScript.get(), size)) {
70-
app_fatal(StrCat("Read failed: ", path, "\n", handle.error()));
107+
if (!assetData.has_value()) {
108+
if (!optional)
109+
app_fatal(assetData.error());
71110
return;
72111
}
73112

74-
const std::string_view luaScriptStr(luaScript.get(), size);
75-
CheckResult(luaState->safe_script(luaScriptStr), optional);
113+
CheckResult(CurrentLuaState->sol.safe_script(std::string_view(*assetData)), optional);
76114
}
77115

78116
void LuaPanic(sol::optional<std::string> message)
@@ -95,8 +133,11 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state)
95133

96134
void LuaInitialize()
97135
{
98-
luaState.emplace(sol::c_call<decltype(&LuaPanic), &LuaPanic>);
99-
sol::state &lua = *luaState;
136+
CurrentLuaState.emplace(LuaState {
137+
.sol = { sol::c_call<decltype(&LuaPanic), &LuaPanic> },
138+
.compiledScripts = {},
139+
});
140+
sol::state &lua = CurrentLuaState->sol;
100141
lua.open_libraries(
101142
sol::lib::base,
102143
sol::lib::package,
@@ -116,11 +157,12 @@ void LuaInitialize()
116157
"_VERSION", LUA_VERSION);
117158

118159
// Registering devilutionx object table
119-
lua.create_named_table(
120-
"devilutionx",
121-
"log", LuaLogModule(lua),
122-
"render", LuaRenderModule(lua),
123-
"message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
160+
CheckResult(lua.safe_script(RequireGenSrc), /*optional=*/false);
161+
const sol::table loaded = lua.create_table_with(
162+
"devilutionx.log", LuaLogModule(lua),
163+
"devilutionx.render", LuaRenderModule(lua),
164+
"devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
165+
lua["require"] = lua["requireGen"](loaded, LuaLoadScriptFromAssets);
124166

125167
RunScript("lua\\init.lua", /*optional=*/false);
126168
RunScript("lua\\user.lua", /*optional=*/true);
@@ -130,12 +172,12 @@ void LuaInitialize()
130172

131173
void LuaShutdown()
132174
{
133-
luaState = std::nullopt;
175+
CurrentLuaState = std::nullopt;
134176
}
135177

136178
void LuaEvent(std::string_view name)
137179
{
138-
const sol::state &lua = *luaState;
180+
const sol::state &lua = CurrentLuaState->sol;
139181
const auto trigger = lua.traverse_get<std::optional<sol::object>>("Events", name, "Trigger");
140182
if (!trigger.has_value() || !trigger->is<sol::protected_function>()) {
141183
LogError("Events.{}.Trigger is not a function", name);
@@ -145,9 +187,9 @@ void LuaEvent(std::string_view name)
145187
CheckResult(fn(), /*optional=*/true);
146188
}
147189

148-
sol::state &LuaState()
190+
sol::state &GetLuaState()
149191
{
150-
return *luaState;
192+
return CurrentLuaState->sol;
151193
}
152194

153195
} // namespace devilution

Source/lua/lua.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ namespace devilution {
1313
void LuaInitialize();
1414
void LuaShutdown();
1515
void LuaEvent(std::string_view name);
16-
sol::state &LuaState();
16+
sol::state &GetLuaState();
1717

1818
} // namespace devilution

Source/lua/repl.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ int LuaPrintToConsole(lua_State *state)
3838

3939
void CreateReplEnvironment()
4040
{
41-
sol::state &lua = LuaState();
41+
sol::state &lua = GetLuaState();
4242
replEnv.emplace(lua, sol::create, lua.globals());
4343
replEnv->set("print", LuaPrintToConsole);
4444
}
@@ -53,7 +53,7 @@ sol::environment &ReplEnvironment()
5353
sol::protected_function_result TryRunLuaAsExpressionThenStatement(std::string_view code)
5454
{
5555
// Try to compile as an expression first. This also how the `lua` repl is implemented.
56-
sol::state &lua = LuaState();
56+
sol::state &lua = GetLuaState();
5757
std::string expression = StrCat("return ", code, ";");
5858
sol::detail::typical_chunk_name_t basechunkname = {};
5959
sol::load_status status = static_cast<sol::load_status>(

0 commit comments

Comments
 (0)