|
| 1 | +import os |
| 2 | +import os.path |
| 3 | +import sys |
| 4 | +import numpy as np |
| 5 | +from lbann.util.data import Dataset, Sample, SampleDims, construct_python_dataset_reader |
| 6 | + |
| 7 | +# Bamboo utilities |
| 8 | +current_file = os.path.realpath(__file__) |
| 9 | +current_dir = os.path.dirname(current_file) |
| 10 | +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) |
| 11 | +import tools |
| 12 | + |
| 13 | +# ============================================== |
| 14 | +# Objects for Python dataset data reader |
| 15 | +# ============================================== |
| 16 | +# Note: The Python dataset data reader loads the dataset constructed below. |
| 17 | + |
| 18 | +# Data |
| 19 | +class TestDataset(Dataset): |
| 20 | + def __init__(self): |
| 21 | + np.random.seed(20240109) |
| 22 | + self.num_samples = 29 |
| 23 | + self.sample_size = 7 |
| 24 | + self.samples = np.random.normal(size=(self.num_samples,self.sample_size)).astype(np.float32) |
| 25 | + |
| 26 | + def __len__(self): |
| 27 | + return self.num_samples |
| 28 | + |
| 29 | + def __getitem__(self, index): |
| 30 | + return Sample(sample=self.samples[index,:]) |
| 31 | + |
| 32 | + def get_sample_dims(self): |
| 33 | + return SampleDims(sample=[self.sample_size]) |
| 34 | + |
| 35 | +test_dataset = TestDataset() |
| 36 | + |
| 37 | +# ============================================== |
| 38 | +# Setup LBANN experiment |
| 39 | +# ============================================== |
| 40 | + |
| 41 | +def setup_experiment(lbann, weekly): |
| 42 | + """Construct LBANN experiment. |
| 43 | +
|
| 44 | + Args: |
| 45 | + lbann (module): Module for LBANN Python frontend |
| 46 | +
|
| 47 | + """ |
| 48 | + mini_batch_size = len(test_dataset) // 4 |
| 49 | + trainer = lbann.Trainer(mini_batch_size) |
| 50 | + model = construct_model(lbann) |
| 51 | + data_reader = construct_data_reader(lbann) |
| 52 | + optimizer = lbann.NoOptimizer() |
| 53 | + return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes |
| 54 | + |
| 55 | +def construct_model(lbann): |
| 56 | + """Construct LBANN model. |
| 57 | +
|
| 58 | + Args: |
| 59 | + lbann (module): Module for LBANN Python frontend |
| 60 | +
|
| 61 | + """ |
| 62 | + |
| 63 | + # Layer graph |
| 64 | + x = lbann.Input(data_field='samples') |
| 65 | + y = lbann.L2Norm2(x) |
| 66 | + layers = list(lbann.traverse_layer_graph(x)) |
| 67 | + metric = lbann.Metric(y, name='obj') |
| 68 | + callbacks = [] |
| 69 | + |
| 70 | + # Compute expected value with NumPy |
| 71 | + vals = [] |
| 72 | + for i in range(len(test_dataset)): |
| 73 | + x = test_dataset[i].sample.astype(np.float64) |
| 74 | + y = tools.numpy_l2norm2(x) |
| 75 | + vals.append(y) |
| 76 | + val = np.mean(vals) |
| 77 | + tol = 8 * val * np.finfo(np.float32).eps |
| 78 | + callbacks.append(lbann.CallbackCheckMetric( |
| 79 | + metric=metric.name, |
| 80 | + lower_bound=val-tol, |
| 81 | + upper_bound=val+tol, |
| 82 | + error_on_failure=True, |
| 83 | + execution_modes='test')) |
| 84 | + |
| 85 | + # Construct model |
| 86 | + num_epochs = 0 |
| 87 | + return lbann.Model(num_epochs, |
| 88 | + layers=layers, |
| 89 | + metrics=[metric], |
| 90 | + callbacks=callbacks) |
| 91 | + |
| 92 | +def construct_data_reader(lbann): |
| 93 | + """Construct Protobuf message for Python dataset data reader. |
| 94 | +
|
| 95 | + The Python data reader will import the current Python file to |
| 96 | + access the sample access functions. |
| 97 | +
|
| 98 | + Args: |
| 99 | + lbann (module): Module for LBANN Python frontend |
| 100 | +
|
| 101 | + """ |
| 102 | + |
| 103 | + dataset_path = os.path.join(work_dir, 'dataset.pkl') |
| 104 | + |
| 105 | + # Note: The training data reader should be removed when |
| 106 | + # https://github.com/LLNL/lbann/issues/1098 is resolved. |
| 107 | + message = lbann.reader_pb2.DataReader() |
| 108 | + message.reader.extend([ |
| 109 | + construct_python_dataset_reader( |
| 110 | + test_dataset, |
| 111 | + dataset_path, |
| 112 | + 'train', |
| 113 | + shuffle=False |
| 114 | + ) |
| 115 | + ]) |
| 116 | + message.reader.extend([ |
| 117 | + construct_python_dataset_reader( |
| 118 | + test_dataset, |
| 119 | + dataset_path, |
| 120 | + 'test', |
| 121 | + shuffle=False |
| 122 | + ) |
| 123 | + ]) |
| 124 | + return message |
| 125 | + |
| 126 | +# ============================================== |
| 127 | +# Setup PyTest |
| 128 | +# ============================================== |
| 129 | + |
| 130 | +work_dir = os.path.join(os.path.dirname(__file__), |
| 131 | + 'experiments', |
| 132 | + os.path.basename(__file__).split('.py')[0]) |
| 133 | +os.makedirs(work_dir, exist_ok=True) |
| 134 | + |
| 135 | +# Create test functions that can interact with PyTest |
| 136 | +for _test_func in tools.create_tests(setup_experiment, __file__, work_dir=work_dir): |
| 137 | + globals()[_test_func.__name__] = _test_func |
0 commit comments