@@ -103,49 +103,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
103
103
}
104
104
105
105
match bias {
106
- Some ( bias) => {
107
- match Conv2DWithBias
108
- . prepare (
109
- [ x. node , weight. node , bias. node ] ,
110
- [ x. graph , weight. graph , bias. graph ] ,
111
- )
112
- . stateful ( )
113
- {
114
- OpsKind :: Tracked ( prep) => prep. finish (
115
- (
116
- x. primitive . clone ( ) ,
117
- weight. primitive . clone ( ) ,
118
- bias. primitive . clone ( ) ,
119
- options. clone ( ) ,
120
- ) ,
121
- B :: conv2d ( x. primitive , weight. primitive , Some ( bias. primitive ) , options) ,
106
+ Some ( bias) => match Conv2DWithBias
107
+ . prepare (
108
+ [ x. node , weight. node , bias. node ] ,
109
+ [ x. graph , weight. graph , bias. graph ] ,
110
+ )
111
+ . stateful ( )
112
+ {
113
+ OpsKind :: Tracked ( prep) => prep. finish (
114
+ (
115
+ x. primitive . clone ( ) ,
116
+ weight. primitive . clone ( ) ,
117
+ bias. primitive . clone ( ) ,
118
+ options. clone ( ) ,
122
119
) ,
123
- OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv2d (
124
- x. primitive ,
125
- weight. primitive ,
126
- Some ( bias. primitive ) ,
127
- options,
128
- ) ) ,
129
- }
130
- }
131
- None => {
132
- match Conv2DNoBias
133
- . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
134
- . stateful ( )
135
- {
136
- OpsKind :: Tracked ( prep) => prep. finish (
137
- (
138
- x. primitive . clone ( ) ,
139
- weight. primitive . clone ( ) ,
140
- options. clone ( ) ,
141
- ) ,
142
- B :: conv2d ( x. primitive , weight. primitive , None , options) ,
120
+ B :: conv2d ( x. primitive , weight. primitive , Some ( bias. primitive ) , options) ,
121
+ ) ,
122
+ OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv2d (
123
+ x. primitive ,
124
+ weight. primitive ,
125
+ Some ( bias. primitive ) ,
126
+ options,
127
+ ) ) ,
128
+ } ,
129
+ None => match Conv2DNoBias
130
+ . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
131
+ . stateful ( )
132
+ {
133
+ OpsKind :: Tracked ( prep) => prep. finish (
134
+ (
135
+ x. primitive . clone ( ) ,
136
+ weight. primitive . clone ( ) ,
137
+ options. clone ( ) ,
143
138
) ,
144
- OpsKind :: UnTracked ( prep) => {
145
- prep. finish ( B :: conv2d ( x. primitive , weight. primitive , None , options) )
146
- }
139
+ B :: conv2d ( x. primitive , weight. primitive , None , options) ,
140
+ ) ,
141
+ OpsKind :: UnTracked ( prep) => {
142
+ prep. finish ( B :: conv2d ( x. primitive , weight. primitive , None , options) )
147
143
}
148
- }
144
+ } ,
149
145
}
150
146
}
151
147
@@ -211,57 +207,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
211
207
}
212
208
213
209
match bias {
214
- Some ( bias) => {
215
- match ConvTranspose2DWithBias
216
- . prepare (
217
- [ x. node , weight. node , bias. node ] ,
218
- [ x. graph , weight. graph , bias. graph ] ,
219
- )
220
- . stateful ( )
221
- {
222
- OpsKind :: Tracked ( prep) => prep. finish (
223
- (
224
- x. primitive . clone ( ) ,
225
- weight. primitive . clone ( ) ,
226
- bias. primitive . clone ( ) ,
227
- options. clone ( ) ,
228
- ) ,
229
- B :: conv_transpose2d (
230
- x. primitive ,
231
- weight. primitive ,
232
- Some ( bias. primitive ) ,
233
- options,
234
- ) ,
210
+ Some ( bias) => match ConvTranspose2DWithBias
211
+ . prepare (
212
+ [ x. node , weight. node , bias. node ] ,
213
+ [ x. graph , weight. graph , bias. graph ] ,
214
+ )
215
+ . stateful ( )
216
+ {
217
+ OpsKind :: Tracked ( prep) => prep. finish (
218
+ (
219
+ x. primitive . clone ( ) ,
220
+ weight. primitive . clone ( ) ,
221
+ bias. primitive . clone ( ) ,
222
+ options. clone ( ) ,
235
223
) ,
236
- OpsKind :: UnTracked ( prep ) => prep . finish ( B :: conv_transpose2d (
224
+ B :: conv_transpose2d (
237
225
x. primitive ,
238
226
weight. primitive ,
239
227
Some ( bias. primitive ) ,
240
228
options,
241
- ) ) ,
242
- }
243
- }
244
- None => {
245
- match ConvTranspose2DNoBias
246
- . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
247
- . stateful ( )
248
- {
249
- OpsKind :: Tracked ( prep) => prep. finish (
250
- (
251
- x. primitive . clone ( ) ,
252
- weight. primitive . clone ( ) ,
253
- options. clone ( ) ,
254
- ) ,
255
- B :: conv_transpose2d ( x. primitive , weight. primitive , None , options) ,
256
229
) ,
257
- OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv_transpose2d (
258
- x. primitive ,
259
- weight. primitive ,
260
- None ,
261
- options,
262
- ) ) ,
263
- }
264
- }
230
+ ) ,
231
+ OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv_transpose2d (
232
+ x. primitive ,
233
+ weight. primitive ,
234
+ Some ( bias. primitive ) ,
235
+ options,
236
+ ) ) ,
237
+ } ,
238
+ None => match ConvTranspose2DNoBias
239
+ . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
240
+ . stateful ( )
241
+ {
242
+ OpsKind :: Tracked ( prep) => prep. finish (
243
+ (
244
+ x. primitive . clone ( ) ,
245
+ weight. primitive . clone ( ) ,
246
+ options. clone ( ) ,
247
+ ) ,
248
+ B :: conv_transpose2d ( x. primitive , weight. primitive , None , options) ,
249
+ ) ,
250
+ OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv_transpose2d (
251
+ x. primitive ,
252
+ weight. primitive ,
253
+ None ,
254
+ options,
255
+ ) ) ,
256
+ } ,
265
257
}
266
258
}
267
259
@@ -322,49 +314,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
322
314
}
323
315
}
324
316
match bias {
325
- Some ( bias) => {
326
- match Conv1DWithBias
327
- . prepare (
328
- [ x. node , weight. node , bias. node ] ,
329
- [ x. graph , weight. graph , bias. graph ] ,
330
- )
331
- . stateful ( )
332
- {
333
- OpsKind :: Tracked ( prep) => prep. finish (
334
- (
335
- x. primitive . clone ( ) ,
336
- weight. primitive . clone ( ) ,
337
- bias. primitive . clone ( ) ,
338
- options. clone ( ) ,
339
- ) ,
340
- B :: conv1d ( x. primitive , weight. primitive , Some ( bias. primitive ) , options) ,
317
+ Some ( bias) => match Conv1DWithBias
318
+ . prepare (
319
+ [ x. node , weight. node , bias. node ] ,
320
+ [ x. graph , weight. graph , bias. graph ] ,
321
+ )
322
+ . stateful ( )
323
+ {
324
+ OpsKind :: Tracked ( prep) => prep. finish (
325
+ (
326
+ x. primitive . clone ( ) ,
327
+ weight. primitive . clone ( ) ,
328
+ bias. primitive . clone ( ) ,
329
+ options. clone ( ) ,
341
330
) ,
342
- OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv1d (
343
- x. primitive ,
344
- weight. primitive ,
345
- Some ( bias. primitive ) ,
346
- options,
347
- ) ) ,
348
- }
349
- }
350
- None => {
351
- match Conv1DNoBias
352
- . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
353
- . stateful ( )
354
- {
355
- OpsKind :: Tracked ( prep) => prep. finish (
356
- (
357
- x. primitive . clone ( ) ,
358
- weight. primitive . clone ( ) ,
359
- options. clone ( ) ,
360
- ) ,
361
- B :: conv1d ( x. primitive , weight. primitive , None , options) ,
331
+ B :: conv1d ( x. primitive , weight. primitive , Some ( bias. primitive ) , options) ,
332
+ ) ,
333
+ OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv1d (
334
+ x. primitive ,
335
+ weight. primitive ,
336
+ Some ( bias. primitive ) ,
337
+ options,
338
+ ) ) ,
339
+ } ,
340
+ None => match Conv1DNoBias
341
+ . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
342
+ . stateful ( )
343
+ {
344
+ OpsKind :: Tracked ( prep) => prep. finish (
345
+ (
346
+ x. primitive . clone ( ) ,
347
+ weight. primitive . clone ( ) ,
348
+ options. clone ( ) ,
362
349
) ,
363
- OpsKind :: UnTracked ( prep) => {
364
- prep. finish ( B :: conv1d ( x. primitive , weight. primitive , None , options) )
365
- }
350
+ B :: conv1d ( x. primitive , weight. primitive , None , options) ,
351
+ ) ,
352
+ OpsKind :: UnTracked ( prep) => {
353
+ prep. finish ( B :: conv1d ( x. primitive , weight. primitive , None , options) )
366
354
}
367
- }
355
+ } ,
368
356
}
369
357
}
370
358
@@ -430,57 +418,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
430
418
}
431
419
432
420
match bias {
433
- Some ( bias) => {
434
- match ConvTranspose1DWithBias
435
- . prepare (
436
- [ x. node , weight. node , bias. node ] ,
437
- [ x. graph , weight. graph , bias. graph ] ,
438
- )
439
- . stateful ( )
440
- {
441
- OpsKind :: Tracked ( prep) => prep. finish (
442
- (
443
- x. primitive . clone ( ) ,
444
- weight. primitive . clone ( ) ,
445
- bias. primitive . clone ( ) ,
446
- options. clone ( ) ,
447
- ) ,
448
- B :: conv_transpose1d (
449
- x. primitive ,
450
- weight. primitive ,
451
- Some ( bias. primitive ) ,
452
- options,
453
- ) ,
421
+ Some ( bias) => match ConvTranspose1DWithBias
422
+ . prepare (
423
+ [ x. node , weight. node , bias. node ] ,
424
+ [ x. graph , weight. graph , bias. graph ] ,
425
+ )
426
+ . stateful ( )
427
+ {
428
+ OpsKind :: Tracked ( prep) => prep. finish (
429
+ (
430
+ x. primitive . clone ( ) ,
431
+ weight. primitive . clone ( ) ,
432
+ bias. primitive . clone ( ) ,
433
+ options. clone ( ) ,
454
434
) ,
455
- OpsKind :: UnTracked ( prep ) => prep . finish ( B :: conv_transpose1d (
435
+ B :: conv_transpose1d (
456
436
x. primitive ,
457
437
weight. primitive ,
458
438
Some ( bias. primitive ) ,
459
439
options,
460
- ) ) ,
461
- }
462
- }
463
- None => {
464
- match ConvTranspose1DNoBias
465
- . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
466
- . stateful ( )
467
- {
468
- OpsKind :: Tracked ( prep) => prep. finish (
469
- (
470
- x. primitive . clone ( ) ,
471
- weight. primitive . clone ( ) ,
472
- options. clone ( ) ,
473
- ) ,
474
- B :: conv_transpose1d ( x. primitive , weight. primitive , None , options) ,
475
440
) ,
476
- OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv_transpose1d (
477
- x. primitive ,
478
- weight. primitive ,
479
- None ,
480
- options,
481
- ) ) ,
482
- }
483
- }
441
+ ) ,
442
+ OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv_transpose1d (
443
+ x. primitive ,
444
+ weight. primitive ,
445
+ Some ( bias. primitive ) ,
446
+ options,
447
+ ) ) ,
448
+ } ,
449
+ None => match ConvTranspose1DNoBias
450
+ . prepare ( [ x. node , weight. node ] , [ x. graph , weight. graph ] )
451
+ . stateful ( )
452
+ {
453
+ OpsKind :: Tracked ( prep) => prep. finish (
454
+ (
455
+ x. primitive . clone ( ) ,
456
+ weight. primitive . clone ( ) ,
457
+ options. clone ( ) ,
458
+ ) ,
459
+ B :: conv_transpose1d ( x. primitive , weight. primitive , None , options) ,
460
+ ) ,
461
+ OpsKind :: UnTracked ( prep) => prep. finish ( B :: conv_transpose1d (
462
+ x. primitive ,
463
+ weight. primitive ,
464
+ None ,
465
+ options,
466
+ ) ) ,
467
+ } ,
484
468
}
485
469
}
486
470
0 commit comments