Skip to content

Commit

Permalink
Fixed multiple bugs in data store (#2325)
Browse files Browse the repository at this point in the history
* Fixed a bug in the data store where the call to get_empty_node was not
using a proper mutex for the m_data structure.

Added additional instances of locking to avoid thread race conditions.

Fixed shared memory segment so that only rank 0 on a node will try to
unlink it.

* Fixed the order of memory unmapping and unlinking.

* Added clang-format

* Fixed the locking within the exchange_local_cache function using more
fine grained locks to avoid nesting locks.
  • Loading branch information
bvanessen authored Sep 22, 2023
1 parent 572f119 commit 73ef72b
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 115 deletions.
10 changes: 8 additions & 2 deletions include/lbann/data_store/data_store_conduit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ class data_store_conduit
std::ofstream* m_profile = nullptr;

/// for use during development and debugging
int get_data_size() { return m_data.size(); }
int get_data_size()
{
std::lock_guard<std::mutex> lock(m_mutex);
return m_data.size();
}

/// made public for debugging during development
void copy_members(const data_store_conduit& rhs);
Expand Down Expand Up @@ -362,7 +366,9 @@ class data_store_conduit
map_ii_t m_spilled_nodes;

/// used in set_conduit_node(...)
std::mutex m_mutex;
// Guards m_sample_sizes, m_data, m_image_offsets, m_sample_sizes
mutable std::mutex m_mutex;
// Guards m_compact_sample_size
std::mutex m_mutex_2;

/// for use in local cache mode
Expand Down
242 changes: 129 additions & 113 deletions src/data_store/data_store_conduit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ data_store_conduit::~data_store_conduit()
m_profile->close();
}
if (m_is_local_cache && m_mem_seg) {
int sanity = shm_unlink(m_seg_name.c_str());
int sanity = munmap(reinterpret_cast<void*>(m_mem_seg), m_mem_seg_length);
if (sanity != 0) {
std::cout << "\nWARNING: shm_unlink failed in "
"data_store_conduit::~data_store_conduit()\n";
LBANN_WARNING("munmap failed in for m_seg_name ", m_seg_name);
}
sanity = munmap(reinterpret_cast<void*>(m_mem_seg), m_mem_seg_length);
if (sanity != 0) {
std::cout << "\nWARNING: munmap failed in "
"data_store_conduit::~data_store_conduit()\n";
if (m_comm->get_rank_in_node() == 0) {
sanity = shm_unlink(m_seg_name.c_str());
if (sanity != 0) {
LBANN_WARNING("shm_unlink failed for m_seg_name ", m_seg_name);
}
}
}
}
Expand Down Expand Up @@ -280,7 +280,7 @@ void data_store_conduit::spill_preloaded_conduit_node(int data_id,
// note: at this point m_data[data_id] = node
conduit::Node n3 = node;
{
std::lock_guard<std::mutex> lock(m_mutex);
// BVE 9-20-2023 std::lock_guard<std::mutex> lock(m_mutex);
build_node_for_sending(node, n3);
}
if (!m_node_sizes_vary) {
Expand Down Expand Up @@ -329,13 +329,13 @@ void data_store_conduit::set_preloaded_conduit_node(int data_id,
conduit::Node n2 = node; // node == m_data[data_id]
std::lock_guard<std::mutex> lock(m_mutex);
build_node_for_sending(n2, m_data[data_id]);
}
if (!m_node_sizes_vary) {
error_check_compacted_node(m_data[data_id], data_id);
}
else {
std::lock_guard<std::mutex> lock(m_mutex);
m_sample_sizes[data_id] = m_data[data_id].total_bytes_compact();

if (!m_node_sizes_vary) {
error_check_compacted_node(m_data[data_id], data_id);
}
else {
m_sample_sizes[data_id] = m_data[data_id].total_bytes_compact();
}
}
}

Expand Down Expand Up @@ -476,6 +476,7 @@ const conduit::Node& data_store_conduit::get_conduit_node(int data_id) const
{
using iterator_t = std::unordered_map<int, conduit::Node>::const_iterator;
if (is_local_cache()) {
std::lock_guard<std::mutex> lock(m_mutex);
iterator_t t3 = m_data.find(data_id);
if (t3 == m_data.end()) {
LBANN_ERROR("(local cache) failed to find data_id: ",
Expand All @@ -490,6 +491,7 @@ const conduit::Node& data_store_conduit::get_conduit_node(int data_id) const
// if not preloaded, and get_label() or get_response() is called,
// we need to check m_data
if (t2 == m_minibatch_data.end()) {
std::lock_guard<std::mutex> lock(m_mutex);
iterator_t t3 = m_data.find(data_id);
if (t3 != m_data.end()) {
return t3->second["data"];
Expand Down Expand Up @@ -593,113 +595,119 @@ void data_store_conduit::start_exchange_data_by_sample(size_t current_pos,
// part 2: exchange the actual data

// start sends for outgoing data
size_t ss = 0;
for (int p = 0; p < m_np_in_trainer; p++) {
const std::unordered_set<int>& indices = m_indices_to_send[p];
for (auto index : indices) {
if (m_data.find(index) == m_data.end()) {
LBANN_ERROR("failed to find data_id: ",
index,
" to be sent to ",
p,
" in m_data");
}
const conduit::Node& n = m_data[index];
const El::byte* s = reinterpret_cast<const El::byte*>(n.data_ptr());
if (!n.is_contiguous()) {
LBANN_ERROR("data_id: ", index, " does not have a contiguous layout");
}
if (n.data_ptr() == nullptr) {
LBANN_ERROR("data_id: ", index, " does not have a valid data pointer");
}
if (n.contiguous_data_ptr() == nullptr) {
LBANN_ERROR("data_id: ",
index,
" does not have a valid contiguous data pointer");
}
{
std::lock_guard<std::mutex> lock(m_mutex);
size_t ss = 0;
for (int p = 0; p < m_np_in_trainer; p++) {
const std::unordered_set<int>& indices = m_indices_to_send[p];
for (auto index : indices) {
if (m_data.find(index) == m_data.end()) {
LBANN_ERROR("failed to find data_id: ",
index,
" to be sent to ",
p,
" in m_data");
}
const conduit::Node& n = m_data[index];
const El::byte* s = reinterpret_cast<const El::byte*>(n.data_ptr());
if (!n.is_contiguous()) {
LBANN_ERROR("data_id: ", index, " does not have a contiguous layout");
}
if (n.data_ptr() == nullptr) {
LBANN_ERROR("data_id: ",
index,
" does not have a valid data pointer");
}
if (n.contiguous_data_ptr() == nullptr) {
LBANN_ERROR("data_id: ",
index,
" does not have a valid contiguous data pointer");
}

size_t sz = m_compacted_sample_size;
size_t sz = m_compacted_sample_size;

if (m_node_sizes_vary) {
if (m_sample_sizes.find(index) == m_sample_sizes.end()) {
LBANN_ERROR(
"m_sample_sizes.find(index) == m_sample_sizes.end() for index: ",
index,
"; m_sample_sizes.size: ",
m_sample_sizes.size());
if (m_node_sizes_vary) {
if (m_sample_sizes.find(index) == m_sample_sizes.end()) {
LBANN_ERROR(
"m_sample_sizes.find(index) == m_sample_sizes.end() for index: ",
index,
"; m_sample_sizes.size: ",
m_sample_sizes.size());
}
sz = m_sample_sizes[index];
}
sz = m_sample_sizes[index];
}

m_comm->nb_tagged_send<El::byte>(s,
sz,
p,
index,
m_send_requests[ss++],
m_comm->get_trainer_comm());
m_comm->nb_tagged_send<El::byte>(s,
sz,
p,
index,
m_send_requests[ss++],
m_comm->get_trainer_comm());
}
}
}

// sanity checks
if (ss != m_send_requests.size()) {
LBANN_ERROR("ss != m_send_requests.size; ss: ",
ss,
" m_send_requests.size: ",
m_send_requests.size());
}

// start recvs for incoming data
ss = 0;

for (int p = 0; p < m_np_in_trainer; p++) {
const std::unordered_set<int>& indices = m_indices_to_recv[p];
int sanity = 0;
for (auto index : indices) {
++sanity;
int sz = m_compacted_sample_size;
if (m_node_sizes_vary) {
if (m_sample_sizes.find(index) == m_sample_sizes.end()) {
LBANN_ERROR(
"m_sample_sizes.find(index) == m_sample_sizes.end() for index: ",
index,
"; m_sample_sizes.size(): ",
m_sample_sizes.size(),
" role: ",
m_reader->get_role(),
" for index: ",
sanity,
" of ",
indices.size());
// sanity checks
if (ss != m_send_requests.size()) {
LBANN_ERROR("ss != m_send_requests.size; ss: ",
ss,
" m_send_requests.size: ",
m_send_requests.size());
}

// start recvs for incoming data
ss = 0;

for (int p = 0; p < m_np_in_trainer; p++) {
const std::unordered_set<int>& indices = m_indices_to_recv[p];
int sanity = 0;
for (auto index : indices) {
++sanity;
int sz = m_compacted_sample_size;
if (m_node_sizes_vary) {
if (m_sample_sizes.find(index) == m_sample_sizes.end()) {
LBANN_ERROR(
"m_sample_sizes.find(index) == m_sample_sizes.end() for index: ",
index,
"; m_sample_sizes.size(): ",
m_sample_sizes.size(),
" role: ",
m_reader->get_role(),
" for index: ",
sanity,
" of ",
indices.size());
}
sz = m_sample_sizes[index];
}
sz = m_sample_sizes[index];
}

m_recv_buffer[ss].set(conduit::DataType::uint8(sz));
El::byte* r = reinterpret_cast<El::byte*>(m_recv_buffer[ss].data_ptr());
m_comm->nb_tagged_recv<El::byte>(r,
sz,
p,
index,
m_recv_requests[ss],
m_comm->get_trainer_comm());
m_recv_data_ids[ss] = index;
++ss;
m_recv_buffer[ss].set(conduit::DataType::uint8(sz));
El::byte* r = reinterpret_cast<El::byte*>(m_recv_buffer[ss].data_ptr());
m_comm->nb_tagged_recv<El::byte>(r,
sz,
p,
index,
m_recv_requests[ss],
m_comm->get_trainer_comm());
m_recv_data_ids[ss] = index;
++ss;
}
}
}

// sanity checks
if (ss != m_recv_buffer.size()) {
LBANN_ERROR("ss != m_recv_buffer.size; ss: ",
ss,
" m_recv_buffer.size: ",
m_recv_buffer.size());
}
if (m_recv_requests.size() != m_recv_buffer.size()) {
LBANN_ERROR("m_recv_requests.size != m_recv_buffer.size; m_recv_requests: ",
m_recv_requests.size(),
" m_recv_buffer.size: ",
m_recv_buffer.size());
}
// sanity checks
if (ss != m_recv_buffer.size()) {
LBANN_ERROR("ss != m_recv_buffer.size; ss: ",
ss,
" m_recv_buffer.size: ",
m_recv_buffer.size());
}
if (m_recv_requests.size() != m_recv_buffer.size()) {
LBANN_ERROR(
"m_recv_requests.size != m_recv_buffer.size; m_recv_requests: ",
m_recv_requests.size(),
" m_recv_buffer.size: ",
m_recv_buffer.size());
}
} // End scope std::lock_guard<std::mutex> lock(m_mutex);

m_start_snd_rcv_time += (get_time() - tm5);
}
Expand Down Expand Up @@ -737,6 +745,7 @@ void data_store_conduit::finish_exchange_data_by_sample()
m_rebuild_time += (get_time() - tm5);

if (m_spill) {
std::lock_guard<std::mutex> lock(m_mutex);
// TODO
m_data.clear();
}
Expand Down Expand Up @@ -839,6 +848,7 @@ void data_store_conduit::build_preloaded_owner_map(

const conduit::Node& data_store_conduit::get_random_node() const
{
std::lock_guard<std::mutex> lock(m_mutex);
size_t sz = m_data.size();

// Deal with edge case
Expand Down Expand Up @@ -871,6 +881,7 @@ conduit::Node& data_store_conduit::get_empty_node(int data_id)

void data_store_conduit::compact_nodes()
{
std::lock_guard<std::mutex> lock(m_mutex);
for (auto&& j : *m_shuffled_indices) {
if (m_data.find(j) != m_data.end()) {
if (!(m_data[j].is_contiguous() && m_data[j].is_compact())) {
Expand Down Expand Up @@ -1068,6 +1079,7 @@ void data_store_conduit::check_mem_capacity(lbann_comm* comm,

bool data_store_conduit::has_conduit_node(int data_id) const
{
std::lock_guard<std::mutex> lock(m_mutex);
std::unordered_map<int, conduit::Node>::const_iterator t =
m_data.find(data_id);
return t != m_data.end();
Expand Down Expand Up @@ -1185,6 +1197,7 @@ void data_store_conduit::get_image_sizes(map_is_t& file_sizes,
// this block fires if we're exchanging cache data at the end
// of the first epoch, and the data store was not preloaded
if (is_explicitly_loading()) {
std::lock_guard<std::mutex> lock(m_mutex);
for (const auto& t : m_data) {
int data_id = t.first;
my_image_sizes.push_back(data_id);
Expand Down Expand Up @@ -1256,6 +1269,7 @@ void data_store_conduit::compute_image_offsets(
map_is_t& sizes,
std::vector<std::vector<int>>& indices)
{
std::lock_guard<std::mutex> lock(m_mutex);
size_t offset = 0;
for (size_t p = 0; p < indices.size(); p++) {
for (auto idx : indices[p]) {
Expand Down Expand Up @@ -1319,7 +1333,7 @@ void data_store_conduit::allocate_shared_segment(
if (node_id == 0) {
shm_fd = shm_open(m_seg_name.c_str(), O_CREAT | O_RDWR | O_EXCL, 0666);
if (shm_fd == -1) {
LBANN_ERROR("shm_open failed");
LBANN_ERROR("shm_open failed: ", m_seg_name);
}
int v = ftruncate(shm_fd, size);
if (v != 0) {
Expand Down Expand Up @@ -1382,6 +1396,7 @@ void data_store_conduit::exchange_local_caches()
PROFILE(" get_image_sizes time: ", (get_time() - tm1));

tm1 = get_time();
// BVE 9-21-2023 It seems like this should only be done once
allocate_shared_segment(m_sample_sizes, indices);
PROFILE(" allocate_shared_segment time: ", (get_time() - tm1));

Expand Down Expand Up @@ -1447,6 +1462,7 @@ void data_store_conduit::read_files(std::vector<char>& work,
void data_store_conduit::build_conduit_nodes(map_is_t& sizes)
{
image_data_reader* image_reader = dynamic_cast<image_data_reader*>(m_reader);
std::lock_guard<std::mutex> lock(m_mutex);
for (auto t : sizes) {
int data_id = t.first;
const auto sample = image_reader->get_sample(static_cast<size_t>(data_id));
Expand Down

0 comments on commit 73ef72b

Please sign in to comment.