Skip to content

Commit f894c6a

Browse files
committed
Fix double-read when not using datastore
1 parent 3d61161 commit f894c6a

File tree

2 files changed

+36
-42
lines changed

2 files changed

+36
-42
lines changed

Diff for: include/lbann/data_ingestion/readers/data_reader_hdf5_legacy.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,14 @@ class hdf5_reader : public generic_data_reader
143143
TensorDataType* sample);
144144
void read_hdf5_sample(uint64_t data_id,
145145
TensorDataType* sample,
146-
TensorDataType* labels);
146+
TensorDataType* labels,
147+
bool get_sample = true,
148+
bool get_labels = true,
149+
bool get_responses = true);
147150
// void set_defaults() override;
148151
void load_sample(conduit::Node& node, uint64_t data_id);
149152
bool fetch_datum(CPUMat& X, uint64_t data_id, uint64_t mb_idx) override;
150-
void fetch_datum_conduit(Mat& X, uint64_t data_id);
153+
void fetch_sample_conduit(Mat& X, uint64_t data_id, std::string field);
151154
bool fetch_data_field(data_field_type data_field,
152155
CPUMat& Y,
153156
uint64_t data_id,

Diff for: src/data_ingestion/readers/data_reader_hdf5_legacy.cpp

+31-40
Original file line numberDiff line numberDiff line change
@@ -151,36 +151,41 @@ void hdf5_reader<TensorDataType>::read_hdf5_hyperslab(hsize_t h_data,
151151
template <typename TensorDataType>
152152
void hdf5_reader<TensorDataType>::read_hdf5_sample(uint64_t data_id,
153153
TensorDataType* sample,
154-
TensorDataType* labels)
154+
TensorDataType* labels,
155+
bool get_sample,
156+
bool get_labels,
157+
bool get_responses)
155158
{
156159
int world_rank = get_comm()->get_rank_in_trainer();
157160
auto file = m_file_paths[data_id];
158161
hid_t h_file = CHECK_HDF5(H5Fopen(file.c_str(), H5F_ACC_RDONLY, m_fapl));
159162

160-
// load in dataset
161-
hid_t h_data = CHECK_HDF5(H5Dopen(h_file, m_key_data.c_str(), H5P_DEFAULT));
162-
hid_t filespace = CHECK_HDF5(H5Dget_space(h_data));
163-
// get the number of dimensions from the dataset
164-
int rank1 = H5Sget_simple_extent_ndims(filespace);
165-
hsize_t dims[rank1];
166-
// read in what the dimensions are
167-
CHECK_HDF5(H5Sget_simple_extent_dims(filespace, dims, NULL));
168-
169-
read_hdf5_hyperslab(h_data, filespace, world_rank, sample);
170-
// close data set
171-
CHECK_HDF5(H5Dclose(h_data));
163+
if (get_sample) {
164+
// load in dataset
165+
hid_t h_data = CHECK_HDF5(H5Dopen(h_file, m_key_data.c_str(), H5P_DEFAULT));
166+
hid_t filespace = CHECK_HDF5(H5Dget_space(h_data));
167+
// get the number of dimensions from the dataset
168+
int rank1 = H5Sget_simple_extent_ndims(filespace);
169+
hsize_t dims[rank1];
170+
// read in what the dimensions are
171+
CHECK_HDF5(H5Sget_simple_extent_dims(filespace, dims, NULL));
172+
173+
read_hdf5_hyperslab(h_data, filespace, world_rank, sample);
174+
// close data set
175+
CHECK_HDF5(H5Dclose(h_data));
176+
}
172177

173-
if (this->has_labels() && labels != nullptr) {
178+
if (this->has_labels() && labels != nullptr && get_labels) {
174179
assert_always(m_hyperslab_labels);
175180
hid_t h_labels =
176181
CHECK_HDF5(H5Dopen(h_file, m_key_labels.c_str(), H5P_DEFAULT));
177182
hid_t filespace_labels = CHECK_HDF5(H5Dget_space(h_labels));
178183
read_hdf5_hyperslab(h_labels, filespace_labels, world_rank, labels);
179184
CHECK_HDF5(H5Dclose(h_labels));
180185
}
181-
else if (this->has_responses()) {
186+
else if (this->has_responses() && get_responses) {
182187
assert_always(labels == nullptr);
183-
h_data = CHECK_HDF5(H5Dopen(h_file, m_key_responses.c_str(), H5P_DEFAULT));
188+
hid_t h_data = CHECK_HDF5(H5Dopen(h_file, m_key_responses.c_str(), H5P_DEFAULT));
184189
CHECK_HDF5(H5Dread(h_data,
185190
H5T_NATIVE_FLOAT,
186191
H5S_ALL,
@@ -390,17 +395,17 @@ bool hdf5_reader<TensorDataType>::fetch_datum(Mat& X,
390395

391396
auto X_v = create_datum_view(X, mb_idx);
392397
if (m_use_data_store) {
393-
fetch_datum_conduit(X_v, data_id);
398+
fetch_sample_conduit(X_v, data_id, "slab");
394399
}
395400
else {
396-
read_hdf5_sample(data_id, (TensorDataType*)X_v.Buffer(), nullptr);
401+
read_hdf5_sample(data_id, (TensorDataType*)X_v.Buffer(), nullptr, true, false, false);
397402
}
398403
prof_region_end("fetch_datum", false);
399404
return true;
400405
}
401406

402407
template <typename TensorDataType>
403-
void hdf5_reader<TensorDataType>::fetch_datum_conduit(Mat& X, uint64_t data_id)
408+
void hdf5_reader<TensorDataType>::fetch_sample_conduit(Mat& X, uint64_t data_id, std::string field)
404409
{
405410
const std::string conduit_key = LBANN_DATA_ID_STR(data_id);
406411
// Create a node to hold all of the data
@@ -417,7 +422,7 @@ void hdf5_reader<TensorDataType>::fetch_datum_conduit(Mat& X, uint64_t data_id)
417422
}
418423
prof_region_begin("set_external", prof_colors[0], false);
419424
conduit::Node slab;
420-
slab.set_external(node[conduit_key + "/slab"]);
425+
slab.set_external(node[conduit_key + "/" + field]);
421426
prof_region_end("set_external", false);
422427
TensorDataType* data = slab.value();
423428
prof_region_begin("copy_to_buffer", prof_colors[0], false);
@@ -438,35 +443,21 @@ bool hdf5_reader<TensorDataType>::fetch_response(Mat& Y,
438443
}
439444

440445
prof_region_begin("fetch_response", prof_colors[0], false);
441-
float* buf = nullptr;
442446
if (m_hyperslab_labels) {
443447
assert_eq((unsigned long)Y.Height(), m_num_features);
444-
const std::string conduit_key = LBANN_DATA_ID_STR(data_id);
445-
conduit::Node node;
446-
const conduit::Node& ds_node = m_data_store->get_conduit_node(data_id);
447-
node.set_external(ds_node);
448-
conduit::Node slab;
449-
slab.set_external(node[conduit_key + "/responses_slab"]);
450-
prof_region_end("set_external", false);
451-
buf = slab.value();
452448
auto Y_v = create_datum_view(Y, mb_idx);
453-
std::memcpy(Y_v.Buffer(), buf, m_num_features * sizeof(TensorDataType));
449+
fetch_sample_conduit(Y_v, data_id, "responses_slab");
454450
}
455451
else {
456452
assert_eq((unsigned long)Y.Height(), m_all_responses.size());
457-
conduit::Node node;
458-
if (m_use_data_store &&
459-
(data_store_active() || m_data_store->has_conduit_node(data_id))) {
460-
const conduit::Node& ds_node = m_data_store->get_conduit_node(data_id);
461-
node.set_external(ds_node);
453+
auto Y_v = create_datum_view(Y, mb_idx);
454+
if (m_use_data_store) {
455+
fetch_sample_conduit(Y_v, data_id, "responses");
462456
}
463457
else {
464-
load_sample(node, data_id);
458+
read_hdf5_sample(data_id, nullptr, nullptr, false, false, true);
459+
std::memcpy(Y_v.Buffer(), &m_all_responses[0], m_all_responses.size() * sizeof(DataType));
465460
}
466-
const std::string conduit_obj = LBANN_DATA_ID_STR(data_id);
467-
buf = node[conduit_obj + "/responses"].value();
468-
auto Y_v = create_datum_view(Y, mb_idx);
469-
std::memcpy(Y_v.Buffer(), buf, m_all_responses.size() * sizeof(DataType));
470461
}
471462
prof_region_end("fetch_response", false);
472463
return true;

0 commit comments

Comments
 (0)