Skip to content

Commit 5983947

Browse files
author
Nicholas Sarkauskas
committed
Revert "Change to not use pair"
This reverts commit ec23416.
1 parent ec23416 commit 5983947

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

tests/cpp/test_multidevice_lower_communication.cpp

+35-19
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,18 @@ void assertIsCompiledToHostIrContainer(
6363
} \
6464
} while (0)
6565

66+
using InOutMesh = std::pair<DeviceMesh, DeviceMesh>;
67+
6668
static constexpr int kTensorSize = 4;
6769

6870
class LowerGatherTest
6971
: public MultiDeviceTest,
70-
public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh, bool>> {};
72+
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
7173

7274
TEST_P(LowerGatherTest, ) {
7375
EnableOptionsGuard opt_guard;
74-
const auto& [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam();
76+
const auto& [meshes, enable_host_ir_lowering] = GetParam();
77+
const auto& [in_mesh, out_mesh] = meshes;
7578

7679
if (enable_host_ir_lowering) {
7780
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
@@ -109,14 +112,17 @@ TEST_P(LowerGatherTest, ) {
109112
INSTANTIATE_TEST_SUITE_P(
110113
HostIrLowering,
111114
LowerGatherTest,
112-
// Create product of mesh configurations and HostIrLowering options
115+
// Create product of InOutMesh configurations and HostIrLowering options
113116
testing::Combine(
114-
testing::ValuesIn(std::vector<DeviceMesh>({{0, 1}, {0, 1}, {1, 2}})),
115-
testing::ValuesIn(std::vector<DeviceMesh>({{0}, {1}, {0, 2}})),
117+
testing::ValuesIn(std::vector<InOutMesh>(
118+
{{{0, 1}, {0}}, {{0, 1}, {1}}, {{1, 2}, {0, 2}}})),
116119
testing::Values(false)), // TODO: testing::Bool() after implementing
117120
// communication lowering
118-
[](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool>>& info) {
119-
const auto& [in_mesh, out_mesh, enable_hir] = info.param;
121+
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
122+
const auto& meshes = std::get<0>(info.param);
123+
const auto& in_mesh = meshes.first;
124+
const auto& out_mesh = meshes.second;
125+
const auto enable_hir = std::get<1>(info.param);
120126

121127
std::stringstream ss;
122128
ss << "InMesh";
@@ -134,11 +140,12 @@ INSTANTIATE_TEST_SUITE_P(
134140

135141
class LowerScatterTest
136142
: public MultiDeviceTest,
137-
public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh, bool>> {};
143+
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
138144

139145
TEST_P(LowerScatterTest, ) {
140146
EnableOptionsGuard opt_guard;
141-
const auto& [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam();
147+
const auto& [meshes, enable_host_ir_lowering] = GetParam();
148+
const auto& [in_mesh, out_mesh] = meshes;
142149

143150
if (enable_host_ir_lowering) {
144151
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
@@ -177,12 +184,17 @@ INSTANTIATE_TEST_SUITE_P(
177184
HostIrLowering,
178185
LowerScatterTest,
179186
testing::Combine(
180-
testing::ValuesIn(std::vector<DeviceMesh>({{0}, {1}, {0, 2}})),
181-
testing::ValuesIn(std::vector<DeviceMesh>({{0, 1}, {0, 1}, {1, 2}})),
187+
testing::ValuesIn(std::vector<InOutMesh>(
188+
{{{0}, {0, 1}},
189+
{{1}, {0, 1}},
190+
{{0, 2}, {1, 2}}})),
182191
testing::Values(false)), // TODO: testing::Bool() after implementing
183192
// communication lowering
184-
[](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool>>& info) {
185-
const auto& [in_mesh, out_mesh, enable_hir] = info.param;
193+
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
194+
const auto& meshes = std::get<0>(info.param);
195+
const auto& in_mesh = meshes.first;
196+
const auto& out_mesh = meshes.second;
197+
const auto enable_hir = std::get<1>(info.param);
186198

187199
std::stringstream ss;
188200
ss << "InMesh";
@@ -200,11 +212,12 @@ INSTANTIATE_TEST_SUITE_P(
200212

201213
class LowerSendRecvTest
202214
: public MultiDeviceTest,
203-
public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh, bool>> {};
215+
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
204216

205217
TEST_P(LowerSendRecvTest, ) {
206218
EnableOptionsGuard opt_guard;
207-
const auto& [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam();
219+
const auto& [meshes, enable_host_ir_lowering] = GetParam();
220+
const auto& [in_mesh, out_mesh] = meshes;
208221

209222
if (enable_host_ir_lowering) {
210223
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
@@ -245,12 +258,15 @@ INSTANTIATE_TEST_SUITE_P(
245258
HostIrLowering,
246259
LowerSendRecvTest,
247260
testing::Combine(
248-
testing::ValuesIn(std::vector<DeviceMesh>({{0}, {1}, {1, 2}, {1, 2}})),
249-
testing::ValuesIn(std::vector<DeviceMesh>({{1}, {0}, {0, 1}, {1, 0}})),
261+
testing::ValuesIn(std::vector<InOutMesh>(
262+
{{{0}, {1}}, {{1}, {0}}, {{1, 2}, {0, 1}}, {{1, 2}, {1, 0}}})),
250263
testing::Values(false)), // TODO: testing::Bool() after implementing
251264
// communication lowering
252-
[](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool>>& info) {
253-
const auto& [in_mesh, out_mesh, enable_hir] = info.param;
265+
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
266+
const auto& meshes = std::get<0>(info.param);
267+
const auto& in_mesh = meshes.first;
268+
const auto& out_mesh = meshes.second;
269+
const auto enable_hir = std::get<1>(info.param);
254270

255271
std::stringstream ss;
256272
ss << "InMesh";

0 commit comments

Comments
 (0)