Skip to content

Commit 01bf472

Browse files
committed
linter
1 parent ae93f2a commit 01bf472

File tree

2 files changed

+51
-44
lines changed

2 files changed

+51
-44
lines changed

csrc/runtime/fusion_kernel_runtime.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,16 @@ flatbuffers::Offset<serde::FusionKernelRuntime> FusionKernelRuntime::serialize(
200200
}
201201

202202
namespace {
203-
std::vector<Expr*> toposortExprs(SegmentedFusion* fusion, SegmentedGroup* group) {
203+
std::vector<Expr*> toposortExprs(
204+
SegmentedFusion* fusion,
205+
SegmentedGroup* group) {
204206
std::vector<Expr*> sorted_exprs;
205207
{
206-
auto [/*IrCloner*/group_ir_cloner, /*std::unique_ptr<Fusion>*/group_fusion] = fusion->makeFusion(group);
208+
auto
209+
[/*IrCloner*/ group_ir_cloner,
210+
/*std::unique_ptr<Fusion>*/ group_fusion] = fusion->makeFusion(group);
207211
std::unordered_map<Expr*, Expr*> inverse_clone_map;
208-
for (auto expr: group->exprs()) { // Sorts the exprs in the group
212+
for (auto expr : group->exprs()) { // Sorts the exprs in the group
209213
inverse_clone_map[group_ir_cloner.clone(expr)] = expr;
210214
}
211215
for (auto cloned_expr : group_fusion->exprs()) {
@@ -480,7 +484,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
480484
} else {
481485
// push back segment's exprs into the container as top level
482486
// expressions
483-
for (auto* expr : toposortExprs(segmented_fusion_.get(), group_to_run)) {
487+
for (auto* expr :
488+
toposortExprs(segmented_fusion_.get(), group_to_run)) {
484489
auto cloned_expr = ir_cloner.clone(expr);
485490
hic->pushBackTopLevelExprs(cloned_expr);
486491
}

tests/cpp/test_multidevice_lower_communication.cpp

+42-40
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,19 @@ using InOutMesh = std::pair<DeviceMesh, DeviceMesh>;
6060

6161
static constexpr int kTensorSize = 4;
6262

63-
class LowerGatherTest : public MultiDeviceTest,
64-
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
63+
class LowerGatherTest
64+
: public MultiDeviceTest,
65+
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
6566

6667
TEST_P(LowerGatherTest, ) {
6768
EnableOptionsGuard opt_guard;
6869
const auto& [meshes, enable_host_ir_lowering] = GetParam();
6970
const auto& [in_mesh, out_mesh] = meshes;
70-
71+
7172
if (enable_host_ir_lowering) {
7273
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
7374
}
74-
75+
7576
SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh);
7677

7778
auto fusion = std::make_unique<Fusion>();
@@ -107,16 +108,15 @@ INSTANTIATE_TEST_SUITE_P(
107108
// Create product of InOutMesh configurations and HostIrLowering options
108109
testing::Combine(
109110
testing::ValuesIn(std::vector<InOutMesh>(
110-
{{{0, 1}, {0}},
111-
{{0, 1}, {1}},
112-
{{1, 2}, {0, 2}}})),
113-
testing::Values(false)), // TODO: testing::Bool() after implementing communication lowering
111+
{{{0, 1}, {0}}, {{0, 1}, {1}}, {{1, 2}, {0, 2}}})),
112+
testing::Values(false)), // TODO: testing::Bool() after implementing
113+
// communication lowering
114114
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
115115
const auto& meshes = std::get<0>(info.param);
116116
const auto& in_mesh = meshes.first;
117117
const auto& out_mesh = meshes.second;
118118
const auto enable_hir = std::get<1>(info.param);
119-
119+
120120
std::stringstream ss;
121121
ss << "InMesh";
122122
for (auto id : in_mesh.vector()) {
@@ -127,22 +127,23 @@ INSTANTIATE_TEST_SUITE_P(
127127
ss << "_" << id;
128128
}
129129
ss << (enable_hir ? "_HirEnabled" : "_HirDisabled");
130-
130+
131131
return ss.str();
132132
});
133133

134-
class LowerScatterTest : public MultiDeviceTest,
135-
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
134+
class LowerScatterTest
135+
: public MultiDeviceTest,
136+
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
136137

137138
TEST_P(LowerScatterTest, ) {
138139
EnableOptionsGuard opt_guard;
139140
const auto& [meshes, enable_host_ir_lowering] = GetParam();
140141
const auto& [in_mesh, out_mesh] = meshes;
141-
142+
142143
if (enable_host_ir_lowering) {
143144
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
144145
}
145-
146+
146147
SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh);
147148

148149
auto fusion = std::make_unique<Fusion>();
@@ -180,13 +181,14 @@ INSTANTIATE_TEST_SUITE_P(
180181
{{{0}, {0, 1}}, //
181182
{{1}, {0, 1}}, //
182183
{{0, 2}, {1, 2}}})),
183-
testing::Values(false)), // TODO: testing::Bool() after implementing communication lowering
184+
testing::Values(false)), // TODO: testing::Bool() after implementing
185+
// communication lowering
184186
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
185187
const auto& meshes = std::get<0>(info.param);
186188
const auto& in_mesh = meshes.first;
187189
const auto& out_mesh = meshes.second;
188190
const auto enable_hir = std::get<1>(info.param);
189-
191+
190192
std::stringstream ss;
191193
ss << "InMesh";
192194
for (auto id : in_mesh.vector()) {
@@ -197,22 +199,23 @@ INSTANTIATE_TEST_SUITE_P(
197199
ss << "_" << id;
198200
}
199201
ss << (enable_hir ? "_HirEnabled" : "_HirDisabled");
200-
202+
201203
return ss.str();
202204
});
203205

