Skip to content

Commit 40c516b

Browse files
committed
Save dataset pickle in work dir
1 parent 5d0e695 commit 40c516b

File tree

3 files changed

+34
-58
lines changed

3 files changed

+34
-58
lines changed

Diff for: ci_test/common_python/tools.py

-38
Original file line numberDiff line numberDiff line change
@@ -889,44 +889,6 @@ def create_python_data_reader(lbann,
889889

890890
return reader
891891

892-
def create_python_dataset_reader(lbann,
893-
file_name,
894-
dataset,
895-
execution_mode,
896-
dataset_path=None):
897-
"""Create protobuf message for Pythond dataset reader
898-
899-
A Python dataset reader gets data by loading a pickled
900-
lbann.util.data.Dataset object.
901-
902-
Args:
903-
lbann (module): Module for LBANN Python frontend.
904-
file_name (str): Python file with dataset class definition.
905-
dataset (lbann.util.data.Dataset): Dataset object to be pickled.
906-
execution_mode (str): 'train', 'validation', or 'test'
907-
908-
"""
909-
910-
# Extract paths
911-
if dataset_path is None:
912-
dataset_path = os.path.join(os.environ['TMPDIR'], f'dataset_{execution_mode}.pkl')
913-
import pickle
914-
with open(dataset_path, 'wb') as f:
915-
pickle.dump(dataset, f)
916-
917-
import inspect
918-
919-
# Construct protobuf message for data reader
920-
reader = lbann.reader_pb2.Reader()
921-
reader.name = 'python_dataset'
922-
reader.role = execution_mode
923-
reader.shuffle = False
924-
reader.fraction_of_data_to_use = 1.0
925-
reader.python_dataset.dataset_path = dataset_path
926-
reader.python_dataset.module_dir = os.path.dirname(os.path.abspath(file_name))
927-
928-
return reader
929-
930892

931893
def numpy_l2norm2(x):
932894
"""Square of L2 norm, computed with NumPy

Diff for: ci_test/unit_tests/test_unit_datareader_python_dataset.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os.path
33
import sys
44
import numpy as np
5-
from lbann.util.data import Dataset, Sample, SampleDims
5+
from lbann.util.data import Dataset, Sample, SampleDims, construct_python_dataset_reader
66

77
# Bamboo utilities
88
current_file = os.path.realpath(__file__)
@@ -100,23 +100,25 @@ def construct_data_reader(lbann):
100100
101101
"""
102102

103+
dataset_path = os.path.join(work_dir, 'dataset.pkl')
104+
103105
# Note: The training data reader should be removed when
104106
# https://github.com/LLNL/lbann/issues/1098 is resolved.
105107
message = lbann.reader_pb2.DataReader()
106108
message.reader.extend([
107-
tools.create_python_dataset_reader(
108-
lbann,
109-
__file__,
109+
construct_python_dataset_reader(
110110
test_dataset,
111-
'train'
111+
dataset_path,
112+
'train',
113+
shuffle=False
112114
)
113115
])
114116
message.reader.extend([
115-
tools.create_python_dataset_reader(
116-
lbann,
117-
__file__,
117+
construct_python_dataset_reader(
118118
test_dataset,
119-
'test'
119+
dataset_path,
120+
'test',
121+
shuffle=False
120122
)
121123
])
122124
return message
@@ -125,6 +127,11 @@ def construct_data_reader(lbann):
125127
# Setup PyTest
126128
# ==============================================
127129

130+
work_dir = os.path.join(os.path.dirname(__file__),
131+
'experiments',
132+
os.path.basename(__file__).split('.py')[0])
133+
os.makedirs(work_dir, exist_ok=True)
134+
128135
# Create test functions that can interact with PyTest
129-
for _test_func in tools.create_tests(setup_experiment, __file__):
136+
for _test_func in tools.create_tests(setup_experiment, __file__, work_dir=work_dir):
130137
globals()[_test_func.__name__] = _test_func

Diff for: ci_test/unit_tests/test_unit_datareader_python_dataset_distconv.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import numpy as np
66
import lbann.contrib.args
7-
from lbann.util.data import DistConvDataset, Sample, SampleDims
7+
from lbann.util.data import DistConvDataset, Sample, SampleDims, construct_python_dataset_reader
88

99
# Bamboo utilities
1010
current_file = os.path.realpath(__file__)
@@ -142,23 +142,25 @@ def construct_data_reader(lbann):
142142
143143
"""
144144

145+
dataset_path = os.path.join(work_dir, 'dataset.pkl')
146+
145147
# Note: The training data reader should be removed when
146148
# https://github.com/LLNL/lbann/issues/1098 is resolved.
147149
message = lbann.reader_pb2.DataReader()
148150
message.reader.extend([
149-
tools.create_python_dataset_reader(
150-
lbann,
151-
__file__,
151+
construct_python_dataset_reader(
152152
test_dataset,
153-
'train'
153+
dataset_path,
154+
'train',
155+
shuffle=False
154156
)
155157
])
156158
message.reader.extend([
157-
tools.create_python_dataset_reader(
158-
lbann,
159-
__file__,
159+
construct_python_dataset_reader(
160160
test_dataset,
161-
'test'
161+
dataset_path,
162+
'train',
163+
shuffle=False
162164
)
163165
])
164166
return message
@@ -167,8 +169,13 @@ def construct_data_reader(lbann):
167169
# Setup PyTest
168170
# ==============================================
169171

172+
work_dir = os.path.join(os.path.dirname(__file__),
173+
'experiments',
174+
os.path.basename(__file__).split('.py')[0])
175+
os.makedirs(work_dir, exist_ok=True)
176+
170177
# Create test functions that can interact with PyTest
171-
for _test_func in tools.create_tests(setup_experiment, __file__,
178+
for _test_func in tools.create_tests(setup_experiment, __file__, work_dir=work_dir,
172179
environment=lbann.contrib.args.get_distconv_environment(
173180
num_io_partitions=tools.gpus_per_node(lbann)
174181
)):

0 commit comments

Comments
 (0)