@@ -63,18 +63,15 @@ void assertIsCompiledToHostIrContainer(
63
63
} \
64
64
} while (0 )
65
65
66
- using InOutMesh = std::pair<DeviceMesh, DeviceMesh>;
67
-
68
66
static constexpr int kTensorSize = 4 ;
69
67
70
68
class LowerGatherTest
71
69
: public MultiDeviceTest,
72
- public testing::WithParamInterface<std::tuple<InOutMesh , bool >> {};
70
+ public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh , bool >> {};
73
71
74
72
TEST_P (LowerGatherTest, ) {
75
73
EnableOptionsGuard opt_guard;
76
- const auto & [meshes, enable_host_ir_lowering] = GetParam ();
77
- const auto & [in_mesh, out_mesh] = meshes;
74
+ const auto & [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam ();
78
75
79
76
if (enable_host_ir_lowering) {
80
77
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
@@ -112,17 +109,14 @@ TEST_P(LowerGatherTest, ) {
112
109
INSTANTIATE_TEST_SUITE_P (
113
110
HostIrLowering,
114
111
LowerGatherTest,
115
- // Create product of InOutMesh configurations and HostIrLowering options
112
+ // Create product of mesh configurations and HostIrLowering options
116
113
testing::Combine (
117
- testing::ValuesIn (std::vector<InOutMesh>(
118
- {{{ 0 , 1 }, {0 }}, {{ 0 , 1 }, {1 }}, {{ 1 , 2 }, { 0 , 2 } }})),
114
+ testing::ValuesIn (std::vector<DeviceMesh>({{ 0 , 1 }, { 0 , 1 }, { 1 , 2 }})),
115
+ testing::ValuesIn(std::vector<DeviceMesh>({{ 0 }, {1 }, {0 , 2 }})),
119
116
testing::Values(false )), // TODO: testing::Bool() after implementing
120
117
// communication lowering
121
- [](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
122
- const auto & meshes = std::get<0 >(info.param );
123
- const auto & in_mesh = meshes.first ;
124
- const auto & out_mesh = meshes.second ;
125
- const auto enable_hir = std::get<1 >(info.param );
118
+ [](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool >>& info) {
119
+ const auto & [in_mesh, out_mesh, enable_hir] = info.param ;
126
120
127
121
std::stringstream ss;
128
122
ss << " InMesh" ;
@@ -140,12 +134,11 @@ INSTANTIATE_TEST_SUITE_P(
140
134
141
135
class LowerScatterTest
142
136
: public MultiDeviceTest,
143
- public testing::WithParamInterface<std::tuple<InOutMesh , bool >> {};
137
+ public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh , bool >> {};
144
138
145
139
TEST_P (LowerScatterTest, ) {
146
140
EnableOptionsGuard opt_guard;
147
- const auto & [meshes, enable_host_ir_lowering] = GetParam ();
148
- const auto & [in_mesh, out_mesh] = meshes;
141
+ const auto & [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam ();
149
142
150
143
if (enable_host_ir_lowering) {
151
144
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
@@ -184,17 +177,12 @@ INSTANTIATE_TEST_SUITE_P(
184
177
HostIrLowering,
185
178
LowerScatterTest,
186
179
testing::Combine (
187
- testing::ValuesIn (std::vector<InOutMesh>(
188
- {{{0 }, {0 , 1 }},
189
- {{1 }, {0 , 1 }},
190
- {{0 , 2 }, {1 , 2 }}})),
180
+ testing::ValuesIn (std::vector<DeviceMesh>({{0 }, {1 }, {0 , 2 }})),
181
+ testing::ValuesIn(std::vector<DeviceMesh>({{0 , 1 }, {0 , 1 }, {1 , 2 }})),
191
182
testing::Values(false )), // TODO: testing::Bool() after implementing
192
183
// communication lowering
193
- [](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
194
- const auto & meshes = std::get<0 >(info.param );
195
- const auto & in_mesh = meshes.first ;
196
- const auto & out_mesh = meshes.second ;
197
- const auto enable_hir = std::get<1 >(info.param );
184
+ [](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool >>& info) {
185
+ const auto & [in_mesh, out_mesh, enable_hir] = info.param ;
198
186
199
187
std::stringstream ss;
200
188
ss << " InMesh" ;
@@ -212,12 +200,11 @@ INSTANTIATE_TEST_SUITE_P(
212
200
213
201
class LowerSendRecvTest
214
202
: public MultiDeviceTest,
215
- public testing::WithParamInterface<std::tuple<InOutMesh , bool >> {};
203
+ public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh , bool >> {};
216
204
217
205
TEST_P (LowerSendRecvTest, ) {
218
206
EnableOptionsGuard opt_guard;
219
- const auto & [meshes, enable_host_ir_lowering] = GetParam ();
220
- const auto & [in_mesh, out_mesh] = meshes;
207
+ const auto & [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam ();
221
208
222
209
if (enable_host_ir_lowering) {
223
210
EnableOptionsGuard::getCurOptions ().set (EnableOption::HostIrLowering);
@@ -258,15 +245,12 @@ INSTANTIATE_TEST_SUITE_P(
258
245
HostIrLowering,
259
246
LowerSendRecvTest,
260
247
testing::Combine (
261
- testing::ValuesIn (std::vector<InOutMesh>(
262
- {{{ 0 }, { 1 }} , {{ 1 }, { 0 }} , {{ 1 , 2 }, { 0 , 1 }} , {{ 1 , 2 }, { 1 , 0 } }})),
248
+ testing::ValuesIn (std::vector<DeviceMesh>({{ 0 }, { 1 }, { 1 , 2 }, { 1 , 2 }})),
249
+ testing::ValuesIn(std::vector<DeviceMesh>({{ 1 } , {0 } , {0 , 1 }, {1 , 0 }})),
263
250
testing::Values(false )), // TODO: testing::Bool() after implementing
264
251
// communication lowering
265
- [](const testing::TestParamInfo<std::tuple<InOutMesh, bool >>& info) {
266
- const auto & meshes = std::get<0 >(info.param );
267
- const auto & in_mesh = meshes.first ;
268
- const auto & out_mesh = meshes.second ;
269
- const auto enable_hir = std::get<1 >(info.param );
252
+ [](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool >>& info) {
253
+ const auto & [in_mesh, out_mesh, enable_hir] = info.param ;
270
254
271
255
std::stringstream ss;
272
256
ss << " InMesh" ;
0 commit comments