204-
class LowerSendRecvTest : public MultiDeviceTest,
205-
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
206+
class LowerSendRecvTest
207+
: public MultiDeviceTest,
208+
public testing::WithParamInterface<std::tuple<InOutMesh, bool>> {};
206209

207210
TEST_P(LowerSendRecvTest, ) {
208211
EnableOptionsGuard opt_guard;
209212
const auto& [meshes, enable_host_ir_lowering] = GetParam();
210213
const auto& [in_mesh, out_mesh] = meshes;
211-
214+
212215
if (enable_host_ir_lowering) {
213216
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
214217
}
215-
218+
216219
SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh);
217220

218221
auto fusion = std::make_unique<Fusion>();
@@ -249,17 +252,15 @@ INSTANTIATE_TEST_SUITE_P(
249252
LowerSendRecvTest,
250253
testing::Combine(
251254
testing::ValuesIn(std::vector<InOutMesh>(
252-
{{{0}, {1}},
253-
{{1}, {0}},
254-
{{1, 2}, {0, 1}},
255-
{{1, 2}, {1, 0}}})),
256-
testing::Values(false)), // TODO: testing::Bool() after implementing communication lowering
255+
{{{0}, {1}}, {{1}, {0}}, {{1, 2}, {0, 1}}, {{1, 2}, {1, 0}}})),
256+
testing::Values(false)), // TODO: testing::Bool() after implementing
257+
// communication lowering
257258
[](const testing::TestParamInfo<std::tuple<InOutMesh, bool>>& info) {
258259
const auto& meshes = std::get<0>(info.param);
259260
const auto& in_mesh = meshes.first;
260261
const auto& out_mesh = meshes.second;
261262
const auto enable_hir = std::get<1>(info.param);
262-
263+
263264
std::stringstream ss;
264265
ss << "InMesh";
265266
for (auto id : in_mesh.vector()) {
@@ -270,12 +271,12 @@ INSTANTIATE_TEST_SUITE_P(
270271
ss << "_" << id;
271272
}
272273
ss << (enable_hir ? "_HirEnabled" : "_HirDisabled");
273-
274+
274275
return ss.str();
275276
});
276277

277278
class LowerCollectiveTest : public MultiDeviceTest,
278-
public testing::WithParamInterface<bool> {};
279+
public testing::WithParamInterface<bool> {};
279280

