@@ -151,36 +151,41 @@ void hdf5_reader<TensorDataType>::read_hdf5_hyperslab(hsize_t h_data,
151
151
template <typename TensorDataType>
152
152
void hdf5_reader<TensorDataType>::read_hdf5_sample(uint64_t data_id,
153
153
TensorDataType* sample,
154
- TensorDataType* labels)
154
+ TensorDataType* labels,
155
+ bool get_sample,
156
+ bool get_labels,
157
+ bool get_responses)
155
158
{
156
159
int world_rank = get_comm ()->get_rank_in_trainer ();
157
160
auto file = m_file_paths[data_id];
158
161
hid_t h_file = CHECK_HDF5 (H5Fopen (file.c_str (), H5F_ACC_RDONLY, m_fapl));
159
162
160
- // load in dataset
161
- hid_t h_data = CHECK_HDF5 (H5Dopen (h_file, m_key_data.c_str (), H5P_DEFAULT));
162
- hid_t filespace = CHECK_HDF5 (H5Dget_space (h_data));
163
- // get the number of dimensions from the dataset
164
- int rank1 = H5Sget_simple_extent_ndims (filespace);
165
- hsize_t dims[rank1];
166
- // read in what the dimensions are
167
- CHECK_HDF5 (H5Sget_simple_extent_dims (filespace, dims, NULL ));
168
-
169
- read_hdf5_hyperslab (h_data, filespace, world_rank, sample);
170
- // close data set
171
- CHECK_HDF5 (H5Dclose (h_data));
163
+ if (get_sample) {
164
+ // load in dataset
165
+ hid_t h_data = CHECK_HDF5 (H5Dopen (h_file, m_key_data.c_str (), H5P_DEFAULT));
166
+ hid_t filespace = CHECK_HDF5 (H5Dget_space (h_data));
167
+ // get the number of dimensions from the dataset
168
+ int rank1 = H5Sget_simple_extent_ndims (filespace);
169
+ hsize_t dims[rank1];
170
+ // read in what the dimensions are
171
+ CHECK_HDF5 (H5Sget_simple_extent_dims (filespace, dims, NULL ));
172
+
173
+ read_hdf5_hyperslab (h_data, filespace, world_rank, sample);
174
+ // close data set
175
+ CHECK_HDF5 (H5Dclose (h_data));
176
+ }
172
177
173
- if (this ->has_labels () && labels != nullptr ) {
178
+ if (this ->has_labels () && labels != nullptr && get_labels ) {
174
179
assert_always (m_hyperslab_labels);
175
180
hid_t h_labels =
176
181
CHECK_HDF5 (H5Dopen (h_file, m_key_labels.c_str (), H5P_DEFAULT));
177
182
hid_t filespace_labels = CHECK_HDF5 (H5Dget_space (h_labels));
178
183
read_hdf5_hyperslab (h_labels, filespace_labels, world_rank, labels);
179
184
CHECK_HDF5 (H5Dclose (h_labels));
180
185
}
181
- else if (this ->has_responses ()) {
186
+ else if (this ->has_responses () && get_responses ) {
182
187
assert_always (labels == nullptr );
183
- h_data = CHECK_HDF5 (H5Dopen (h_file, m_key_responses.c_str (), H5P_DEFAULT));
188
+ hid_t h_data = CHECK_HDF5 (H5Dopen (h_file, m_key_responses.c_str (), H5P_DEFAULT));
184
189
CHECK_HDF5 (H5Dread (h_data,
185
190
H5T_NATIVE_FLOAT,
186
191
H5S_ALL,
@@ -390,17 +395,17 @@ bool hdf5_reader<TensorDataType>::fetch_datum(Mat& X,
390
395
391
396
auto X_v = create_datum_view (X, mb_idx);
392
397
if (m_use_data_store) {
393
- fetch_datum_conduit (X_v, data_id);
398
+ fetch_sample_conduit (X_v, data_id, " slab " );
394
399
}
395
400
else {
396
- read_hdf5_sample (data_id, (TensorDataType*)X_v.Buffer (), nullptr );
401
+ read_hdf5_sample (data_id, (TensorDataType*)X_v.Buffer (), nullptr , true , false , false );
397
402
}
398
403
prof_region_end (" fetch_datum" , false );
399
404
return true ;
400
405
}
401
406
402
407
template <typename TensorDataType>
403
- void hdf5_reader<TensorDataType>::fetch_datum_conduit (Mat& X, uint64_t data_id)
408
+ void hdf5_reader<TensorDataType>::fetch_sample_conduit (Mat& X, uint64_t data_id, std::string field )
404
409
{
405
410
const std::string conduit_key = LBANN_DATA_ID_STR (data_id);
406
411
// Create a node to hold all of the data
@@ -417,7 +422,7 @@ void hdf5_reader<TensorDataType>::fetch_datum_conduit(Mat& X, uint64_t data_id)
417
422
}
418
423
prof_region_begin (" set_external" , prof_colors[0 ], false );
419
424
conduit::Node slab;
420
- slab.set_external (node[conduit_key + " /slab " ]);
425
+ slab.set_external (node[conduit_key + " /" + field ]);
421
426
prof_region_end (" set_external" , false );
422
427
TensorDataType* data = slab.value ();
423
428
prof_region_begin (" copy_to_buffer" , prof_colors[0 ], false );
@@ -438,35 +443,21 @@ bool hdf5_reader<TensorDataType>::fetch_response(Mat& Y,
438
443
}
439
444
440
445
prof_region_begin (" fetch_response" , prof_colors[0 ], false );
441
- float * buf = nullptr ;
442
446
if (m_hyperslab_labels) {
443
447
assert_eq ((unsigned long )Y.Height (), m_num_features);
444
- const std::string conduit_key = LBANN_DATA_ID_STR (data_id);
445
- conduit::Node node;
446
- const conduit::Node& ds_node = m_data_store->get_conduit_node (data_id);
447
- node.set_external (ds_node);
448
- conduit::Node slab;
449
- slab.set_external (node[conduit_key + " /responses_slab" ]);
450
- prof_region_end (" set_external" , false );
451
- buf = slab.value ();
452
448
auto Y_v = create_datum_view (Y, mb_idx);
453
- std::memcpy (Y_v. Buffer (), buf, m_num_features * sizeof (TensorDataType) );
449
+ fetch_sample_conduit (Y_v, data_id, " responses_slab " );
454
450
}
455
451
else {
456
452
assert_eq ((unsigned long )Y.Height (), m_all_responses.size ());
457
- conduit::Node node;
458
- if (m_use_data_store &&
459
- (data_store_active () || m_data_store->has_conduit_node (data_id))) {
460
- const conduit::Node& ds_node = m_data_store->get_conduit_node (data_id);
461
- node.set_external (ds_node);
453
+ auto Y_v = create_datum_view (Y, mb_idx);
454
+ if (m_use_data_store) {
455
+ fetch_sample_conduit (Y_v, data_id, " responses" );
462
456
}
463
457
else {
464
- load_sample (node, data_id);
458
+ read_hdf5_sample (data_id, nullptr , nullptr , false , false , true );
459
+ std::memcpy (Y_v.Buffer (), &m_all_responses[0 ], m_all_responses.size () * sizeof (DataType));
465
460
}
466
- const std::string conduit_obj = LBANN_DATA_ID_STR (data_id);
467
- buf = node[conduit_obj + " /responses" ].value ();
468
- auto Y_v = create_datum_view (Y, mb_idx);
469
- std::memcpy (Y_v.Buffer (), buf, m_all_responses.size () * sizeof (DataType));
470
461
}
471
462
prof_region_end (" fetch_response" , false );
472
463
return true ;
0 commit comments