Skip to content

Commit 1db91a2

Browse files
Python Dataset Reader (#2414)
* Add skeleton for new python data reader * Implement basic functionality * Fix initialization for distconv * Add support for labels * Add python library supporting classes * clang format * Raise exception if rank/io parts not set * Rename to python dataset * Add optional module dir argument to add to path * Add unit tests * Simplify naming * Add cosmoflow example and reader helper * Update release notes * Save dataset pickle in work dir * Overhaul new data reader to support prefetching multiple samples/batches * Fix worker index calculation * clang-format * Clarify proto comments * Throw error if file fails to open * Add docstrings and type hints * Update CosmoFlow example and enable parallel IO * Add basic sample size checking, remove label reconstruction, general clean up * Switch to multiprocessing pool * Implement response shuffling for distconv * fix typo Co-authored-by: Tal Ben-Nun <[email protected]> --------- Co-authored-by: Tal Ben-Nun <[email protected]>
1 parent 811af60 commit 1db91a2

File tree

14 files changed

+1306
-10
lines changed

14 files changed

+1306
-10
lines changed

Diff for: ReleaseNotes.txt

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Experiments & Applications:
1919
Internal features:
2020

2121
I/O & data ingestion:
22+
- Added a new python dataset reader for simple, flexible, and distconv-supported
23+
python data readers.
2224

2325
Build system:
2426

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
from glob import glob
3+
from lbann.util.data import Sample, SampleDims, Dataset, DistConvDataset
4+
import h5py as h5
5+
import os
6+
7+
8+
class CosmoFlowDataset(DistConvDataset):
9+
def __init__(self, data_dir, input_width, num_secrets):
10+
self.data_dir = data_dir
11+
self.input_width = input_width
12+
self.num_secrets = num_secrets
13+
self.samples = glob(os.path.join(data_dir, '*.hdf5'))
14+
self.samples.sort()
15+
16+
def __len__(self):
17+
return len(self.samples)
18+
19+
def __getitem__(self, index) -> Sample:
20+
data = h5.File(self.samples[index], 'r')
21+
slice_width = self.input_width // self.num_io_partitions
22+
slice_ind = self.rank % self.num_io_partitions
23+
full = data['full'][:,
24+
slice_ind*slice_width:(slice_ind+1)*slice_width,
25+
:self.input_width,
26+
:self.input_width].astype(np.float32)
27+
par = data['unitPar'][:].astype(np.float32)
28+
return Sample(sample=np.ascontiguousarray(full), response=par)
29+
30+
def get_sample_dims(self):
31+
return SampleDims(sample=[4, self.input_width, self.input_width, self.input_width], response=self.num_secrets)

Diff for: applications/physics/cosmology/cosmoflow/train_cosmoflow.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,21 @@
1010
import lbann.contrib.args
1111
import lbann.contrib.launcher
1212
from lbann.core.util import get_parallel_strategy_args
13+
import lbann.util.data
1314
import os
15+
from cosmoflow_dataset import CosmoFlowDataset
16+
17+
def create_python_dataset_reader(args):
18+
"""Create a python dataset reader for CosmoFlow."""
19+
20+
readers = []
21+
for role in ['train', 'val', 'test']:
22+
role_dir = getattr(args, f'{role}_dir')
23+
dataset = CosmoFlowDataset(role_dir, args.input_width, args.num_secrets)
24+
reader = lbann.util.data.construct_python_dataset_reader(dataset, role=role)
25+
readers.append(reader)
26+
27+
return lbann.reader_pb2.DataReader(reader=readers)
1428

1529
def create_cosmoflow_data_reader(
1630
train_path, val_path, test_path, num_responses):
@@ -133,6 +147,9 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
133147
parser.add_argument(
134148
'--synthetic', action='store_true',
135149
help='Use synthetic data')
150+
parser.add_argument(
151+
'--python-dataset', action='store_true',
152+
help='Use python dataset reader')
136153
parser.add_argument(
137154
'--no-datastore', action='store_true',
138155
help='Disable the data store')
@@ -220,22 +237,26 @@ def create_synthetic_data_reader(input_width: int, num_responses: int) -> Any:
220237
# optimizer.learn_rate *= 1e-2
221238

222239
# Setup data reader
240+
serialize_io = False
223241
if args.synthetic:
224242
data_reader = create_synthetic_data_reader(
225243
args.input_width, args.num_secrets)
244+
elif args.python_dataset:
245+
data_reader = create_python_dataset_reader(args)
226246
else:
227247
data_reader = create_cosmoflow_data_reader(
228248
args.train_dir,
229249
args.val_dir,
230250
args.test_dir,
231251
num_responses=args.num_secrets)
252+
serialize_io = True
232253

233254
# Setup trainer
234255
random_seed_arg = {'random_seed': args.random_seed} \
235256
if args.random_seed is not None else {}
236257
trainer = lbann.Trainer(
237258
mini_batch_size=args.mini_batch_size,
238-
serialize_io=True,
259+
serialize_io=serialize_io,
239260
**random_seed_arg)
240261

241262
# Runtime parameters/arguments
+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import os.path
3+
import sys
4+
import numpy as np
5+
from lbann.util.data import Dataset, Sample, SampleDims, construct_python_dataset_reader
6+
7+
# Bamboo utilities
8+
current_file = os.path.realpath(__file__)
9+
current_dir = os.path.dirname(current_file)
10+
sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python'))
11+
import tools
12+
13+
# ==============================================
14+
# Objects for Python dataset data reader
15+
# ==============================================
16+
# Note: The Python dataset data reader loads the dataset constructed below.
17+
18+
# Data
19+
class TestDataset(Dataset):
20+
def __init__(self):
21+
np.random.seed(20240109)
22+
self.num_samples = 29
23+
self.sample_size = 7
24+
self.samples = np.random.normal(size=(self.num_samples,self.sample_size)).astype(np.float32)
25+
26+
def __len__(self):
27+
return self.num_samples
28+
29+
def __getitem__(self, index):
30+
return Sample(sample=self.samples[index,:])
31+
32+
def get_sample_dims(self):
33+
return SampleDims(sample=[self.sample_size])
34+
35+
test_dataset = TestDataset()
36+
37+
# ==============================================
38+
# Setup LBANN experiment
39+
# ==============================================
40+
41+
def setup_experiment(lbann, weekly):
42+
"""Construct LBANN experiment.
43+
44+
Args:
45+
lbann (module): Module for LBANN Python frontend
46+
47+
"""
48+
mini_batch_size = len(test_dataset) // 4
49+
trainer = lbann.Trainer(mini_batch_size)
50+
model = construct_model(lbann)
51+
data_reader = construct_data_reader(lbann)
52+
optimizer = lbann.NoOptimizer()
53+
return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes
54+
55+
def construct_model(lbann):
56+
"""Construct LBANN model.
57+
58+
Args:
59+
lbann (module): Module for LBANN Python frontend
60+
61+
"""
62+
63+
# Layer graph
64+
x = lbann.Input(data_field='samples')
65+
y = lbann.L2Norm2(x)
66+
layers = list(lbann.traverse_layer_graph(x))
67+
metric = lbann.Metric(y, name='obj')
68+
callbacks = []
69+
70+
# Compute expected value with NumPy
71+
vals = []
72+
for i in range(len(test_dataset)):
73+
x = test_dataset[i].sample.astype(np.float64)
74+
y = tools.numpy_l2norm2(x)
75+
vals.append(y)
76+
val = np.mean(vals)
77+
tol = 8 * val * np.finfo(np.float32).eps
78+
callbacks.append(lbann.CallbackCheckMetric(
79+
metric=metric.name,
80+
lower_bound=val-tol,
81+
upper_bound=val+tol,
82+
error_on_failure=True,
83+
execution_modes='test'))
84+
85+
# Construct model
86+
num_epochs = 0
87+
return lbann.Model(num_epochs,
88+
layers=layers,
89+
metrics=[metric],
90+
callbacks=callbacks)
91+
92+
def construct_data_reader(lbann):
93+
"""Construct Protobuf message for Python dataset data reader.
94+
95+
The Python data reader will import the current Python file to
96+
access the sample access functions.
97+
98+
Args:
99+
lbann (module): Module for LBANN Python frontend
100+
101+
"""
102+
103+
dataset_path = os.path.join(work_dir, 'dataset.pkl')
104+
105+
# Note: The training data reader should be removed when
106+
# https://github.com/LLNL/lbann/issues/1098 is resolved.
107+
message = lbann.reader_pb2.DataReader()
108+
message.reader.extend([
109+
construct_python_dataset_reader(
110+
test_dataset,
111+
dataset_path,
112+
'train',
113+
shuffle=False
114+
)
115+
])
116+
message.reader.extend([
117+
construct_python_dataset_reader(
118+
test_dataset,
119+
dataset_path,
120+
'test',
121+
shuffle=False
122+
)
123+
])
124+
return message
125+
126+
# ==============================================
127+
# Setup PyTest
128+
# ==============================================
129+
130+
work_dir = os.path.join(os.path.dirname(__file__),
131+
'experiments',
132+
os.path.basename(__file__).split('.py')[0])
133+
os.makedirs(work_dir, exist_ok=True)
134+
135+
# Create test functions that can interact with PyTest
136+
for _test_func in tools.create_tests(setup_experiment, __file__, work_dir=work_dir):
137+
globals()[_test_func.__name__] = _test_func

0 commit comments

Comments
 (0)