Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/envoy/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,16 @@ class Loader {
/**
* @return const Snapshot& the current snapshot. This reference is safe to use for the duration of
* the calling routine, but may be overwritten on a future event loop cycle so should be
* fetched again when needed.
* fetched again when needed. This may only be called from worker threads.
*/
virtual const Snapshot& snapshot() PURE;

/**
* @return shared_ptr<const Snapshot> the current snapshot. This function may safely be called
* from non-worker theads.
*/
virtual std::shared_ptr<const Snapshot> threadsafeSnapshot() PURE;

/**
* Merge the given map of key-value pairs into the runtime's state. To remove a previous merge for
* a key, use an empty string as the value.
Expand Down
10 changes: 10 additions & 0 deletions include/envoy/thread_local/thread_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ class Slot {
public:
virtual ~Slot() = default;

/**
* Returns if there is thread local data for this thread.
*
* This should return true for Envoy worker threads and false for threads which do not have thread
* local storage allocated.
*
* @return true if registerThread has been called for this thread, false otherwise.
*/
virtual bool currentThreadRegistered() PURE;

/**
* @return ThreadLocalObjectSharedPtr a thread local object stored in the slot.
*/
Expand Down
30 changes: 25 additions & 5 deletions source/common/runtime/runtime_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ namespace Runtime {
bool runtimeFeatureEnabled(absl::string_view feature) {
ASSERT(absl::StartsWith(feature, "envoy.reloadable_features"));
if (Runtime::LoaderSingleton::getExisting()) {
return Runtime::LoaderSingleton::getExisting()->snapshot().runtimeFeatureEnabled(feature);
return Runtime::LoaderSingleton::getExisting()->threadsafeSnapshot()->runtimeFeatureEnabled(
feature);
}
ENVOY_LOG_TO_LOGGER(Envoy::Logger::Registry::getLog(Envoy::Logger::Id::runtime), warn,
"Unable to use runtime singleton for feature {}", feature);
Expand Down Expand Up @@ -551,13 +552,32 @@ void RtdsSubscription::validateUpdateSize(uint32_t num_resources) {
}

void LoaderImpl::loadNewSnapshot() {
ThreadLocal::ThreadLocalObjectSharedPtr ptr = createNewSnapshot();
tls_->set([ptr = std::move(ptr)](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return ptr;
std::shared_ptr<SnapshotImpl> ptr = createNewSnapshot();
tls_->set([ptr](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::static_pointer_cast<ThreadLocal::ThreadLocalObject>(ptr);
});

{
absl::MutexLock lock(&snapshot_mutex_);
thread_safe_snapshot_ = ptr;
}
}

const Snapshot& LoaderImpl::snapshot() {
ASSERT(tls_->currentThreadRegistered(), "snapshot can only be called from a worker thread");
return tls_->getTyped<Snapshot>();
}

const Snapshot& LoaderImpl::snapshot() { return tls_->getTyped<Snapshot>(); }
std::shared_ptr<const Snapshot> LoaderImpl::threadsafeSnapshot() {
if (tls_->currentThreadRegistered()) {
return std::dynamic_pointer_cast<const Snapshot>(tls_->get());
}

{
absl::ReaderMutexLock lock(&snapshot_mutex_);
return thread_safe_snapshot_;
}
}

void LoaderImpl::mergeValues(const std::unordered_map<std::string, std::string>& values) {
if (admin_layer_ == nullptr) {
Expand Down
4 changes: 4 additions & 0 deletions source/common/runtime/runtime_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class LoaderImpl : public Loader, Logger::Loggable<Logger::Id::runtime> {
// Runtime::Loader
void initialize(Upstream::ClusterManager& cm) override;
const Snapshot& snapshot() override;
std::shared_ptr<const Snapshot> threadsafeSnapshot() override;
void mergeValues(const std::unordered_map<std::string, std::string>& values) override;

private:
Expand All @@ -265,6 +266,9 @@ class LoaderImpl : public Loader, Logger::Loggable<Logger::Id::runtime> {
Api::Api& api_;
std::vector<RtdsSubscriptionPtr> subscriptions_;
Upstream::ClusterManager* cm_{};

absl::Mutex snapshot_mutex_;
std::shared_ptr<const Snapshot> thread_safe_snapshot_ GUARDED_BY(snapshot_mutex_);
};

} // namespace Runtime
Expand Down
6 changes: 5 additions & 1 deletion source/common/thread_local/thread_local_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ SlotPtr InstanceImpl::allocateSlot() {
return slot;
}

bool InstanceImpl::SlotImpl::currentThreadRegistered() {
return thread_local_data_.data_.size() > index_;
}

ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() {
ASSERT(thread_local_data_.data_.size() > index_);
ASSERT(currentThreadRegistered());
return thread_local_data_.data_[index_];
}

Expand Down
1 change: 1 addition & 0 deletions source/common/thread_local/thread_local_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class InstanceImpl : Logger::Loggable<Logger::Id::main>, public Instance {

// ThreadLocal::Slot
ThreadLocalObjectSharedPtr get() override;
bool currentThreadRegistered() override;
void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); }
void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override {
parent_.runOnAllThreads(cb, main_callback);
Expand Down
40 changes: 40 additions & 0 deletions test/common/runtime/runtime_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,46 @@ TEST_F(StaticLoaderImplTest, ProtoParsing) {
EXPECT_EQ(2, store_.gauge("runtime.num_layers", Stats::Gauge::ImportMode::NeverImport).value());
}

TEST_F(StaticLoaderImplTest, RuntimeFromNonWorkerThreads) {
// Force the thread to be considered a non-worker thread.
tls_.registered_ = false;
setup();

// Set up foo -> bar
loader_->mergeValues({{"foo", "bar"}});
EXPECT_EQ("bar", loader_->threadsafeSnapshot()->get("foo"));
const Snapshot* original_snapshot_pointer = loader_->threadsafeSnapshot().get();

// Now set up a test thread which verifies foo -> bar
//
// Then change foo and make sure the test thread picks up the change.
Thread::MutexBasicLockable mutex;
Thread::CondVar foo_read;
Thread::CondVar foo_changed;
const Snapshot* original_thread_snapshot_pointer = nullptr;
auto thread = Thread::threadFactoryForTest().createThread([&]() {
Thread::LockGuard lock(mutex);
EXPECT_EQ("bar", loader_->threadsafeSnapshot()->get("foo"));
original_thread_snapshot_pointer = loader_->threadsafeSnapshot().get();
EXPECT_EQ(original_thread_snapshot_pointer, loader_->threadsafeSnapshot().get());
foo_read.notifyOne();

foo_changed.wait(mutex);
EXPECT_EQ("eep", loader_->threadsafeSnapshot()->get("foo"));
});

{
Thread::LockGuard lock(mutex);
foo_read.wait(mutex);
loader_->mergeValues({{"foo", "eep"}});
foo_changed.notifyOne();
EXPECT_EQ("eep", loader_->threadsafeSnapshot()->get("foo"));
}

thread->join();
EXPECT_EQ(original_thread_snapshot_pointer, original_snapshot_pointer);
}

class DiskLayerTest : public testing::Test {
protected:
DiskLayerTest() : api_(Api::createApiForTest()) {}
Expand Down
1 change: 1 addition & 0 deletions test/mocks/runtime/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class MockLoader : public Loader {

MOCK_METHOD1(initialize, void(Upstream::ClusterManager& cm));
MOCK_METHOD0(snapshot, const Snapshot&());
MOCK_METHOD0(threadsafeSnapshot, std::shared_ptr<const Snapshot>());
MOCK_METHOD1(mergeValues, void(const std::unordered_map<std::string, std::string>&));

testing::NiceMock<MockSnapshot> snapshot_;
Expand Down
2 changes: 2 additions & 0 deletions test/mocks/thread_local/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class MockInstance : public Instance {

// ThreadLocal::Slot
ThreadLocalObjectSharedPtr get() override { return parent_.data_[index_]; }
bool currentThreadRegistered() override { return parent_.registered_; }
void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); }
void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override {
parent_.runOnAllThreads(cb, main_callback);
Expand All @@ -72,6 +73,7 @@ class MockInstance : public Instance {
testing::NiceMock<Event::MockDispatcher> dispatcher_;
std::vector<ThreadLocalObjectSharedPtr> data_;
bool shutdown_{};
bool registered_{true};
};

} // namespace ThreadLocal
Expand Down