Skip to content

Commit ec23416

Browse files
author
Nicholas Sarkauskas
committed
Change to not use pair
1 parent 72295e7 commit ec23416

File tree

1 file changed

+19
-35
lines changed

1 file changed

+19
-35
lines changed

tests/cpp/test_multidevice_lower_communication.cpp

+19-35
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,15 @@ void assertIsCompiledToHostIrContainer(
6363
} \
6464
} while (0)
6565

66-
using InOutMesh = std::pair<DeviceMesh, DeviceMesh>;
67-
6866
static constexpr int kTensorSize = 4;
6967

7068
class LowerGatherTest
7169
: public MultiDeviceTest,
72-
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
70+
public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh, bool>> {};
7371

7472
TEST_P(LowerGatherTest, ) {
7573
EnableOptionsGuard opt_guard;
76-
const auto& [meshes, enable_host_ir_lowering] = GetParam();
77-
const auto& [in_mesh, out_mesh] = meshes;
74+
const auto& [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam();
7875

7976
if (enable_host_ir_lowering) {
8077
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
@@ -112,17 +109,14 @@ TEST_P(LowerGatherTest, ) {
112109
INSTANTIATE_TEST_SUITE_P(
113110
HostIrLowering,
114111
LowerGatherTest,
115-
// Create product of InOutMesh configurations and HostIrLowering options
112+
// Create product of mesh configurations and HostIrLowering options
116113
testing::Combine(
117-
testing::ValuesIn(std::vector<InOutMesh>(
118-
{{{0, 1}, {0}}, {{0, 1}, {1}}, {{1, 2}, {0, 2}}})),
114+
testing::ValuesIn(std::vector<DeviceMesh>({{0, 1}, {0, 1}, {1, 2}})),
115+
testing::ValuesIn(std::vector<DeviceMesh>({{0}, {1}, {0, 2}})),
119116
testing::Values(false)), // TODO: testing::Bool() after implementing
120117
// communication lowering
121-
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
122-
const auto& meshes = std::get<0>(info.param);
123-
const auto& in_mesh = meshes.first;
124-
const auto& out_mesh = meshes.second;
125-
const auto enable_hir = std::get<1>(info.param);
118+
[](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool>>& info) {
119+
const auto& [in_mesh, out_mesh, enable_hir] = info.param;
126120

127121
std::stringstream ss;
128122
ss << "InMesh";
@@ -140,12 +134,11 @@ INSTANTIATE_TEST_SUITE_P(
140134

141135
class LowerScatterTest
142136
: public MultiDeviceTest,
143-
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
137+
public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh, bool>> {};
144138

145139
TEST_P(LowerScatterTest, ) {
146140
EnableOptionsGuard opt_guard;
147-
const auto& [meshes, enable_host_ir_lowering] = GetParam();
148-
const auto& [in_mesh, out_mesh] = meshes;
141+
const auto& [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam();
149142

150143
if (enable_host_ir_lowering) {
151144
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
@@ -184,17 +177,12 @@ INSTANTIATE_TEST_SUITE_P(
184177
HostIrLowering,
185178
LowerScatterTest,
186179
testing::Combine(
187-
testing::ValuesIn(std::vector<InOutMesh>(
188-
{{{0}, {0, 1}},
189-
{{1}, {0, 1}},
190-
{{0, 2}, {1, 2}}})),
180+
testing::ValuesIn(std::vector<DeviceMesh>({{0}, {1}, {0, 2}})),
181+
testing::ValuesIn(std::vector<DeviceMesh>({{0, 1}, {0, 1}, {1, 2}})),
191182
testing::Values(false)), // TODO: testing::Bool() after implementing
192183
// communication lowering
193-
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
194-
const auto& meshes = std::get<0>(info.param);
195-
const auto& in_mesh = meshes.first;
196-
const auto& out_mesh = meshes.second;
197-
const auto enable_hir = std::get<1>(info.param);
184+
[](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool>>& info) {
185+
const auto& [in_mesh, out_mesh, enable_hir] = info.param;
198186

199187
std::stringstream ss;
200188
ss << "InMesh";
@@ -212,12 +200,11 @@ INSTANTIATE_TEST_SUITE_P(
212200

213201
class LowerSendRecvTest
214202
: public MultiDeviceTest,
215-
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
203+
public testing::WithParamInterface<std::tuple<DeviceMesh, DeviceMesh, bool>> {};
216204

217205
TEST_P(LowerSendRecvTest, ) {
218206
EnableOptionsGuard opt_guard;
219-
const auto& [meshes, enable_host_ir_lowering] = GetParam();
220-
const auto& [in_mesh, out_mesh] = meshes;
207+
const auto& [in_mesh, out_mesh, enable_host_ir_lowering] = GetParam();
221208

222209
if (enable_host_ir_lowering) {
223210
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
@@ -258,15 +245,12 @@ INSTANTIATE_TEST_SUITE_P(
258245
HostIrLowering,
259246
LowerSendRecvTest,
260247
testing::Combine(
261-
testing::ValuesIn(std::vector<InOutMesh>(
262-
{{{0}, {1}}, {{1}, {0}}, {{1, 2}, {0, 1}}, {{1, 2}, {1, 0}}})),
248+
testing::ValuesIn(std::vector<DeviceMesh>({{0}, {1}, {1, 2}, {1, 2}})),
249+
testing::ValuesIn(std::vector<DeviceMesh>({{1}, {0}, {0, 1}, {1, 0}})),
263250
testing::Values(false)), // TODO: testing::Bool() after implementing
264251
// communication lowering
265-
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
266-
const auto& meshes = std::get<0>(info.param);
267-
const auto& in_mesh = meshes.first;
268-
const auto& out_mesh = meshes.second;
269-
const auto enable_hir = std::get<1>(info.param);
252+
[](const testing::TestParamInfo<std::tuple<DeviceMesh, DeviceMesh, bool>>& info) {
253+
const auto& [in_mesh, out_mesh, enable_hir] = info.param;
270254

271255
std::stringstream ss;
272256
ss << "InMesh";

0 commit comments

Comments
 (0)