280281
TEST_P(LowerCollectiveTest, Allgather) {
281282
EnableOptionsGuard opt_guard;
@@ -284,7 +285,7 @@ TEST_P(LowerCollectiveTest, Allgather) {
284285
if (enable_host_ir_lowering) {
285286
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
286287
}
287-
288+
288289
auto fusion = std::make_unique<Fusion>();
289290
FusionGuard fg(fusion.get());
290291

@@ -319,7 +320,7 @@ TEST_P(LowerCollectiveTest, Allgather_LoopSplit) {
319320
// Skip this test when HostIrLowering is enabled
320321
GTEST_SKIP() << "Disabled for HostIrLowering enabled configuration";
321322
}
322-
323+
323324
auto fusion = std::make_unique<Fusion>();
324325
FusionGuard fg(fusion.get());
325326

@@ -362,7 +363,7 @@ TEST_P(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) {
362363
if (enable_host_ir_lowering) {
363364
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
364365
}
365-
366+
366367
auto fusion = std::make_unique<Fusion>();
367368
FusionGuard fg(fusion.get());
368369

@@ -402,7 +403,7 @@ TEST_P(LowerCollectiveTest, Broadcast) {
402403
if (enable_host_ir_lowering) {
403404
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
404405
}
405-
406+
406407
auto fusion = std::make_unique<Fusion>();
407408
FusionGuard fg(fusion.get());
408409

@@ -425,7 +426,7 @@ TEST_P(LowerCollectiveTest, Broadcast) {
425426
FusionExecutorCache executor_cache(std::move(fusion));
426427
at::Tensor out_tensor =
427428
executor_cache.runFusionWithInputs({in_tensor})[0].as<at::Tensor>();
428-
429+
429430
if (num_devices > 1) {
430431
assertIsCompiledToHostIrContainer(executor_cache);
431432
}
@@ -441,7 +442,7 @@ TEST_P(LowerCollectiveTest, Reduce) {
441442
if (enable_host_ir_lowering) {
442443
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
443444
}
444-
445+
445446
auto fusion = std::make_unique<Fusion>();
446447
FusionGuard fg(fusion.get());
447448

@@ -480,7 +481,7 @@ TEST_P(LowerCollectiveTest, Allreduce) {
480481
if (enable_host_ir_lowering) {
481482
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
482483
}
483-
484+
484485
auto fusion = std::make_unique<Fusion>();
485486
FusionGuard fg(fusion.get());
486487

@@ -514,7 +515,7 @@ TEST_P(LowerCollectiveTest, Allreduce_Concrete) {
514515
if (enable_host_ir_lowering) {
515516
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
516517
}
517-
518+
518519
auto fusion = std::make_unique<Fusion>();
519520
FusionGuard fg(fusion.get());
520521

@@ -552,7 +553,7 @@ TEST_P(LowerCollectiveTest, ReduceScatter) {
552553
if (enable_host_ir_lowering) {
553554
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
554555
}
555-
556+
556557
auto fusion = std::make_unique<Fusion>();
557558
FusionGuard fg(fusion.get());
558559

@@ -588,7 +589,7 @@ TEST_P(LowerCollectiveTest, ReduceScatter_Allgather) {
588589
if (enable_host_ir_lowering) {
589590
EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering);
590591
}
591-
592+
592593
// Allreduce = ReduceScatter + Allgather
593594
auto fusion = std::make_unique<Fusion>();
594595
FusionGuard fg(fusion.get());
@@ -623,7 +624,8 @@ TEST_P(LowerCollectiveTest, ReduceScatter_Allgather) {
623624
INSTANTIATE_TEST_SUITE_P(
624625
HostIrLowering,
625626
LowerCollectiveTest,
626-
testing::Values(false), // TODO: testing::Bool() after implementing communication lowering
627+
testing::Values(false), // TODO: testing::Bool() after implementing
628+
// communication lowering
627629
[](const testing::TestParamInfo<bool>& info) {
628630
return info.param ? "HirLowerEnabled" : "HirLowerDisabled";
629631
});

0 commit comments

Comments
 (0)