From 6011d0387b99fb6de9b12f464ff2aa2c102fc331 Mon Sep 17 00:00:00 2001 From: Pier Fiedorowicz <117680821+fiedorowicz1@users.noreply.github.com> Date: Wed, 14 Feb 2024 13:57:27 -0800 Subject: [PATCH] Fix CosmoFlow Double-Reading (#2425) * Fix double-read when not using datastore * clang-format --- .../readers/data_reader_hdf5_legacy.hpp | 7 +- .../readers/data_reader_hdf5_legacy.cpp | 81 ++++++++++--------- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/include/lbann/data_ingestion/readers/data_reader_hdf5_legacy.hpp b/include/lbann/data_ingestion/readers/data_reader_hdf5_legacy.hpp index a2095f7e954..c4fa06d7334 100644 --- a/include/lbann/data_ingestion/readers/data_reader_hdf5_legacy.hpp +++ b/include/lbann/data_ingestion/readers/data_reader_hdf5_legacy.hpp @@ -143,11 +143,14 @@ class hdf5_reader : public generic_data_reader TensorDataType* sample); void read_hdf5_sample(uint64_t data_id, TensorDataType* sample, - TensorDataType* labels); + TensorDataType* labels, + bool get_sample = true, + bool get_labels = true, + bool get_responses = true); // void set_defaults() override; void load_sample(conduit::Node& node, uint64_t data_id); bool fetch_datum(CPUMat& X, uint64_t data_id, uint64_t mb_idx) override; - void fetch_datum_conduit(Mat& X, uint64_t data_id); + void fetch_sample_conduit(Mat& X, uint64_t data_id, std::string field); bool fetch_data_field(data_field_type data_field, CPUMat& Y, uint64_t data_id, diff --git a/src/data_ingestion/readers/data_reader_hdf5_legacy.cpp b/src/data_ingestion/readers/data_reader_hdf5_legacy.cpp index 91249505413..e0a3807ea7f 100644 --- a/src/data_ingestion/readers/data_reader_hdf5_legacy.cpp +++ b/src/data_ingestion/readers/data_reader_hdf5_legacy.cpp @@ -151,26 +151,31 @@ void hdf5_reader::read_hdf5_hyperslab(hsize_t h_data, template void hdf5_reader::read_hdf5_sample(uint64_t data_id, TensorDataType* sample, - TensorDataType* labels) + TensorDataType* labels, + bool get_sample, + bool get_labels, + bool get_responses) { int world_rank = get_comm()->get_rank_in_trainer(); auto file = m_file_paths[data_id]; hid_t h_file = CHECK_HDF5(H5Fopen(file.c_str(), H5F_ACC_RDONLY, m_fapl)); - // load in dataset - hid_t h_data = CHECK_HDF5(H5Dopen(h_file, m_key_data.c_str(), H5P_DEFAULT)); - hid_t filespace = CHECK_HDF5(H5Dget_space(h_data)); - // get the number of dimensions from the dataset - int rank1 = H5Sget_simple_extent_ndims(filespace); - hsize_t dims[rank1]; - // read in what the dimensions are - CHECK_HDF5(H5Sget_simple_extent_dims(filespace, dims, NULL)); - - read_hdf5_hyperslab(h_data, filespace, world_rank, sample); - // close data set - CHECK_HDF5(H5Dclose(h_data)); + if (get_sample) { + // load in dataset + hid_t h_data = CHECK_HDF5(H5Dopen(h_file, m_key_data.c_str(), H5P_DEFAULT)); + hid_t filespace = CHECK_HDF5(H5Dget_space(h_data)); + // get the number of dimensions from the dataset + int rank1 = H5Sget_simple_extent_ndims(filespace); + hsize_t dims[rank1]; + // read in what the dimensions are + CHECK_HDF5(H5Sget_simple_extent_dims(filespace, dims, NULL)); + + read_hdf5_hyperslab(h_data, filespace, world_rank, sample); + // close data set + CHECK_HDF5(H5Dclose(h_data)); + } - if (this->has_labels() && labels != nullptr) { + if (this->has_labels() && labels != nullptr && get_labels) { assert_always(m_hyperslab_labels); hid_t h_labels = CHECK_HDF5(H5Dopen(h_file, m_key_labels.c_str(), H5P_DEFAULT)); @@ -178,9 +183,10 @@ void hdf5_reader::read_hdf5_sample(uint64_t data_id, read_hdf5_hyperslab(h_labels, filespace_labels, world_rank, labels); CHECK_HDF5(H5Dclose(h_labels)); } - else if (this->has_responses()) { + else if (this->has_responses() && get_responses) { assert_always(labels == nullptr); - h_data = CHECK_HDF5(H5Dopen(h_file, m_key_responses.c_str(), H5P_DEFAULT)); + hid_t h_data = + CHECK_HDF5(H5Dopen(h_file, m_key_responses.c_str(), H5P_DEFAULT)); CHECK_HDF5(H5Dread(h_data, H5T_NATIVE_FLOAT, H5S_ALL, @@ -390,17 +396,24 @@ bool hdf5_reader::fetch_datum(Mat& X, auto X_v = create_datum_view(X, mb_idx); if (m_use_data_store) { - fetch_datum_conduit(X_v, data_id); + fetch_sample_conduit(X_v, data_id, "slab"); } else { - read_hdf5_sample(data_id, (TensorDataType*)X_v.Buffer(), nullptr); + read_hdf5_sample(data_id, + (TensorDataType*)X_v.Buffer(), + nullptr, + true, + false, + false); } prof_region_end("fetch_datum", false); return true; } template -void hdf5_reader::fetch_datum_conduit(Mat& X, uint64_t data_id) +void hdf5_reader::fetch_sample_conduit(Mat& X, + uint64_t data_id, + std::string field) { const std::string conduit_key = LBANN_DATA_ID_STR(data_id); // Create a node to hold all of the data @@ -417,7 +430,7 @@ void hdf5_reader::fetch_datum_conduit(Mat& X, uint64_t data_id) } prof_region_begin("set_external", prof_colors[0], false); conduit::Node slab; - slab.set_external(node[conduit_key + "/slab"]); + slab.set_external(node[conduit_key + "/" + field]); prof_region_end("set_external", false); TensorDataType* data = slab.value(); prof_region_begin("copy_to_buffer", prof_colors[0], false); @@ -438,35 +451,23 @@ bool hdf5_reader::fetch_response(Mat& Y, } prof_region_begin("fetch_response", prof_colors[0], false); - float* buf = nullptr; if (m_hyperslab_labels) { assert_eq((unsigned long)Y.Height(), m_num_features); - const std::string conduit_key = LBANN_DATA_ID_STR(data_id); - conduit::Node node; - const conduit::Node& ds_node = m_data_store->get_conduit_node(data_id); - node.set_external(ds_node); - conduit::Node slab; - slab.set_external(node[conduit_key + "/responses_slab"]); - prof_region_end("set_external", false); - buf = slab.value(); auto Y_v = create_datum_view(Y, mb_idx); - std::memcpy(Y_v.Buffer(), buf, m_num_features * sizeof(TensorDataType)); + fetch_sample_conduit(Y_v, data_id, "responses_slab"); } else { assert_eq((unsigned long)Y.Height(), m_all_responses.size()); - conduit::Node node; - if (m_use_data_store && - (data_store_active() || m_data_store->has_conduit_node(data_id))) { - const conduit::Node& ds_node = m_data_store->get_conduit_node(data_id); - node.set_external(ds_node); + auto Y_v = create_datum_view(Y, mb_idx); + if (m_use_data_store) { + fetch_sample_conduit(Y_v, data_id, "responses"); } else { - load_sample(node, data_id); + read_hdf5_sample(data_id, nullptr, nullptr, false, false, true); + std::memcpy(Y_v.Buffer(), + &m_all_responses[0], + m_all_responses.size() * sizeof(DataType)); } - const std::string conduit_obj = LBANN_DATA_ID_STR(data_id); - buf = node[conduit_obj + "/responses"].value(); - auto Y_v = create_datum_view(Y, mb_idx); - std::memcpy(Y_v.Buffer(), buf, m_all_responses.size() * sizeof(DataType)); } prof_region_end("fetch_response", false); return true;