diff --git a/source/extensions/filters/common/lua/lua.cc b/source/extensions/filters/common/lua/lua.cc index 02a45f817ec57..c907fef9fd6a4 100644 --- a/source/extensions/filters/common/lua/lua.cc +++ b/source/extensions/filters/common/lua/lua.cc @@ -71,8 +71,8 @@ int ThreadLocalState::getGlobalRef(uint64_t slot) { } uint64_t ThreadLocalState::registerGlobal(const std::string& global) { - tls_slot_->runOnAllThreads([this, global]() { - LuaThreadLocal& tls = tls_slot_->getTyped(); + tls_slot_->runOnAllThreads([global](ThreadLocal::ThreadLocalObjectSharedPtr previous) { + LuaThreadLocal& tls = *std::dynamic_pointer_cast(previous); lua_getglobal(tls.state_.get(), global.c_str()); if (lua_isfunction(tls.state_.get(), -1)) { tls.global_slots_.push_back(luaL_ref(tls.state_.get(), LUA_REGISTRYINDEX)); @@ -81,6 +81,7 @@ uint64_t ThreadLocalState::registerGlobal(const std::string& global) { lua_pop(tls.state_.get(), 1); tls.global_slots_.push_back(LUA_REFNIL); } + return previous; }); return current_global_slot_++; diff --git a/source/extensions/filters/common/lua/lua.h b/source/extensions/filters/common/lua/lua.h index b9bb7caa157ea..7071b375303fa 100644 --- a/source/extensions/filters/common/lua/lua.h +++ b/source/extensions/filters/common/lua/lua.h @@ -386,8 +386,11 @@ class ThreadLocalState : Logger::Loggable { * all threaded workers. */ template void registerType() { - tls_slot_->runOnAllThreads( - [this]() { T::registerType(tls_slot_->getTyped().state_.get()); }); + tls_slot_->runOnAllThreads([](ThreadLocal::ThreadLocalObjectSharedPtr previous) { + LuaThreadLocal& tls = *std::dynamic_pointer_cast(previous); + T::registerType(tls.state_.get()); + return previous; + }); } /** diff --git a/test/extensions/filters/common/lua/BUILD b/test/extensions/filters/common/lua/BUILD index b6d7bfecd6d54..88d42f01aab0b 100644 --- a/test/extensions/filters/common/lua/BUILD +++ b/test/extensions/filters/common/lua/BUILD @@ -14,6 +14,7 @@ envoy_cc_test( srcs = ["lua_test.cc"], tags = ["skip_on_windows"], deps = [ + "//source/common/thread_local:thread_local_lib", "//source/extensions/filters/common/lua:lua_lib", "//test/mocks:common_lib", "//test/mocks/thread_local:thread_local_mocks", diff --git a/test/extensions/filters/common/lua/lua_test.cc b/test/extensions/filters/common/lua/lua_test.cc index b5770a0b20d79..5f4462e7d3c4f 100644 --- a/test/extensions/filters/common/lua/lua_test.cc +++ b/test/extensions/filters/common/lua/lua_test.cc @@ -1,5 +1,7 @@ #include +#include "common/thread_local/thread_local_impl.h" + #include "extensions/filters/common/lua/lua.h" #include "test/mocks/common.h" @@ -157,6 +159,55 @@ TEST_F(LuaTest, MarkDead) { lua_gc(cr1->luaState(), LUA_GCCOLLECT, 0); } +class ThreadSafeTest : public testing::Test { +public: + ThreadSafeTest() + : api_(Api::createApiForTest()), main_dispatcher_(api_->allocateDispatcher("main")), + worker_dispatcher_(api_->allocateDispatcher("worker")) {} + + // Use real dispatchers to verify that callback functions can be executed correctly. + Api::ApiPtr api_; + Event::DispatcherPtr main_dispatcher_; + Event::DispatcherPtr worker_dispatcher_; + ThreadLocal::InstanceImpl tls_; + + std::unique_ptr state_; +}; + +// Test whether ThreadLocalState can be safely released. +TEST_F(ThreadSafeTest, StateDestructedBeforeWorkerRun) { + const std::string SCRIPT{R"EOF( + function HelloWorld() + print("Hello World!") + end + )EOF"}; + + tls_.registerThread(*main_dispatcher_, true); + EXPECT_EQ(main_dispatcher_.get(), &tls_.dispatcher()); + tls_.registerThread(*worker_dispatcher_, false); + + // Some callback functions waiting to be executed will be added to the dispatcher of the Worker + // thread. The callback functions in the main thread will be executed directly. + state_ = std::make_unique(SCRIPT, tls_); + state_->registerType(); + + main_dispatcher_->run(Event::Dispatcher::RunType::Block); + + // Destroy state_. + state_.reset(nullptr); + + // Start a new worker thread to execute the callback functions in the worker dispatcher. + Thread::ThreadPtr thread = Thread::threadFactoryForTest().createThread([this]() { + worker_dispatcher_->run(Event::Dispatcher::RunType::Block); + // Verify we have the expected dispatcher for the new worker thread. + EXPECT_EQ(worker_dispatcher_.get(), &tls_.dispatcher()); + }); + thread->join(); + + tls_.shutdownGlobalThreading(); + tls_.shutdownThread(); +} + } // namespace } // namespace Lua } // namespace Common