1515 * limitations under the License.
1616 */
1717
18- #include " config.h"
19- #include < iostream>
18+ #include < gflags/gflags.h>
2019#include < nixl.h>
2120#include < sys/time.h>
22- #include < gflags/gflags.h>
23- #include " utils/utils.h"
21+
22+ #include < iostream>
23+
24+ #include " config.h"
2425#include " utils/scope_guard.h"
26+ #include " utils/utils.h"
2527#include " worker/nixl/nixl_worker.h"
2628#if HAVE_NVSHMEM && HAVE_CUDA
2729#include " worker/nvshmem/nvshmem_worker.h"
2830#endif
2931#include < unistd.h>
30- # include < memory >
32+
3133#include < csignal>
34+ #include < memory>
3235
33- static std::pair<size_t , size_t > getStrideScheme (xferBenchWorker &worker, int num_threads) {
34- int initiator_device, target_device;
36+ static std::pair<size_t , size_t > getStrideScheme (xferBenchWorker &worker, int num_threads)
37+ {
38+ int initiator_device, target_device;
3539 size_t buffer_size, count, stride;
3640
3741 initiator_device = xferBenchConfig::num_initiator_dev;
@@ -44,22 +48,14 @@ static std::pair<size_t, size_t> getStrideScheme(xferBenchWorker &worker, int nu
4448 // TODO: add macro for schemes
4549 // Maybe, we can squeze ONE_TO_MANY and MANY_TO_ONE into TP scheme
4650 if (XFERBENCH_SCHEME_ONE_TO_MANY == xferBenchConfig::scheme) {
47- if (worker.isInitiator ()) {
48- count = target_device;
49- }
51+ if (worker.isInitiator ()) { count = target_device; }
5052 } else if (XFERBENCH_SCHEME_MANY_TO_ONE == xferBenchConfig::scheme) {
51- if (worker.isTarget ()) {
52- count = initiator_device;
53- }
53+ if (worker.isTarget ()) { count = initiator_device; }
5454 } else if (XFERBENCH_SCHEME_TP == xferBenchConfig::scheme) {
5555 if (worker.isInitiator ()) {
56- if (initiator_device < target_device) {
57- count = target_device / initiator_device;
58- }
56+ if (initiator_device < target_device) { count = target_device / initiator_device; }
5957 } else if (worker.isTarget ()) {
60- if (target_device < initiator_device) {
61- count = initiator_device / target_device;
62- }
58+ if (target_device < initiator_device) { count = initiator_device / target_device; }
6359 }
6460 }
6561 stride = buffer_size / count;
@@ -68,14 +64,13 @@ static std::pair<size_t, size_t> getStrideScheme(xferBenchWorker &worker, int nu
6864}
6965
7066static std::vector<std::vector<xferBenchIOV>> createTransferDescLists (xferBenchWorker &worker,
71- std::vector<std::vector<xferBenchIOV>> &iov_lists,
72- size_t block_size,
73- size_t batch_size,
74- int num_threads) {
67+ std::vector<std::vector<xferBenchIOV>> &iov_lists, size_t block_size, size_t batch_size,
68+ int num_threads)
69+ {
7570 auto [count, stride] = getStrideScheme (worker, num_threads);
7671 std::vector<std::vector<xferBenchIOV>> xfer_lists;
7772
78- for (const auto &iov_list: iov_lists) {
73+ for (const auto &iov_list : iov_lists) {
7974 std::vector<xferBenchIOV> xfer_list;
8075
8176 for (const auto &iov : iov_list) {
@@ -84,9 +79,8 @@ static std::vector<std::vector<xferBenchIOV>> createTransferDescLists(xferBenchW
8479
8580 for (size_t j = 0 ; j < batch_size; j++) {
8681 size_t block_offset = ((j * block_size) % iov.len );
87- xfer_list.push_back (xferBenchIOV ((iov.addr + dev_offset) + block_offset,
88- block_size,
89- iov.devId ));
82+ xfer_list.push_back (xferBenchIOV (
83+ (iov.addr + dev_offset) + block_offset, block_size, iov.devId ));
9084 }
9185 }
9286 }
@@ -97,24 +91,20 @@ static std::vector<std::vector<xferBenchIOV>> createTransferDescLists(xferBenchW
9791 return xfer_lists;
9892}
9993
100- static int processBatchSizes (xferBenchWorker &worker,
101- std::vector<std::vector<xferBenchIOV>> &iov_lists,
102- size_t block_size, int num_threads) {
94+ static int processBatchSizes (xferBenchWorker &worker,
95+ std::vector<std::vector<xferBenchIOV>> &iov_lists, size_t block_size, int num_threads)
96+ {
10397 for (size_t batch_size = xferBenchConfig::start_batch_size;
104- !worker.signaled () &&
105- batch_size <= xferBenchConfig::max_batch_size;
106- batch_size *= 2 ) {
107- auto local_trans_lists = createTransferDescLists (worker,
108- iov_lists,
109- block_size,
110- batch_size,
111- num_threads);
98+ !worker.signaled () && batch_size <= xferBenchConfig::max_batch_size; batch_size *= 2 ) {
99+ auto local_trans_lists =
100+ createTransferDescLists (worker, iov_lists, block_size, batch_size, num_threads);
112101
113102 if (worker.isTarget ()) {
114103 worker.exchangeIOV (local_trans_lists);
115104 worker.poll (block_size);
116105
117- if (xferBenchConfig::check_consistency && xferBenchConfig::op_type == XFERBENCH_OP_WRITE) {
106+ if (xferBenchConfig::check_consistency &&
107+ xferBenchConfig::op_type == XFERBENCH_OP_WRITE) {
118108 xferBenchUtils::checkConsistency (local_trans_lists);
119109 }
120110 if (IS_PAIRWISE_AND_SG ()) {
@@ -123,28 +113,26 @@ static int processBatchSizes(xferBenchWorker &worker,
123113 xferBenchUtils::printStats (true , block_size, batch_size, 0 );
124114 }
125115 } else if (worker.isInitiator ()) {
126- std::vector<std::vector<xferBenchIOV>> remote_trans_lists (worker.exchangeIOV (local_trans_lists));
116+ std::vector<std::vector<xferBenchIOV>> remote_trans_lists (
117+ worker.exchangeIOV (local_trans_lists));
127118
128- auto result = worker.transfer (block_size,
129- local_trans_lists,
130- remote_trans_lists);
131- if (std::holds_alternative<int >(result)) {
132- return 1 ;
133- }
119+ auto result = worker.transfer (block_size, local_trans_lists, remote_trans_lists);
120+ if (std::holds_alternative<int >(result)) { return 1 ; }
134121
135- if (xferBenchConfig::check_consistency && xferBenchConfig::op_type == XFERBENCH_OP_READ) {
122+ if (xferBenchConfig::check_consistency &&
123+ xferBenchConfig::op_type == XFERBENCH_OP_READ) {
136124 xferBenchUtils::checkConsistency (local_trans_lists);
137125 }
138126
139- xferBenchUtils::printStats (false , block_size, batch_size,
140- std::get<double >(result));
127+ xferBenchUtils::printStats (false , block_size, batch_size, std::get<double >(result));
141128 }
142129 }
143130
144131 return 0 ;
145132}
146133
147- static std::unique_ptr<xferBenchWorker> createWorker (int *argc, char ***argv) {
134+ static std::unique_ptr<xferBenchWorker> createWorker (int * argc, char *** argv)
135+ {
148136 if (xferBenchConfig::worker_type == " nixl" ) {
149137 std::vector<std::string> devices = xferBenchConfig::parseDeviceList ();
150138 if (devices.empty ()) {
@@ -165,21 +153,18 @@ static std::unique_ptr<xferBenchWorker> createWorker(int *argc, char ***argv) {
165153 }
166154}
167155
168- int main (int argc, char *argv[]) {
156+ int main (int argc, char * argv[])
157+ {
169158 gflags::ParseCommandLineFlags (&argc, &argv, true );
170159
171160 int ret = xferBenchConfig::loadFromFlags ();
172- if (0 != ret) {
173- return EXIT_FAILURE;
174- }
161+ if (0 != ret) { return EXIT_FAILURE; }
175162
176163 int num_threads = xferBenchConfig::num_threads;
177164
178165 // Create the appropriate worker based on worker configuration
179166 std::unique_ptr<xferBenchWorker> worker_ptr = createWorker (&argc, &argv);
180- if (!worker_ptr) {
181- return EXIT_FAILURE;
182- }
167+ if (!worker_ptr) { return EXIT_FAILURE; }
183168
184169 std::signal (SIGINT, worker_ptr->signalHandler );
185170
@@ -191,34 +176,25 @@ int main(int argc, char *argv[]) {
191176 }
192177
193178 std::vector<std::vector<xferBenchIOV>> iov_lists = worker_ptr->allocateMemory (num_threads);
194- auto mem_guard = make_scope_guard ([&] {
195- worker_ptr->deallocateMemory (iov_lists);
196- });
179+ auto mem_guard = make_scope_guard ([&] { worker_ptr->deallocateMemory (iov_lists); });
197180
198181 ret = worker_ptr->exchangeMetadata ();
199- if (0 != ret) {
200- return EXIT_FAILURE;
201- }
182+ if (0 != ret) { return EXIT_FAILURE; }
202183
203184 if (worker_ptr->isInitiator () && worker_ptr->isMasterRank ()) {
204185 xferBenchConfig::printConfig ();
205186 xferBenchUtils::printStatsHeader ();
206187 }
207188
208189 for (size_t block_size = xferBenchConfig::start_block_size;
209- !worker_ptr->signaled () &&
210- block_size <= xferBenchConfig::max_block_size;
211- block_size *= 2 ) {
190+ !worker_ptr->signaled () && block_size <= xferBenchConfig::max_block_size;
191+ block_size *= 2 ) {
212192 ret = processBatchSizes (*worker_ptr, iov_lists, block_size, num_threads);
213- if (0 != ret) {
214- return EXIT_FAILURE;
215- }
193+ if (0 != ret) { return EXIT_FAILURE; }
216194 }
217195
218- ret = worker_ptr->synchronize (); // Make sure environment is not used anymore
219- if (0 != ret) {
220- return EXIT_FAILURE;
221- }
196+ ret = worker_ptr->synchronize (); // Make sure environment is not used anymore
197+ if (0 != ret) { return EXIT_FAILURE; }
222198
223199 gflags::ShutDownCommandLineFlags ();
224200
0 commit comments