Skip to content

Commit 008e912

Browse files
mrwyattiibvanessen
authored andcommitted
added template for data reader to pass conduit node from driver
added conduit to cmakelist fixed error with global_trainer_ added simple conduit datareader to hold conduit node prototyping use of data_store to hold conduit nodes fixing bug with input buffers not being sized correctly fixed problem with unpacking conduit node moving {trainer, dc, dr, ds setup} and {loading inference samples} to separate functions extended core API for many different input types removed old code from first lbann-core impl added simple run script Fix things that have drifted in LBANN Get core-drive compiling again clang-format batch_functional_inference_algorithm Steps toward debugging the segfault in the inference algo test The test no longer segfaults. Now it just fails. Don't shuffle when setting up for inference Fix a spacing issue Updated CMake to install the core driver Build the core-driver
1 parent cae84af commit 008e912

File tree

19 files changed

+607
-218
lines changed

19 files changed

+607
-218
lines changed

Diff for: CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,7 @@ add_subdirectory(applications/CANDLE/pilot2/tools)
920920
add_subdirectory(applications/ATOM/utils)
921921
add_subdirectory(tests)
922922
add_subdirectory(scripts)
923+
add_subdirectory(core-driver)
923924

924925
################################################################
925926
# Install LBANN

Diff for: cmake/configure_files/LBANNConfig.cmake.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ set(LBANN_HAS_DIHYDROGEN @LBANN_HAS_DIHYDROGEN@)
7474
set(LBANN_HAS_DISTCONV @LBANN_HAS_DISTCONV@)
7575
set(LBANN_HAS_DOXYGEN @LBANN_HAS_DOXYGEN@)
7676
set(LBANN_HAS_EMBEDDED_PYTHON @LBANN_HAS_EMBEDDED_PYTHON@)
77-
set(LBANN_HAS_FFTW @LBANN_HAS_FFTW@
77+
set(LBANN_HAS_FFTW @LBANN_HAS_FFTW@)
7878
set(LBANN_HAS_FFTW_FLOAT @LBANN_HAS_FFTW_FLOAT@)
7979
set(LBANN_HAS_FFTW_DOUBLE @LBANN_HAS_FFTW_DOUBLE@)
8080
set(LBANN_HAS_GPU_FP16 @LBANN_HAS_GPU_FP16@)

Diff for: core-driver/CMakeLists.txt

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1-
cmake_minimum_required(VERSION 3.18.0)
2-
project(my_lbann_test C CXX)
1+
cmake_minimum_required(VERSION 3.21.0)
2+
project(my_lbann_test CXX)
33
find_package(LBANN 0.102.0 REQUIRED)
4-
add_executable(Main main.cpp)
5-
target_link_libraries(Main PRIVATE LBANN::lbann)
4+
find_package(Conduit CONFIG REQUIRED)
5+
add_executable(lbann-core main.cpp)
6+
target_link_libraries(lbann-core PRIVATE LBANN::lbann)
7+
8+
#target_link_libraries(lbann-bin lbann)
9+
set_target_properties(lbann-core
10+
PROPERTIES
11+
OUTPUT_NAME lbann-core-driver
12+
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
13+
14+
#list(APPEND LBANN_EXE_TGTS lbann-core)
15+
16+
install(TARGETS lbann-core
17+
EXPORT LBANNTargets
18+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})

Diff for: core-driver/main.cpp

+93-14
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@
2929
#include <mpi.h>
3030
#include <stdio.h>
3131

32+
// Add test-specific options
3233
void construct_opts(int argc, char **argv) {
3334
auto& arg_parser = lbann::global_argument_parser();
35+
lbann::construct_std_options();
36+
lbann::construct_datastore_options();
3437
arg_parser.add_option("samples",
3538
{"-n"},
3639
"Number of samples to run inference on",
@@ -52,20 +55,76 @@ void construct_opts(int argc, char **argv) {
5255
"Number of labels in dataset",
5356
10);
5457
arg_parser.add_option("minibatchsize",
55-
{"-mbs"},
58+
{"--mbs"},
5659
"Number of samples in a mini-batch",
5760
16);
61+
arg_parser.add_flag("use_conduit",
62+
{"--conduit"},
63+
"Use Conduit node samples (Default is non-distributed matrix)");
64+
arg_parser.add_flag("use_dist_matrix",
65+
{"--dist"},
66+
"Use Hydrogen distributed matrix (Default is non-distributed matrix)");
5867
arg_parser.add_required_argument<std::string>
5968
("model",
6069
"Directory containing checkpointed model");
6170
arg_parser.parse(argc, argv);
6271
}
6372

64-
El::DistMatrix<float, El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>
65-
random_samples(El::Grid const& g, int n, int c, int h, int w) {
73+
// Generates random samples and labels for mnist data in Hydrogen matrix
74+
std::map<
75+
std::string,
76+
El::Matrix<float, El::Device::CPU>>
77+
mat_mnist_samples(int n, int c, int h, int w)
78+
{
79+
El::Matrix<float, El::Device::CPU>
80+
samples(c * h * w, n);
81+
El::MakeUniform(samples);
82+
El::Matrix<float, El::Device::CPU>
83+
labels(1, n);
84+
El::MakeUniform(labels);
85+
std::map<
86+
std::string,
87+
El::Matrix<float, El::Device::CPU>>
88+
samples_map = {{"data/samples", samples}, {"data/labels", labels}};
89+
return samples_map;
90+
}
91+
92+
// Generates random samples and labels for mnist data in Hydrogen distributed matrix
93+
std::map<
94+
std::string,
95+
El::DistMatrix<float, El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>>
96+
distmat_mnist_samples(El::Grid const& g, int n, int c, int h, int w)
97+
{
6698
El::DistMatrix<float, El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>
67-
samples(n, c * h * w, g);
99+
samples(c * h * w, n, g);
68100
El::MakeUniform(samples);
101+
El::DistMatrix<float, El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>
102+
labels(1, n, g);
103+
El::MakeUniform(labels);
104+
std::map<
105+
std::string,
106+
El::DistMatrix<float, El::STAR, El::STAR, El::ELEMENT, El::Device::CPU>>
107+
samples_map = {{"data/samples", samples}, {"data/labels", labels}};
108+
return samples_map;
109+
}
110+
111+
// Fills array with random values
112+
void random_fill(float *arr, int size, int max_val=255) {
113+
for (int i; i < size; i++) {
114+
arr[i] = (float)(std::rand() % max_val) / (float)max_val;
115+
}
116+
}
117+
118+
// Generates random samples and labels for mnist data in vector of Conduit nodes
119+
std::vector<conduit::Node> conduit_mnist_samples(int n, int c, int h, int w) {
120+
std::vector<conduit::Node> samples(n);
121+
int sample_size = c * h * w;
122+
float this_sample[sample_size];
123+
for (int i; i<n; i++) {
124+
random_fill(this_sample, sample_size);
125+
samples[i]["data/samples"].set(this_sample, sample_size);
126+
samples[i]["data/labels"] = std::rand() % 10;
127+
}
69128
return samples;
70129
}
71130

@@ -79,10 +138,13 @@ int main(int argc, char **argv) {
79138
int rank;
80139
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
81140

82-
// Get input arguments and print values
141+
// Get input arguments, check and print values
83142
construct_opts(argc, argv);
84143
auto& arg_parser = lbann::global_argument_parser();
85144
if (rank == 0) {
145+
if (arg_parser.get<bool>("use_conduit") && arg_parser.get<bool>("use_dist_matrix")) {
146+
LBANN_ERROR("Cannot use conduit node and distributed matrix together, choose one: --conduit --dist");
147+
}
86148
std::stringstream msg;
87149
msg << "Model: " << arg_parser.get<std::string>("model") << std::endl;
88150
msg << "{ N, c, h, w } = { " << arg_parser.get<int>("samples") << ", ";
@@ -94,8 +156,8 @@ int main(int argc, char **argv) {
94156
std::cout << msg.str();
95157
}
96158

97-
// Load model and run inference on samples
98159
auto lbann_comm = lbann::initialize_lbann(MPI_COMM_WORLD);
160+
99161
auto m = lbann::load_inference_model(lbann_comm.get(),
100162
arg_parser.get<std::string>("model"),
101163
arg_parser.get<int>("minibatchsize"),
@@ -105,14 +167,31 @@ int main(int argc, char **argv) {
105167
arg_parser.get<int>("width")
106168
},
107169
{arg_parser.get<int>("labels")});
108-
auto samples = random_samples(lbann_comm->get_trainer_grid(),
109-
arg_parser.get<int>("samples"),
110-
arg_parser.get<int>("channels"),
111-
arg_parser.get<int>("height"),
112-
arg_parser.get<int>("width"));
113-
auto labels = lbann::infer(m.get(),
114-
samples,
115-
arg_parser.get<int>("minibatchsize"));
170+
171+
// three options for data generation
172+
if (arg_parser.get<bool>("use_conduit")) {
173+
auto samples = conduit_mnist_samples(arg_parser.get<int>("samples"),
174+
arg_parser.get<int>("channels"),
175+
arg_parser.get<int>("height"),
176+
arg_parser.get<int>("width"));
177+
lbann::set_inference_samples(samples);
178+
} else if (arg_parser.get<bool>("use_dist_matrix")) {
179+
auto samples = distmat_mnist_samples(lbann_comm->get_trainer_grid(),
180+
arg_parser.get<int>("samples"),
181+
arg_parser.get<int>("channels"),
182+
arg_parser.get<int>("height"),
183+
arg_parser.get<int>("width"));
184+
lbann::set_inference_samples(samples);
185+
} else {
186+
auto samples = mat_mnist_samples(
187+
arg_parser.get<int>("samples"),
188+
arg_parser.get<int>("channels"),
189+
arg_parser.get<int>("height"),
190+
arg_parser.get<int>("width"));
191+
lbann::set_inference_samples(samples);
192+
}
193+
194+
auto labels = lbann::inference(m.get());
116195

117196
// Print inference results
118197
if (lbann_comm->am_world_master()) {

Diff for: core-driver/run.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
export AL_PROGRESS_RANKS_PER_NUMA_NODE=2
2+
export OMP_NUM_THREADS=8
3+
export MV2_USE_RDMA_CM=0
4+
5+
# This should be a checkpointed lenet model
6+
MODEL_LOC="path/to/checkpointed/model"
7+
8+
./Main $MODEL_LOC
9+
./Main $MODEL_LOC --dist
10+
./Main $MODEL_LOC --conduit

Diff for: include/lbann/data_ingestion/readers/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ set_full_path(THIS_DIR_HEADERS
2929
metadata.hpp
3030
# Data readers
3131
data_reader_cifar10.hpp
32+
data_reader_conduit.hpp
3233
data_reader_csv.hpp
3334
data_reader_image.hpp
3435
data_reader_HDF5.hpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// Copyright (c) 2014-2021, Lawrence Livermore National Security, LLC.
3+
// Produced at the Lawrence Livermore National Laboratory.
4+
// Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5+
// the CONTRIBUTORS file. <[email protected]>
6+
//
7+
// LLNL-CODE-697807.
8+
// All rights reserved.
9+
//
10+
// This file is part of LBANN: Livermore Big Artificial Neural Network
11+
// Toolkit. For details, see http://software.llnl.gov/LBANN or
12+
// https://github.com/LLNL/LBANN.
13+
//
14+
// Licensed under the Apache License, Version 2.0 (the "Licensee"); you
15+
// may not use this file except in compliance with the License. You may
16+
// obtain a copy of the License at:
17+
//
18+
// http://www.apache.org/licenses/LICENSE-2.0
19+
//
20+
// Unless required by applicable law or agreed to in writing, software
21+
// distributed under the License is distributed on an "AS IS" BASIS,
22+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
23+
// implied. See the License for the specific language governing
24+
// permissions and limitations under the license.
25+
////////////////////////////////////////////////////////////////////////////////
26+
27+
#ifndef LBANN_DATA_READER_CONDUIT_HPP
28+
#define LBANN_DATA_READER_CONDUIT_HPP
29+
30+
#include "lbann/data_readers/data_reader.hpp"
31+
#include "lbann/data_store/data_store_conduit.hpp"
32+
33+
namespace lbann {
34+
/**
35+
* A generalized data reader for passed in conduit nodes.
36+
*/
37+
class conduit_data_reader : public generic_data_reader
38+
{
39+
public:
40+
conduit_data_reader* copy() const override { return new conduit_data_reader(*this); }
41+
bool has_conduit_output() override { return true; }
42+
void load() override;
43+
bool fetch_conduit_node(conduit::Node& sample, int data_id) override;
44+
45+
void set_data_dims(std::vector<int> dims);
46+
void set_label_dims(std::vector<int> dims);
47+
48+
std::string get_type() const override { return "conduit_data_reader"; }
49+
int get_linearized_data_size() const override {
50+
int data_size = 1;
51+
for(int i : m_data_dims) {
52+
data_size *= i;
53+
}
54+
return data_size;
55+
}
56+
int get_linearized_label_size() const override {
57+
int label_size = 1;
58+
for(int i : m_label_dims) {
59+
label_size *= i;
60+
}
61+
return label_size;
62+
}
63+
64+
protected:
65+
std::vector<int> m_data_dims;
66+
std::vector<int> m_label_dims;
67+
68+
}; // END: class conduit_data_reader
69+
70+
} // namespace lbann
71+
72+
#endif // LBANN_DATA_READER_CONDUIT_HPP

0 commit comments

Comments
 (0)