Skip to content

Commit 610d640

Browse files
authored
cargo +nightly fmt (#1017)
1 parent 1a5f252 commit 610d640

File tree

29 files changed

+424
-323
lines changed

29 files changed

+424
-323
lines changed

burn-autodiff/src/graph/traversal.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ impl BreadthFirstSearch {
1616
let mut visited = HashSet::with_capacity(root.order);
1717
let mut parents = Vec::with_capacity(root.order);
1818
let mut steps = graph.steps();
19-
let root_step = steps
20-
.remove(&root.id)
21-
.expect("Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?");
19+
let root_step = steps.remove(&root.id).expect(
20+
"Root node should have a step registered, did you forget to call \
21+
`Tensor::register_grad` on the tensor where you need gradients?",
22+
);
2223

2324
visited.insert(root.id.clone());
2425
parents.append(&mut root.parents.clone());

burn-autodiff/src/ops/module.rs

+154-170
Original file line numberDiff line numberDiff line change
@@ -103,49 +103,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
103103
}
104104

105105
match bias {
106-
Some(bias) => {
107-
match Conv2DWithBias
108-
.prepare(
109-
[x.node, weight.node, bias.node],
110-
[x.graph, weight.graph, bias.graph],
111-
)
112-
.stateful()
113-
{
114-
OpsKind::Tracked(prep) => prep.finish(
115-
(
116-
x.primitive.clone(),
117-
weight.primitive.clone(),
118-
bias.primitive.clone(),
119-
options.clone(),
120-
),
121-
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
106+
Some(bias) => match Conv2DWithBias
107+
.prepare(
108+
[x.node, weight.node, bias.node],
109+
[x.graph, weight.graph, bias.graph],
110+
)
111+
.stateful()
112+
{
113+
OpsKind::Tracked(prep) => prep.finish(
114+
(
115+
x.primitive.clone(),
116+
weight.primitive.clone(),
117+
bias.primitive.clone(),
118+
options.clone(),
122119
),
123-
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
124-
x.primitive,
125-
weight.primitive,
126-
Some(bias.primitive),
127-
options,
128-
)),
129-
}
130-
}
131-
None => {
132-
match Conv2DNoBias
133-
.prepare([x.node, weight.node], [x.graph, weight.graph])
134-
.stateful()
135-
{
136-
OpsKind::Tracked(prep) => prep.finish(
137-
(
138-
x.primitive.clone(),
139-
weight.primitive.clone(),
140-
options.clone(),
141-
),
142-
B::conv2d(x.primitive, weight.primitive, None, options),
120+
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
121+
),
122+
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
123+
x.primitive,
124+
weight.primitive,
125+
Some(bias.primitive),
126+
options,
127+
)),
128+
},
129+
None => match Conv2DNoBias
130+
.prepare([x.node, weight.node], [x.graph, weight.graph])
131+
.stateful()
132+
{
133+
OpsKind::Tracked(prep) => prep.finish(
134+
(
135+
x.primitive.clone(),
136+
weight.primitive.clone(),
137+
options.clone(),
143138
),
144-
OpsKind::UnTracked(prep) => {
145-
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
146-
}
139+
B::conv2d(x.primitive, weight.primitive, None, options),
140+
),
141+
OpsKind::UnTracked(prep) => {
142+
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
147143
}
148-
}
144+
},
149145
}
150146
}
151147

@@ -211,57 +207,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
211207
}
212208

213209
match bias {
214-
Some(bias) => {
215-
match ConvTranspose2DWithBias
216-
.prepare(
217-
[x.node, weight.node, bias.node],
218-
[x.graph, weight.graph, bias.graph],
219-
)
220-
.stateful()
221-
{
222-
OpsKind::Tracked(prep) => prep.finish(
223-
(
224-
x.primitive.clone(),
225-
weight.primitive.clone(),
226-
bias.primitive.clone(),
227-
options.clone(),
228-
),
229-
B::conv_transpose2d(
230-
x.primitive,
231-
weight.primitive,
232-
Some(bias.primitive),
233-
options,
234-
),
210+
Some(bias) => match ConvTranspose2DWithBias
211+
.prepare(
212+
[x.node, weight.node, bias.node],
213+
[x.graph, weight.graph, bias.graph],
214+
)
215+
.stateful()
216+
{
217+
OpsKind::Tracked(prep) => prep.finish(
218+
(
219+
x.primitive.clone(),
220+
weight.primitive.clone(),
221+
bias.primitive.clone(),
222+
options.clone(),
235223
),
236-
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
224+
B::conv_transpose2d(
237225
x.primitive,
238226
weight.primitive,
239227
Some(bias.primitive),
240228
options,
241-
)),
242-
}
243-
}
244-
None => {
245-
match ConvTranspose2DNoBias
246-
.prepare([x.node, weight.node], [x.graph, weight.graph])
247-
.stateful()
248-
{
249-
OpsKind::Tracked(prep) => prep.finish(
250-
(
251-
x.primitive.clone(),
252-
weight.primitive.clone(),
253-
options.clone(),
254-
),
255-
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
256229
),
257-
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
258-
x.primitive,
259-
weight.primitive,
260-
None,
261-
options,
262-
)),
263-
}
264-
}
230+
),
231+
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
232+
x.primitive,
233+
weight.primitive,
234+
Some(bias.primitive),
235+
options,
236+
)),
237+
},
238+
None => match ConvTranspose2DNoBias
239+
.prepare([x.node, weight.node], [x.graph, weight.graph])
240+
.stateful()
241+
{
242+
OpsKind::Tracked(prep) => prep.finish(
243+
(
244+
x.primitive.clone(),
245+
weight.primitive.clone(),
246+
options.clone(),
247+
),
248+
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
249+
),
250+
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
251+
x.primitive,
252+
weight.primitive,
253+
None,
254+
options,
255+
)),
256+
},
265257
}
266258
}
267259

