29
29
#include < mpi.h>
30
30
#include < stdio.h>
31
31
32
+ // Add test-specific options
32
33
void construct_opts (int argc, char **argv) {
33
34
auto & arg_parser = lbann::global_argument_parser ();
35
+ lbann::construct_std_options ();
36
+ lbann::construct_datastore_options ();
34
37
arg_parser.add_option (" samples" ,
35
38
{" -n" },
36
39
" Number of samples to run inference on" ,
@@ -52,20 +55,76 @@ void construct_opts(int argc, char **argv) {
52
55
" Number of labels in dataset" ,
53
56
10 );
54
57
arg_parser.add_option (" minibatchsize" ,
55
- {" -mbs" },
58
+ {" -- mbs" },
56
59
" Number of samples in a mini-batch" ,
57
60
16 );
61
+ arg_parser.add_flag (" use_conduit" ,
62
+ {" --conduit" },
63
+ " Use Conduit node samples (Default is non-distributed matrix)" );
64
+ arg_parser.add_flag (" use_dist_matrix" ,
65
+ {" --dist" },
66
+ " Use Hydrogen distributed matrix (Default is non-distributed matrix)" );
58
67
arg_parser.add_required_argument <std::string>
59
68
(" model" ,
60
69
" Directory containing checkpointed model" );
61
70
arg_parser.parse (argc, argv);
62
71
}
63
72
64
- El::DistMatrix<float , El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>
65
- random_samples (El::Grid const & g, int n, int c, int h, int w) {
73
+ // Generates random samples and labels for mnist data in Hydrogen matrix
74
+ std::map<
75
+ std::string,
76
+ El::Matrix<float , El::Device::CPU>>
77
+ mat_mnist_samples (int n, int c, int h, int w)
78
+ {
79
+ El::Matrix<float , El::Device::CPU>
80
+ samples (c * h * w, n);
81
+ El::MakeUniform (samples);
82
+ El::Matrix<float , El::Device::CPU>
83
+ labels (1 , n);
84
+ El::MakeUniform (labels);
85
+ std::map<
86
+ std::string,
87
+ El::Matrix<float , El::Device::CPU>>
88
+ samples_map = {{" data/samples" , samples}, {" data/labels" , labels}};
89
+ return samples_map;
90
+ }
91
+
92
+ // Generates random samples and labels for mnist data in Hydrogen distributed matrix
93
+ std::map<
94
+ std::string,
95
+ El::DistMatrix<float , El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>>
96
+ distmat_mnist_samples (El::Grid const & g, int n, int c, int h, int w)
97
+ {
66
98
El::DistMatrix<float , El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>
67
- samples (n, c * h * w, g);
99
+ samples (c * h * w, n , g);
68
100
El::MakeUniform (samples);
101
+ El::DistMatrix<float , El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>
102
+ labels (1 , n, g);
103
+ El::MakeUniform (labels);
104
+ std::map<
105
+ std::string,
106
+ El::DistMatrix<float , El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>>
107
+ samples_map = {{" data/samples" , samples}, {" data/labels" , labels}};
108
+ return samples_map;
109
+ }
110
+
111
+ // Fills array with random values
112
+ void random_fill (float *arr, int size, int max_val=255 ) {
113
+ for (int i; i < size; i++) {
114
+ arr[i] = (float )(std::rand () % max_val) / (float )max_val;
115
+ }
116
+ }
117
+
118
+ // Generates random samples and labels for mnist data in vector of Conduit nodes
119
+ std::vector<conduit::Node> conduit_mnist_samples (int n, int c, int h, int w) {
120
+ std::vector<conduit::Node> samples (n);
121
+ int sample_size = c * h * w;
122
+ float this_sample[sample_size];
123
+ for (int i; i<n; i++) {
124
+ random_fill (this_sample, sample_size);
125
+ samples[i][" data/samples" ].set (this_sample, sample_size);
126
+ samples[i][" data/labels" ] = std::rand () % 10 ;
127
+ }
69
128
return samples;
70
129
}
71
130
@@ -79,10 +138,13 @@ int main(int argc, char **argv) {
79
138
int rank;
80
139
MPI_Comm_rank (MPI_COMM_WORLD, &rank);
81
140
82
- // Get input arguments and print values
141
+ // Get input arguments, check and print values
83
142
construct_opts (argc, argv);
84
143
auto & arg_parser = lbann::global_argument_parser ();
85
144
if (rank == 0 ) {
145
+ if (arg_parser.get <bool >(" use_conduit" ) && arg_parser.get <bool >(" use_dist_matrix" )) {
146
+ LBANN_ERROR (" Cannot use conduit node and distributed matrix together, choose one: --conduit --dist" );
147
+ }
86
148
std::stringstream msg;
87
149
msg << " Model: " << arg_parser.get <std::string>(" model" ) << std::endl;
88
150
msg << " { N, c, h, w } = { " << arg_parser.get <int >(" samples" ) << " , " ;
@@ -94,8 +156,8 @@ int main(int argc, char **argv) {
94
156
std::cout << msg.str ();
95
157
}
96
158
97
- // Load model and run inference on samples
98
159
auto lbann_comm = lbann::initialize_lbann (MPI_COMM_WORLD);
160
+
99
161
auto m = lbann::load_inference_model (lbann_comm.get (),
100
162
arg_parser.get <std::string>(" model" ),
101
163
arg_parser.get <int >(" minibatchsize" ),
@@ -105,14 +167,31 @@ int main(int argc, char **argv) {
105
167
arg_parser.get <int >(" width" )
106
168
},
107
169
{arg_parser.get <int >(" labels" )});
108
- auto samples = random_samples (lbann_comm->get_trainer_grid (),
109
- arg_parser.get <int >(" samples" ),
110
- arg_parser.get <int >(" channels" ),
111
- arg_parser.get <int >(" height" ),
112
- arg_parser.get <int >(" width" ));
113
- auto labels = lbann::infer (m.get (),
114
- samples,
115
- arg_parser.get <int >(" minibatchsize" ));
170
+
171
+ // three options for data generation
172
+ if (arg_parser.get <bool >(" use_conduit" )) {
173
+ auto samples = conduit_mnist_samples (arg_parser.get <int >(" samples" ),
174
+ arg_parser.get <int >(" channels" ),
175
+ arg_parser.get <int >(" height" ),
176
+ arg_parser.get <int >(" width" ));
177
+ lbann::set_inference_samples (samples);
178
+ } else if (arg_parser.get <bool >(" use_dist_matrix" )) {
179
+ auto samples = distmat_mnist_samples (lbann_comm->get_trainer_grid (),
180
+ arg_parser.get <int >(" samples" ),
181
+ arg_parser.get <int >(" channels" ),
182
+ arg_parser.get <int >(" height" ),
183
+ arg_parser.get <int >(" width" ));
184
+ lbann::set_inference_samples (samples);
185
+ } else {
186
+ auto samples = mat_mnist_samples (
187
+ arg_parser.get <int >(" samples" ),
188
+ arg_parser.get <int >(" channels" ),
189
+ arg_parser.get <int >(" height" ),
190
+ arg_parser.get <int >(" width" ));
191
+ lbann::set_inference_samples (samples);
192
+ }
193
+
194
+ auto labels = lbann::inference (m.get ());
116
195
117
196
// Print inference results
118
197
if (lbann_comm->am_world_master ()) {
0 commit comments