@@ -60,18 +60,19 @@ using InOutMesh = std::pair<DeviceMesh, DeviceMesh>;
60
60
61
61
static constexpr int kTensorSize = 4 ;
62
62
63
- class LowerGatherTest : public MultiDeviceTest ,
64
- public testing::WithParamInterface<std::tuple<InOutMesh, bool >> {};
63
+ class LowerGatherTest
64
+ : public MultiDeviceTest,
65
+ public testing::WithParamInterface<std::tuple<InOutMesh, bool >> {};
65
66
66
67
TEST_P (LowerGatherTest, ) {
67
68
EnableOptionsGuard opt_guard;
68
69
const auto & [meshes, enable_host_ir_lowering] = GetParam ();
69
70
const auto & [in_mesh, out_mesh] = meshes;
70
-
71
+
71
72
if (enable_host_ir_lowering) {
72
73
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
73
74
}
74
-
75
+
75
76
SKIP_IF_NOT_ENOUGH_DEVICES (in_mesh, out_mesh);
76
77
77
78
auto fusion = std::make_unique<Fusion>();
@@ -107,16 +108,15 @@ INSTANTIATE_TEST_SUITE_P(
107
108
// Create product of InOutMesh configurations and HostIrLowering options
108
109
testing::Combine (
109
110
testing::ValuesIn (std::vector<InOutMesh>(
110
- {{{0 , 1 }, {0 }},
111
- {{0 , 1 }, {1 }},
112
- {{1 , 2 }, {0 , 2 }}})),
113
- testing::Values(false )), // TODO: testing::Bool() after implementing communication lowering
111
+ {{{0 , 1 }, {0 }}, {{0 , 1 }, {1 }}, {{1 , 2 }, {0 , 2 }}})),
112
+ testing::Values(false )), // TODO: testing::Bool() after implementing
113
+ // communication lowering
114
114
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
115
115
const auto & meshes = std::get<0 >(info.param );
116
116
const auto & in_mesh = meshes.first ;
117
117
const auto & out_mesh = meshes.second ;
118
118
const auto enable_hir = std::get<1 >(info.param );
119
-
119
+
120
120
std::stringstream ss;
121
121
ss << " InMesh" ;
122
122
for (auto id : in_mesh.vector ()) {
@@ -127,22 +127,23 @@ INSTANTIATE_TEST_SUITE_P(
127
127
ss << " _" << id;
128
128
}
129
129
ss << (enable_hir ? " _HirEnabled" : " _HirDisabled" );
130
-
130
+
131
131
return ss.str ();
132
132
});
133
133
134
- class LowerScatterTest : public MultiDeviceTest ,
135
- public testing::WithParamInterface<std::tuple<InOutMesh, bool >> {};
134
+ class LowerScatterTest
135
+ : public MultiDeviceTest,
136
+ public testing::WithParamInterface<std::tuple<InOutMesh, bool >> {};
136
137
137
138
TEST_P (LowerScatterTest, ) {
138
139
EnableOptionsGuard opt_guard;
139
140
const auto & [meshes, enable_host_ir_lowering] = GetParam ();
140
141
const auto & [in_mesh, out_mesh] = meshes;
141
-
142
+
142
143
if (enable_host_ir_lowering) {
143
144
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
144
145
}
145
-
146
+
146
147
SKIP_IF_NOT_ENOUGH_DEVICES (in_mesh, out_mesh);
147
148
148
149
auto fusion = std::make_unique<Fusion>();
@@ -180,13 +181,14 @@ INSTANTIATE_TEST_SUITE_P(
180
181
{{{0 }, {0 , 1 }}, //
181
182
{{1 }, {0 , 1 }}, //
182
183
{{0 , 2 }, {1 , 2 }}})),
183
- testing::Values(false )), // TODO: testing::Bool() after implementing communication lowering
184
+ testing::Values(false )), // TODO: testing::Bool() after implementing
185
+ // communication lowering
184
186
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
185
187
const auto & meshes = std::get<0 >(info.param );
186
188
const auto & in_mesh = meshes.first ;
187
189
const auto & out_mesh = meshes.second ;
188
190
const auto enable_hir = std::get<1 >(info.param );
189
-
191
+
190
192
std::stringstream ss;
191
193
ss << " InMesh" ;
192
194
for (auto id : in_mesh.vector ()) {
@@ -197,22 +199,23 @@ INSTANTIATE_TEST_SUITE_P(
197
199
ss << " _" << id;
198
200
}
199
201
ss << (enable_hir ? " _HirEnabled" : " _HirDisabled" );
200
-
202
+
201
203
return ss.str ();
202
204
});
203
205
204
- class LowerSendRecvTest : public MultiDeviceTest ,
205
- public testing::WithParamInterface<std::tuple<InOutMesh, bool >> {};
206
+ class LowerSendRecvTest
207
+ : public MultiDeviceTest,
208
+ public testing::WithParamInterface<std::tuple<InOutMesh, bool >> {};
206
209
207
210
TEST_P (LowerSendRecvTest, ) {
208
211
EnableOptionsGuard opt_guard;
209
212
const auto & [meshes, enable_host_ir_lowering] = GetParam ();
210
213
const auto & [in_mesh, out_mesh] = meshes;
211
-
214
+
212
215
if (enable_host_ir_lowering) {
213
216
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
214
217
}
215
-
218
+
216
219
SKIP_IF_NOT_ENOUGH_DEVICES (in_mesh, out_mesh);
217
220
218
221
auto fusion = std::make_unique<Fusion>();
@@ -249,17 +252,15 @@ INSTANTIATE_TEST_SUITE_P(
249
252
LowerSendRecvTest,
250
253
testing::Combine (
251
254
testing::ValuesIn (std::vector<InOutMesh>(
252
- {{{0 }, {1 }},
253
- {{1 }, {0 }},
254
- {{1 , 2 }, {0 , 1 }},
255
- {{1 , 2 }, {1 , 0 }}})),
256
- testing::Values(false )), // TODO: testing::Bool() after implementing communication lowering
255
+ {{{0 }, {1 }}, {{1 }, {0 }}, {{1 , 2 }, {0 , 1 }}, {{1 , 2 }, {1 , 0 }}})),
256
+ testing::Values(false )), // TODO: testing::Bool() after implementing
257
+ // communication lowering
257
258
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
258
259
const auto & meshes = std::get<0 >(info.param );
259
260
const auto & in_mesh = meshes.first ;
260
261
const auto & out_mesh = meshes.second ;
261
262
const auto enable_hir = std::get<1 >(info.param );
262
-
263
+
263
264
std::stringstream ss;
264
265
ss << " InMesh" ;
265
266
for (auto id : in_mesh.vector ()) {
@@ -270,12 +271,12 @@ INSTANTIATE_TEST_SUITE_P(
270
271
ss << " _" << id;
271
272
}
272
273
ss << (enable_hir ? " _HirEnabled" : " _HirDisabled" );
273
-
274
+
274
275
return ss.str ();
275
276
});
276
277
277
278
class LowerCollectiveTest : public MultiDeviceTest ,
278
- public testing::WithParamInterface<bool > {};
279
+ public testing::WithParamInterface<bool > {};
279
280
280
281
TEST_P (LowerCollectiveTest, Allgather) {
281
282
EnableOptionsGuard opt_guard;
@@ -284,7 +285,7 @@ TEST_P(LowerCollectiveTest, Allgather) {
284
285
if (enable_host_ir_lowering) {
285
286
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
286
287
}
287
-
288
+
288
289
auto fusion = std::make_unique<Fusion>();
289
290
FusionGuard fg (fusion.get ());
290
291
@@ -319,7 +320,7 @@ TEST_P(LowerCollectiveTest, Allgather_LoopSplit) {
319
320
// Skip this test when HostIrLowering is enabled
320
321
GTEST_SKIP () << " Disabled for HostIrLowering enabled configuration" ;
321
322
}
322
-
323
+
323
324
auto fusion = std::make_unique<Fusion>();
324
325
FusionGuard fg (fusion.get ());
325
326
@@ -362,7 +363,7 @@ TEST_P(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) {
362
363
if (enable_host_ir_lowering) {
363
364
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
364
365
}
365
-
366
+
366
367
auto fusion = std::make_unique<Fusion>();
367
368
FusionGuard fg (fusion.get ());
368
369
@@ -402,7 +403,7 @@ TEST_P(LowerCollectiveTest, Broadcast) {
402
403
if (enable_host_ir_lowering) {
403
404
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
404
405
}
405
-
406
+
406
407
auto fusion = std::make_unique<Fusion>();
407
408
FusionGuard fg (fusion.get ());
408
409
@@ -425,7 +426,7 @@ TEST_P(LowerCollectiveTest, Broadcast) {
425
426
FusionExecutorCache executor_cache (std::move (fusion));
426
427
at::Tensor out_tensor =
427
428
executor_cache.runFusionWithInputs ({in_tensor})[0 ].as <at::Tensor>();
428
-
429
+
429
430
if (num_devices > 1 ) {
430
431
assertIsCompiledToHostIrContainer (executor_cache);
431
432
}
@@ -441,7 +442,7 @@ TEST_P(LowerCollectiveTest, Reduce) {
441
442
if (enable_host_ir_lowering) {
442
443
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
443
444
}
444
-
445
+
445
446
auto fusion = std::make_unique<Fusion>();
446
447
FusionGuard fg (fusion.get ());
447
448
@@ -480,7 +481,7 @@ TEST_P(LowerCollectiveTest, Allreduce) {
480
481
if (enable_host_ir_lowering) {
481
482
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
482
483
}
483
-
484
+
484
485
auto fusion = std::make_unique<Fusion>();
485
486
FusionGuard fg (fusion.get ());
486
487
@@ -514,7 +515,7 @@ TEST_P(LowerCollectiveTest, Allreduce_Concrete) {
514
515
if (enable_host_ir_lowering) {
515
516
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
516
517
}
517
-
518
+
518
519
auto fusion = std::make_unique<Fusion>();
519
520
FusionGuard fg (fusion.get ());
520
521
@@ -552,7 +553,7 @@ TEST_P(LowerCollectiveTest, ReduceScatter) {
552
553
if (enable_host_ir_lowering) {
553
554
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
554
555
}
555
-
556
+
556
557
auto fusion = std::make_unique<Fusion>();
557
558
FusionGuard fg (fusion.get ());
558
559
@@ -588,7 +589,7 @@ TEST_P(LowerCollectiveTest, ReduceScatter_Allgather) {
588
589
if (enable_host_ir_lowering) {
589
590
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
590
591
}
591
-
592
+
592
593
// Allreduce = ReduceScatter + Allgather
593
594
auto fusion = std::make_unique<Fusion>();
594
595
FusionGuard fg (fusion.get ());
@@ -623,7 +624,8 @@ TEST_P(LowerCollectiveTest, ReduceScatter_Allgather) {
623
624
INSTANTIATE_TEST_SUITE_P (
624
625
HostIrLowering,
625
626
LowerCollectiveTest,
626
- testing::Values (false ), // TODO: testing::Bool() after implementing communication lowering
627
+ testing::Values (false ), // TODO: testing::Bool() after implementing
628
+ // communication lowering
627
629
[](const testing::TestParamInfo<bool >& info) {
628
630
return info.param ? " HirLowerEnabled" : " HirLowerDisabled" ;
629
631
});
0 commit comments