Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions source/extensions/filters/common/lua/lua.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ int ThreadLocalState::getGlobalRef(uint64_t slot) {
}

uint64_t ThreadLocalState::registerGlobal(const std::string& global) {
tls_slot_->runOnAllThreads([this, global]() {
LuaThreadLocal& tls = tls_slot_->getTyped<LuaThreadLocal>();
tls_slot_->runOnAllThreads([global](ThreadLocal::ThreadLocalObjectSharedPtr previous) {
LuaThreadLocal& tls = *std::dynamic_pointer_cast<LuaThreadLocal>(previous);
lua_getglobal(tls.state_.get(), global.c_str());
if (lua_isfunction(tls.state_.get(), -1)) {
tls.global_slots_.push_back(luaL_ref(tls.state_.get(), LUA_REGISTRYINDEX));
Expand All @@ -81,6 +81,7 @@ uint64_t ThreadLocalState::registerGlobal(const std::string& global) {
lua_pop(tls.state_.get(), 1);
tls.global_slots_.push_back(LUA_REFNIL);
}
return previous;
});

return current_global_slot_++;
Expand Down
7 changes: 5 additions & 2 deletions source/extensions/filters/common/lua/lua.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,11 @@ class ThreadLocalState : Logger::Loggable<Logger::Id::lua> {
* all threaded workers.
*/
template <class T> void registerType() {
tls_slot_->runOnAllThreads(
[this]() { T::registerType(tls_slot_->getTyped<LuaThreadLocal>().state_.get()); });
tls_slot_->runOnAllThreads([](ThreadLocal::ThreadLocalObjectSharedPtr previous) {
LuaThreadLocal& tls = *std::dynamic_pointer_cast<LuaThreadLocal>(previous);
T::registerType(tls.state_.get());
return previous;
});
}

/**
Expand Down
1 change: 1 addition & 0 deletions test/extensions/filters/common/lua/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ envoy_cc_test(
srcs = ["lua_test.cc"],
tags = ["skip_on_windows"],
deps = [
"//source/common/thread_local:thread_local_lib",
"//source/extensions/filters/common/lua:lua_lib",
"//test/mocks:common_lib",
"//test/mocks/thread_local:thread_local_mocks",
Expand Down
51 changes: 51 additions & 0 deletions test/extensions/filters/common/lua/lua_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include <memory>

#include "common/thread_local/thread_local_impl.h"

#include "extensions/filters/common/lua/lua.h"

#include "test/mocks/common.h"
Expand Down Expand Up @@ -157,6 +159,55 @@ TEST_F(LuaTest, MarkDead) {
lua_gc(cr1->luaState(), LUA_GCCOLLECT, 0);
}

class ThreadSafeTest : public testing::Test {
public:
ThreadSafeTest()
: api_(Api::createApiForTest()), main_dispatcher_(api_->allocateDispatcher("main")),
worker_dispatcher_(api_->allocateDispatcher("worker")) {}

// Use real dispatchers to verify that callback functions can be executed correctly.
Api::ApiPtr api_;
Event::DispatcherPtr main_dispatcher_;
Event::DispatcherPtr worker_dispatcher_;
ThreadLocal::InstanceImpl tls_;

std::unique_ptr<ThreadLocalState> state_;
};

// Test whether ThreadLocalState can be safely released.
TEST_F(ThreadSafeTest, StateDestructedBeforeWorkerRun) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test fail without the code change in this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will be a crash.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for verifying!

const std::string SCRIPT{R"EOF(
function HelloWorld()
print("Hello World!")
end
)EOF"};

tls_.registerThread(*main_dispatcher_, true);
EXPECT_EQ(main_dispatcher_.get(), &tls_.dispatcher());
tls_.registerThread(*worker_dispatcher_, false);

// Some callback functions waiting to be executed will be added to the dispatcher of the Worker
// thread. The callback functions in the main thread will be executed directly.
state_ = std::make_unique<ThreadLocalState>(SCRIPT, tls_);
state_->registerType<TestObject>();

main_dispatcher_->run(Event::Dispatcher::RunType::Block);

// Destroy state_.
state_.reset(nullptr);

// Start a new worker thread to execute the callback functions in the worker dispatcher.
Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([this]() {
worker_dispatcher_->run(Event::Dispatcher::RunType::Block);
// Verify we have the expected dispatcher for the new worker thread.
EXPECT_EQ(worker_dispatcher_.get(), &tls_.dispatcher());
});
thread->join();

tls_.shutdownGlobalThreading();
tls_.shutdownThread();
}

} // namespace
} // namespace Lua
} // namespace Common
Expand Down