@@ -63,15 +63,18 @@ void assertIsCompiledToHostIrContainer(
63
63
} \
64
64
} while (0 )
65
65
66
+ using InOutMesh = std::pair<DeviceMesh, DeviceMesh>;
67
+
66
68
static constexpr int kTensorSize = 4 ;
67
69
68
70
class LowerGatherTest
69
71
: public MultiDeviceTest,
70
- public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh , bool >> {};
72
+ public testing::WithParamInterface<std::tuple<InOutMesh , bool >> {};
71
73
72
74
TEST_P (LowerGatherTest, ) {
73
75
EnableOptionsGuard opt_guard;
74
- const auto & [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam ();
76
+ const auto & [meshes, enable_host_ir_lowering] = GetParam ();
77
+ const auto & [in_mesh, out_mesh] = meshes;
75
78
76
79
if (enable_host_ir_lowering) {
77
80
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
@@ -109,14 +112,17 @@ TEST_P(LowerGatherTest, ) {
109
112
INSTANTIATE_TEST_SUITE_P (
110
113
HostIrLowering,
111
114
LowerGatherTest,
112
- // Create product of mesh configurations and HostIrLowering options
115
+ // Create product of InOutMesh configurations and HostIrLowering options
113
116
testing::Combine (
114
- testing::ValuesIn (std::vector<DeviceMesh>({{ 0 , 1 }, { 0 , 1 }, { 1 , 2 }})),
115
- testing::ValuesIn(std::vector<DeviceMesh>({{ 0 }, {1 }, {0 , 2 }})),
117
+ testing::ValuesIn (std::vector<InOutMesh>(
118
+ {{{ 0 , 1 }, {0 }}, {{ 0 , 1 }, {1 }}, {{ 1 , 2 }, { 0 , 2 } }})),
116
119
testing::Values(false )), // TODO: testing::Bool() after implementing
117
120
// communication lowering
118
- [](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool >>& info) {
119
- const auto & [in_mesh, out_mesh, enable_hir] = info.param ;
121
+ [](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
122
+ const auto & meshes = std::get<0 >(info.param );
123
+ const auto & in_mesh = meshes.first ;
124
+ const auto & out_mesh = meshes.second ;
125
+ const auto enable_hir = std::get<1 >(info.param );
120
126
121
127
std::stringstream ss;
122
128
ss << " InMesh" ;
@@ -134,11 +140,12 @@ INSTANTIATE_TEST_SUITE_P(
134
140
135
141
class LowerScatterTest
136
142
: public MultiDeviceTest,
137
- public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh , bool >> {};
143
+ public testing::WithParamInterface<std::tuple<InOutMesh , bool >> {};
138
144
139
145
TEST_P (LowerScatterTest, ) {
140
146
EnableOptionsGuard opt_guard;
141
- const auto & [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam ();
147
+ const auto & [meshes, enable_host_ir_lowering] = GetParam ();
148
+ const auto & [in_mesh, out_mesh] = meshes;
142
149
143
150
if (enable_host_ir_lowering) {
144
151
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
@@ -177,12 +184,17 @@ INSTANTIATE_TEST_SUITE_P(
177
184
HostIrLowering,
178
185
LowerScatterTest,
179
186
testing::Combine (
180
- testing::ValuesIn (std::vector<DeviceMesh>({{0 }, {1 }, {0 , 2 }})),
181
- testing::ValuesIn(std::vector<DeviceMesh>({{0 , 1 }, {0 , 1 }, {1 , 2 }})),
187
+ testing::ValuesIn (std::vector<InOutMesh>(
188
+ {{{0 }, {0 , 1 }},
189
+ {{1 }, {0 , 1 }},
190
+ {{0 , 2 }, {1 , 2 }}})),
182
191
testing::Values(false )), // TODO: testing::Bool() after implementing
183
192
// communication lowering
184
- [](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool >>& info) {
185
- const auto & [in_mesh, out_mesh, enable_hir] = info.param ;
193
+ [](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
194
+ const auto & meshes = std::get<0 >(info.param );
195
+ const auto & in_mesh = meshes.first ;
196
+ const auto & out_mesh = meshes.second ;
197
+ const auto enable_hir = std::get<1 >(info.param );
186
198
187
199
std::stringstream ss;
188
200
ss << " InMesh" ;
@@ -200,11 +212,12 @@ INSTANTIATE_TEST_SUITE_P(
200
212
201
213
class LowerSendRecvTest
202
214
: public MultiDeviceTest,
203
- public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh , bool >> {};
215
+ public testing::WithParamInterface<std::tuple<InOutMesh , bool >> {};
204
216
205
217
TEST_P (LowerSendRecvTest, ) {
206
218
EnableOptionsGuard opt_guard;
207
- const auto & [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam ();
219
+ const auto & [meshes, enable_host_ir_lowering] = GetParam ();
220
+ const auto & [in_mesh, out_mesh] = meshes;
208
221
209
222
if (enable_host_ir_lowering) {
210
223
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
@@ -245,12 +258,15 @@ INSTANTIATE_TEST_SUITE_P(
245
258
HostIrLowering,
246
259
LowerSendRecvTest,
247
260
testing::Combine (
248
- testing::ValuesIn (std::vector<DeviceMesh>({{ 0 }, { 1 }, { 1 , 2 }, { 1 , 2 }})),
249
- testing::ValuesIn(std::vector<DeviceMesh>({{ 1 } , {0 } , {0 , 1 }, {1 , 0 }})),
261
+ testing::ValuesIn (std::vector<InOutMesh>(
262
+ {{{ 0 }, { 1 }} , {{ 1 }, { 0 }} , {{ 1 , 2 }, { 0 , 1 }} , {{ 1 , 2 }, { 1 , 0 } }})),
250
263
testing::Values(false )), // TODO: testing::Bool() after implementing
251
264
// communication lowering
252
- [](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool >>& info) {
253
- const auto & [in_mesh, out_mesh, enable_hir] = info.param ;
265
+ [](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
266
+ const auto & meshes = std::get<0 >(info.param );
267
+ const auto & in_mesh = meshes.first ;
268
+ const auto & out_mesh = meshes.second ;
269
+ const auto enable_hir = std::get<1 >(info.param );
254
270
255
271
std::stringstream ss;
256
272
ss << " InMesh" ;
0 commit comments