2
2
3
3
#include < optional>
4
4
#include < string_view>
5
+ #include < unordered_map>
5
6
6
7
#include < sol/sol.hpp>
7
8
@@ -17,7 +18,58 @@ namespace devilution {
17
18
18
19
namespace {
19
20
20
- std::optional<sol::state> luaState;
21
+ struct LuaState {
22
+ sol::state sol;
23
+ std::unordered_map<std::string, sol::bytecode> compiledScripts;
24
+ };
25
+
26
+ std::optional<LuaState> CurrentLuaState;
27
+
28
+ // A Lua function that we use to generate a `require` implementation.
29
+ constexpr std::string_view RequireGenSrc = R"(
30
+ function requireGen(loaded, loadFn)
31
+ return function(packageName)
32
+ local p = loaded[packageName]
33
+ if p == nil then
34
+ local loader = loadFn(packageName)
35
+ if type(loader) == "string" then
36
+ error(loader)
37
+ end
38
+ p = loader(packageName)
39
+ loaded[packageName] = p
40
+ end
41
+ return p
42
+ end
43
+ end
44
+ )" ;
45
+
46
+ sol::object LuaLoadScriptFromAssets (std::string_view packageName)
47
+ {
48
+ LuaState &luaState = *CurrentLuaState;
49
+ std::string path { packageName };
50
+ std::replace (path.begin (), path.end (), ' .' , ' \\ ' );
51
+ path.append (" .lua" );
52
+
53
+ auto iter = luaState.compiledScripts .find (path);
54
+ if (iter != luaState.compiledScripts .end ()) {
55
+ return luaState.sol .load (iter->second .as_string_view (), path, sol::load_mode::binary);
56
+ }
57
+
58
+ tl::expected<AssetData, std::string> assetData = LoadAsset (path);
59
+ if (!assetData.has_value ()) {
60
+ sol::stack::push (luaState.sol .lua_state (), assetData.error ());
61
+ return sol::stack_object (luaState.sol .lua_state (), -1 );
62
+ }
63
+ sol::load_result result = luaState.sol .load (std::string_view (*assetData), path, sol::load_mode::text);
64
+ if (!result.valid ()) {
65
+ sol::stack::push (luaState.sol .lua_state (),
66
+ StrCat (" Lua error when loading " , path, " : " , result.get <std::string>()));
67
+ return sol::stack_object (luaState.sol .lua_state (), -1 );
68
+ }
69
+ const sol::function fn = result;
70
+ luaState.compiledScripts [path] = fn.dump ();
71
+ return result;
72
+ }
21
73
22
74
int LuaPrint (lua_State *state)
23
75
{
@@ -50,29 +102,15 @@ bool CheckResult(sol::protected_function_result result, bool optional)
50
102
51
103
void RunScript (std::string_view path, bool optional)
52
104
{
53
- AssetRef ref = FindAsset (path);
54
- if (!ref.ok ()) {
55
- if (!optional)
56
- app_fatal (StrCat (" Asset not found: " , path));
57
- return ;
58
- }
105
+ tl::expected<AssetData, std::string> assetData = LoadAsset (path);
59
106
60
- const size_t size = ref.size ();
61
- std::unique_ptr<char []> luaScript { new char [size] };
62
-
63
- AssetHandle handle = OpenAsset (std::move (ref));
64
- if (!handle.ok ()) {
65
- app_fatal (StrCat (" Failed to open asset: " , path, " \n " , handle.error ()));
66
- return ;
67
- }
68
-
69
- if (size > 0 && !handle.read (luaScript.get (), size)) {
70
- app_fatal (StrCat (" Read failed: " , path, " \n " , handle.error ()));
107
+ if (!assetData.has_value ()) {
108
+ if (!optional)
109
+ app_fatal (assetData.error ());
71
110
return ;
72
111
}
73
112
74
- const std::string_view luaScriptStr (luaScript.get (), size);
75
- CheckResult (luaState->safe_script (luaScriptStr), optional);
113
+ CheckResult (CurrentLuaState->sol .safe_script (std::string_view (*assetData)), optional);
76
114
}
77
115
78
116
void LuaPanic (sol::optional<std::string> message)
@@ -95,8 +133,11 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state)
95
133
96
134
void LuaInitialize ()
97
135
{
98
- luaState.emplace (sol::c_call<decltype (&LuaPanic), &LuaPanic>);
99
- sol::state &lua = *luaState;
136
+ CurrentLuaState.emplace (LuaState {
137
+ .sol = { sol::c_call<decltype (&LuaPanic), &LuaPanic> },
138
+ .compiledScripts = {},
139
+ });
140
+ sol::state &lua = CurrentLuaState->sol ;
100
141
lua.open_libraries (
101
142
sol::lib::base,
102
143
sol::lib::package,
@@ -116,11 +157,12 @@ void LuaInitialize()
116
157
" _VERSION" , LUA_VERSION);
117
158
118
159
// Registering devilutionx object table
119
- lua.create_named_table (
120
- " devilutionx" ,
121
- " log" , LuaLogModule (lua),
122
- " render" , LuaRenderModule (lua),
123
- " message" , [](std::string_view text) { EventPlrMsg (text, UiFlags::ColorRed); });
160
+ CheckResult (lua.safe_script (RequireGenSrc), /* optional=*/ false );
161
+ const sol::table loaded = lua.create_table_with (
162
+ " devilutionx.log" , LuaLogModule (lua),
163
+ " devilutionx.render" , LuaRenderModule (lua),
164
+ " devilutionx.message" , [](std::string_view text) { EventPlrMsg (text, UiFlags::ColorRed); });
165
+ lua[" require" ] = lua[" requireGen" ](loaded, LuaLoadScriptFromAssets);
124
166
125
167
RunScript (" lua\\ init.lua" , /* optional=*/ false );
126
168
RunScript (" lua\\ user.lua" , /* optional=*/ true );
@@ -130,12 +172,12 @@ void LuaInitialize()
130
172
131
173
void LuaShutdown ()
132
174
{
133
- luaState = std::nullopt;
175
+ CurrentLuaState = std::nullopt;
134
176
}
135
177
136
178
void LuaEvent (std::string_view name)
137
179
{
138
- const sol::state &lua = *luaState ;
180
+ const sol::state &lua = CurrentLuaState-> sol ;
139
181
const auto trigger = lua.traverse_get <std::optional<sol::object>>(" Events" , name, " Trigger" );
140
182
if (!trigger.has_value () || !trigger->is <sol::protected_function>()) {
141
183
LogError (" Events.{}.Trigger is not a function" , name);
@@ -145,9 +187,9 @@ void LuaEvent(std::string_view name)
145
187
CheckResult (fn (), /* optional=*/ true );
146
188
}
147
189
148
- sol::state &LuaState ()
190
+ sol::state &GetLuaState ()
149
191
{
150
- return *luaState ;
192
+ return CurrentLuaState-> sol ;
151
193
}
152
194
153
195
} // namespace devilution
0 commit comments