Skip to content

Commit

Permalink
[ntcore] Queue current value on subscriber creation (wpilibsuite#4938)
Browse files Browse the repository at this point in the history
This fixes a potential race condition in code that only uses readQueue.
  • Loading branch information
PeterJohnson authored and Starlight220 committed Mar 2, 2023
1 parent b31c35d commit eea50ed
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
15 changes: 15 additions & 0 deletions ntcore/src/main/native/cpp/LocalStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ struct TopicData {
NT_Entry entry{0}; // cached entry for GetEntry()

bool onNetwork{false}; // true if there are any remote publishers
bool lastValueFromNetwork{false};

wpi::SmallVector<DataLoggerEntry, 1> datalogs;
NT_Type datalogType{NT_UNASSIGNED};
Expand Down Expand Up @@ -484,6 +485,7 @@ void LSImpl::CheckReset(TopicData* topic) {
}
topic->lastValue = {};
topic->lastValueNetwork = {};
topic->lastValueFromNetwork = false;
topic->type = NT_UNASSIGNED;
topic->typeStr.clear();
topic->flags = 0;
Expand All @@ -503,6 +505,7 @@ bool LSImpl::SetValue(TopicData* topic, const Value& value,
// TODO: notify option even if older value
topic->type = value.type();
topic->lastValue = value;
topic->lastValueFromNetwork = false;
NotifyValue(topic, eventFlags, isDuplicate, publisher);
}
if (!isDuplicate && topic->datalogType == value.type()) {
Expand Down Expand Up @@ -858,6 +861,17 @@ SubscriberData* LSImpl::AddLocalSubscriber(TopicData* topic,
DEBUG4("-> NetworkSubscribe({})", topic->name);
m_network->Subscribe(subscriber->handle, {{topic->name}}, config);
}

// queue current value
if (subscriber->active) {
if (!topic->lastValueFromNetwork && !config.disableLocal) {
subscriber->pollStorage.emplace_back(topic->lastValue);
subscriber->handle.Set();
} else if (topic->lastValueFromNetwork && !config.disableRemote) {
subscriber->pollStorage.emplace_back(topic->lastValueNetwork);
subscriber->handle.Set();
}
}
return subscriber;
}

Expand Down Expand Up @@ -1376,6 +1390,7 @@ void LocalStorage::NetworkSetValue(NT_Topic topicHandle, const Value& value) {
if (m_impl->SetValue(topic, value, NT_EVENT_VALUE_REMOTE,
value == topic->lastValue, nullptr)) {
topic->lastValueNetwork = value;
topic->lastValueFromNetwork = true;
}
}
}
Expand Down
47 changes: 44 additions & 3 deletions ntcore/src/test/native/cpp/LocalStorageTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,6 @@ TEST_F(LocalStorageTest, SubscribeNoTypeLocalPubPre) {
ASSERT_TRUE(value.IsBoolean());
EXPECT_EQ(value.GetBoolean(), true);
EXPECT_EQ(value.time(), 5);

auto vals = storage.ReadQueueValue(sub); // read queue won't get anything
ASSERT_TRUE(vals.empty());
}

TEST_F(LocalStorageTest, EntryNoTypeLocalSet) {
Expand Down Expand Up @@ -916,4 +913,48 @@ TEST_F(LocalStorageTest, EntryExcludeSelf) {
ElementsAre(TSEq<TimestampedDouble>(2.0, 60)));
}

TEST_F(LocalStorageTest, ReadQueueInitialLocal) {
EXPECT_CALL(network, Publish(_, _, _, _, _, _));
EXPECT_CALL(network, SetValue(_, _));
EXPECT_CALL(network, Subscribe(_, _, _)).Times(3);

auto pub = storage.Publish(fooTopic, NT_DOUBLE, "double", {}, {});
storage.SetEntryValue(pub, Value::MakeDouble(1.0, 50));

auto subBoth =
storage.Subscribe(fooTopic, NT_DOUBLE, "double", kDefaultPubSubOptions);
auto subLocal =
storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableRemote = true});
auto subRemote =
storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableLocal = true});

EXPECT_THAT(storage.ReadQueueDouble(subBoth),
ElementsAre(TSEq<TimestampedDouble>(1.0, 50)));
EXPECT_THAT(storage.ReadQueueDouble(subLocal),
ElementsAre(TSEq<TimestampedDouble>(1.0, 50)));
EXPECT_THAT(storage.ReadQueueDouble(subRemote), IsEmpty());
}

TEST_F(LocalStorageTest, ReadQueueInitialRemote) {
EXPECT_CALL(network, Subscribe(_, _, _)).Times(3);

auto remoteTopic =
storage.NetworkAnnounce("foo", "double", wpi::json::object(), 0);
storage.NetworkSetValue(remoteTopic, Value::MakeDouble(2.0, 60));

auto subBoth =
storage.Subscribe(fooTopic, NT_DOUBLE, "double", kDefaultPubSubOptions);
auto subLocal =
storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableRemote = true});
auto subRemote =
storage.Subscribe(fooTopic, NT_DOUBLE, "double", {.disableLocal = true});

// network set
EXPECT_THAT(storage.ReadQueueDouble(subBoth),
ElementsAre(TSEq<TimestampedDouble>(2.0, 60)));
EXPECT_THAT(storage.ReadQueueDouble(subRemote),
ElementsAre(TSEq<TimestampedDouble>(2.0, 60)));
EXPECT_THAT(storage.ReadQueueDouble(subLocal), IsEmpty());
}

} // namespace nt

0 comments on commit eea50ed

Please sign in to comment.