-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathlbann_inf.cpp
125 lines (108 loc) · 4.15 KB
/
lbann_inf.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2014-2021, Lawrence Livermore National Security, LLC.
// Produced at the Lawrence Livermore National Laboratory.
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
// the CONTRIBUTORS file. <[email protected]>
//
// LLNL-CODE-697807.
// All rights reserved.
//
// This file is part of LBANN: Livermore Big Artificial Neural Network
// Toolkit. For details, see http://software.llnl.gov/LBANN or
// https://github.com/LLNL/LBANN.
//
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you
// may not use this file except in compliance with the License. You may
// obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the license.
//
// lbann_proto.cpp - prototext application
////////////////////////////////////////////////////////////////////////////////
#include "lbann/lbann.hpp"
#include "lbann/proto/proto_common.hpp"
#include "lbann/utils/argument_parser.hpp"
#include "lbann/utils/protobuf_utils.hpp"
#include "lbann/proto/lbann.pb.h"
#include "lbann/proto/model.pb.h"
#include <dirent.h>
#include <cstdlib>
using namespace lbann;
int main(int argc, char* argv[])
{
auto& arg_parser = global_argument_parser();
construct_all_options();
try {
arg_parser.parse(argc, argv);
}
catch (std::exception const& e) {
std::cerr << "Error during argument parsing:\n\ne.what():\n\n " << e.what()
<< "\n\nProcess terminating." << std::endl;
std::terminate();
}
auto comm = initialize(argc, argv);
const bool master = comm->am_world_master();
try {
// Split MPI into trainers
allocate_trainer_resources(comm.get());
if (arg_parser.help_requested() or argc == 1) {
if (master)
std::cout << arg_parser << std::endl;
return EXIT_SUCCESS;
}
std::ostringstream err;
auto pbs = protobuf_utils::load_prototext(master);
// Optionally over-ride some values in the prototext for each model
for (size_t i = 0; i < pbs.size(); i++) {
get_cmdline_overrides(*comm, *(pbs[i]));
}
lbann_data::LbannPB& pb = *(pbs[0]);
lbann_data::Trainer* pb_trainer = pb.mutable_trainer();
// Construct the trainer
auto& trainer = construct_trainer(comm.get(), pb_trainer, *(pbs[0]));
thread_pool& io_thread_pool = trainer.get_io_thread_pool();
auto* dr =
trainer.get_data_coordinator().get_data_reader(execution_mode::testing);
if (dr == nullptr) {
LBANN_ERROR("No testing data reader defined");
}
auto& ds =
trainer.get_data_coordinator().get_dataset(execution_mode::testing);
std::vector<std::unique_ptr<model>> models;
for (auto&& pb_model : pbs) {
models.emplace_back(
build_model_from_prototext(argc,
argv,
pb_trainer,
*pb_model,
comm.get(),
io_thread_pool,
trainer.get_callbacks_with_ownership()));
}
/// Interleave the inference between the models so that they can use a
/// shared data reader Enable shared testing data readers on the command
/// line via --share_testing_data_readers=1
El::Int num_samples = ds.get_num_iterations_per_epoch();
if (num_samples == 0) {
LBANN_ERROR("The testing data reader does not have any samples");
}
for (El::Int s = 0; s < num_samples; s++) {
for (auto&& m : models) {
trainer.evaluate(m.get(), execution_mode::testing, 1);
}
}
}
catch (std::exception& e) {
El::ReportException(e);
// It's possible that a proper subset of ranks throw some
// exception. But we want to tear down the whole world.
El::mpi::Abort(El::mpi::COMM_WORLD, EXIT_FAILURE);
}
return EXIT_SUCCESS;
}