@@ -322,49 +314,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
322314
}
323315
}
324316
match bias {
325-
Some(bias) => {
326-
match Conv1DWithBias
327-
.prepare(
328-
[x.node, weight.node, bias.node],
329-
[x.graph, weight.graph, bias.graph],
330-
)
331-
.stateful()
332-
{
333-
OpsKind::Tracked(prep) => prep.finish(
334-
(
335-
x.primitive.clone(),
336-
weight.primitive.clone(),
337-
bias.primitive.clone(),
338-
options.clone(),
339-
),
340-
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
317+
Some(bias) => match Conv1DWithBias
318+
.prepare(
319+
[x.node, weight.node, bias.node],
320+
[x.graph, weight.graph, bias.graph],
321+
)
322+
.stateful()
323+
{
324+
OpsKind::Tracked(prep) => prep.finish(
325+
(
326+
x.primitive.clone(),
327+
weight.primitive.clone(),
328+
bias.primitive.clone(),
329+
options.clone(),
341330
),
342-
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
343-
x.primitive,
344-
weight.primitive,
345-
Some(bias.primitive),
346-
options,
347-
)),
348-
}
349-
}
350-
None => {
351-
match Conv1DNoBias
352-
.prepare([x.node, weight.node], [x.graph, weight.graph])
353-
.stateful()
354-
{
355-
OpsKind::Tracked(prep) => prep.finish(
356-
(
357-
x.primitive.clone(),
358-
weight.primitive.clone(),
359-
options.clone(),
360-
),
361-
B::conv1d(x.primitive, weight.primitive, None, options),
331+
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
332+
),
333+
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
334+
x.primitive,
335+
weight.primitive,
336+
Some(bias.primitive),
337+
options,
338+
)),
339+
},
340+
None => match Conv1DNoBias
341+
.prepare([x.node, weight.node], [x.graph, weight.graph])
342+
.stateful()
343+
{
344+
OpsKind::Tracked(prep) => prep.finish(
345+
(
346+
x.primitive.clone(),
347+
weight.primitive.clone(),
348+
options.clone(),
362349
),
363-
OpsKind::UnTracked(prep) => {
364-
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
365-
}
350+
B::conv1d(x.primitive, weight.primitive, None, options),
351+
),
352+
OpsKind::UnTracked(prep) => {
353+
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
366354
}
367-
}
355+
},
368356
}
369357
}
370358

@@ -430,57 +418,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
430418
}
431419

432420
match bias {
433-
Some(bias) => {
434-
match ConvTranspose1DWithBias
435-
.prepare(
436-
[x.node, weight.node, bias.node],
437-
[x.graph, weight.graph, bias.graph],
438-
)
439-
.stateful()
440-
{
441-
OpsKind::Tracked(prep) => prep.finish(
442-
(
443-
x.primitive.clone(),
444-
weight.primitive.clone(),
445-
bias.primitive.clone(),
446-
options.clone(),
447-
),
448-
B::conv_transpose1d(
449-
x.primitive,
450-
weight.primitive,
451-
Some(bias.primitive),
452-
options,
453-
),
421+
Some(bias) => match ConvTranspose1DWithBias
422+
.prepare(
423+
[x.node, weight.node, bias.node],
424+
[x.graph, weight.graph, bias.graph],
425+
)
426+
.stateful()
427+
{
428+
OpsKind::Tracked(prep) => prep.finish(
429+
(
430+
x.primitive.clone(),
431+
weight.primitive.clone(),
432+
bias.primitive.clone(),
433+
options.clone(),
454434
),
455-
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
435+
B::conv_transpose1d(
456436
x.primitive,
457437
weight.primitive,
458438
Some(bias.primitive),
459439
options,
460-
)),
461-
}
462-
}
463-
None => {
464-
match ConvTranspose1DNoBias
465-
.prepare([x.node, weight.node], [x.graph, weight.graph])
466-
.stateful()
467-
{
468-
OpsKind::Tracked(prep) => prep.finish(
469-
(
470-
x.primitive.clone(),
471-
weight.primitive.clone(),
472-
options.clone(),
473-
),
474-
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
475440
),
476-
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
477-
x.primitive,
478-
weight.primitive,
479-
None,
480-
options,
481-
)),
482-
}
483-
}
441+
),
442+
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
443+
x.primitive,
444+
weight.primitive,
445+
Some(bias.primitive),
446+
options,
447+
)),
448+
},
449+
None => match ConvTranspose1DNoBias
450+
.prepare([x.node, weight.node], [x.graph, weight.graph])
451+
.stateful()
452+
{
453+
OpsKind::Tracked(prep) => prep.finish(
454+
(
455+
x.primitive.clone(),
456+
weight.primitive.clone(),
457+
options.clone(),
458+
),
459+
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
460+
),
461+
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
462+
x.primitive,
463+
weight.primitive,
464+
None,
465+
options,
466+
)),
467+
},
484468
}
485469
}
486470

burn-core/src/nn/conv/checks.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
33
let channels_out_div_by_group = channels_out % groups == 0;
44

55
if !channels_in_div_by_group && !channels_out_div_by_group {
6-
panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}");
6+
panic!(
7+
"Both channels must be divisible by the number of groups. Got \
8+
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
9+
);
710
}
811
}

0 commit comments

Comments
 (0)