Skip to content

Commit

Permalink
Save dataset pickle in work dir
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 committed Feb 27, 2024
1 parent c33a569 commit 861931c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 58 deletions.
38 changes: 0 additions & 38 deletions ci_test/common_python/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,44 +889,6 @@ def create_python_data_reader(lbann,

return reader

def create_python_dataset_reader(lbann,
file_name,
dataset,
execution_mode,
dataset_path=None):
"""Create protobuf message for Pythond dataset reader
A Python dataset reader gets data by loading a pickled
lbann.util.data.Dataset object.
Args:
lbann (module): Module for LBANN Python frontend.
file_name (str): Python file with dataset class definition.
dataset (lbann.util.data.Dataset): Dataset object to be pickled.
execution_mode (str): 'train', 'validation', or 'test'
"""

# Extract paths
if dataset_path is None:
dataset_path = os.path.join(os.environ['TMPDIR'], f'dataset_{execution_mode}.pkl')
import pickle
with open(dataset_path, 'wb') as f:
pickle.dump(dataset, f)

import inspect

# Construct protobuf message for data reader
reader = lbann.reader_pb2.Reader()
reader.name = 'python_dataset'
reader.role = execution_mode
reader.shuffle = False
reader.fraction_of_data_to_use = 1.0
reader.python_dataset.dataset_path = dataset_path
reader.python_dataset.module_dir = os.path.dirname(os.path.abspath(file_name))

return reader


def numpy_l2norm2(x):
"""Square of L2 norm, computed with NumPy
Expand Down
27 changes: 17 additions & 10 deletions ci_test/unit_tests/test_unit_datareader_python_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path
import sys
import numpy as np
from lbann.util.data import Dataset, Sample, SampleDims
from lbann.util.data import Dataset, Sample, SampleDims, construct_python_dataset_reader

# Bamboo utilities
current_file = os.path.realpath(__file__)
Expand Down Expand Up @@ -100,23 +100,25 @@ def construct_data_reader(lbann):
"""

dataset_path = os.path.join(work_dir, 'dataset.pkl')

# Note: The training data reader should be removed when
# https://github.com/LLNL/lbann/issues/1098 is resolved.
message = lbann.reader_pb2.DataReader()
message.reader.extend([
tools.create_python_dataset_reader(
lbann,
__file__,
construct_python_dataset_reader(
test_dataset,
'train'
dataset_path,
'train',
shuffle=False
)
])
message.reader.extend([
tools.create_python_dataset_reader(
lbann,
__file__,
construct_python_dataset_reader(
test_dataset,
'test'
dataset_path,
'test',
shuffle=False
)
])
return message
Expand All @@ -125,6 +127,11 @@ def construct_data_reader(lbann):
# Setup PyTest
# ==============================================

work_dir = os.path.join(os.path.dirname(__file__),
'experiments',
os.path.basename(__file__).split('.py')[0])
os.makedirs(work_dir, exist_ok=True)

# Create test functions that can interact with PyTest
for _test_func in tools.create_tests(setup_experiment, __file__):
for _test_func in tools.create_tests(setup_experiment, __file__, work_dir=work_dir):
globals()[_test_func.__name__] = _test_func
27 changes: 17 additions & 10 deletions ci_test/unit_tests/test_unit_datareader_python_dataset_distconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import numpy as np
import lbann.contrib.args
from lbann.util.data import DistConvDataset, Sample, SampleDims
from lbann.util.data import DistConvDataset, Sample, SampleDims, construct_python_dataset_reader

# Bamboo utilities
current_file = os.path.realpath(__file__)
Expand Down Expand Up @@ -142,23 +142,25 @@ def construct_data_reader(lbann):
"""

dataset_path = os.path.join(work_dir, 'dataset.pkl')

# Note: The training data reader should be removed when
# https://github.com/LLNL/lbann/issues/1098 is resolved.
message = lbann.reader_pb2.DataReader()
message.reader.extend([
tools.create_python_dataset_reader(
lbann,
__file__,
construct_python_dataset_reader(
test_dataset,
'train'
dataset_path,
'train',
shuffle=False
)
])
message.reader.extend([
tools.create_python_dataset_reader(
lbann,
__file__,
construct_python_dataset_reader(
test_dataset,
'test'
dataset_path,
'train',
shuffle=False
)
])
return message
Expand All @@ -167,8 +169,13 @@ def construct_data_reader(lbann):
# Setup PyTest
# ==============================================

work_dir = os.path.join(os.path.dirname(__file__),
'experiments',
os.path.basename(__file__).split('.py')[0])
os.makedirs(work_dir, exist_ok=True)

# Create test functions that can interact with PyTest
for _test_func in tools.create_tests(setup_experiment, __file__,
for _test_func in tools.create_tests(setup_experiment, __file__, work_dir=work_dir,
environment=lbann.contrib.args.get_distconv_environment(
num_io_partitions=tools.gpus_per_node(lbann)
)):
Expand Down

0 comments on commit 861931c

Please sign in to comment.