@@ -25,79 +25,67 @@ public struct BatchNormConv2DBlock: Layer {
2525 public var conv1 : Conv2D < Float >
2626 public var norm2 : BatchNorm < Float >
2727 public var conv2 : Conv2D < Float >
28+ public var shortcut : Conv2D < Float >
29+ let isExpansion : Bool
30+ let dropout : Dropout < Float > = Dropout ( probability: 0.3 )
2831
2932 public init (
30- filterShape: ( Int , Int , Int , Int ) ,
33+ featureCounts: ( Int , Int ) ,
34+ kernelSize: Int = 3 ,
3135 strides: ( Int , Int ) = ( 1 , 1 ) ,
3236 padding: Padding = . same
3337 ) {
34- self . norm1 = BatchNorm ( featureCount: filterShape. 2 )
35- self . conv1 = Conv2D ( filterShape: filterShape, strides: strides, padding: padding)
36- self . norm2 = BatchNorm ( featureCount: filterShape. 3 )
37- self . conv2 = Conv2D ( filterShape: filterShape, strides: ( 1 , 1 ) , padding: padding)
38+ self . norm1 = BatchNorm ( featureCount: featureCounts. 0 )
39+ self . conv1 = Conv2D (
40+ filterShape: ( kernelSize, kernelSize, featureCounts. 0 , featureCounts. 1 ) ,
41+ strides: strides,
42+ padding: padding)
43+ self . norm2 = BatchNorm ( featureCount: featureCounts. 1 )
44+ self . conv2 = Conv2D ( filterShape: ( kernelSize, kernelSize, featureCounts. 1 , featureCounts. 1 ) ,
45+ strides: ( 1 , 1 ) ,
46+ padding: padding)
47+ self . shortcut = Conv2D ( filterShape: ( 1 , 1 , featureCounts. 0 , featureCounts. 1 ) ,
48+ strides: strides,
49+ padding: padding)
50+ self . isExpansion = featureCounts. 1 != featureCounts. 0 || strides != ( 1 , 1 )
3851 }
3952
4053 @differentiable
4154 public func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
42- let firstLayer = conv1 ( relu ( norm1 ( input) ) )
43- return conv2 ( relu ( norm2 ( firstLayer) ) )
55+ let preact1 = relu ( norm1 ( input) )
56+ var residual = conv1 ( preact1)
57+ let preact2 : Tensor < Float >
58+ let shortcutResult : Tensor < Float >
59+ if isExpansion {
60+ shortcutResult = shortcut ( preact1)
61+ preact2 = relu ( norm2 ( residual) )
62+ } else {
63+ shortcutResult = input
64+ preact2 = dropout ( relu ( norm2 ( residual) ) )
65+ }
66+ residual = conv2 ( preact2)
67+ return residual + shortcutResult
4468 }
4569}
4670
4771public struct WideResNetBasicBlock : Layer {
4872 public var blocks : [ BatchNormConv2DBlock ]
49- public var shortcut : Conv2D < Float >
5073
5174 public init (
5275 featureCounts: ( Int , Int ) ,
5376 kernelSize: Int = 3 ,
5477 depthFactor: Int = 2 ,
55- widenFactor: Int = 1 ,
5678 initialStride: ( Int , Int ) = ( 2 , 2 )
5779 ) {
58- if initialStride == ( 1 , 1 ) {
59- self . blocks = [
60- BatchNormConv2DBlock (
61- filterShape: (
62- kernelSize, kernelSize,
63- featureCounts. 0 , featureCounts. 1 * widenFactor
64- ) ,
65- strides: initialStride)
66- ]
67- self . shortcut = Conv2D (
68- filterShape: ( 1 , 1 , featureCounts. 0 , featureCounts. 1 * widenFactor) ,
69- strides: initialStride)
70- } else {
71- self . blocks = [
72- BatchNormConv2DBlock (
73- filterShape: (
74- kernelSize, kernelSize,
75- featureCounts. 0 * widenFactor, featureCounts. 1 * widenFactor
76- ) ,
77- strides: initialStride)
78- ]
79- self . shortcut = Conv2D (
80- filterShape: ( 1 , 1 , featureCounts. 0 * widenFactor, featureCounts. 1 * widenFactor) ,
81- strides: initialStride)
82- }
80+ self . blocks = [ BatchNormConv2DBlock ( featureCounts: featureCounts, strides: initialStride) ]
8381 for _ in 1 ..< depthFactor {
84- self . blocks += [
85- BatchNormConv2DBlock (
86- filterShape: (
87- kernelSize, kernelSize,
88- featureCounts. 1 * widenFactor, featureCounts. 1 * widenFactor
89- ) ,
90- strides: ( 1 , 1 ) )
91- ]
92- }
82+ self . blocks += [ BatchNormConv2DBlock ( featureCounts: ( featureCounts. 1 , featureCounts. 1 ) ) ]
83+ }
9384 }
9485
9586 @differentiable
9687 public func callAsFunction( _ input: Tensor < Float > ) -> Tensor < Float > {
97- let blocksReduced = blocks. differentiableReduce ( input) { last, layer in
98- relu ( layer ( last) )
99- }
100- return relu ( blocksReduced + shortcut( input) )
88+ return blocks. differentiableReduce ( input) { $1 ( $0) }
10189 }
10290}
10391
@@ -116,15 +104,12 @@ public struct WideResNet: Layer {
116104 public init ( depthFactor: Int = 2 , widenFactor: Int = 8 ) {
117105 self . l1 = Conv2D ( filterShape: ( 3 , 3 , 3 , 16 ) , strides: ( 1 , 1 ) , padding: . same)
118106
119- l2 = WideResNetBasicBlock (
120- featureCounts: ( 16 , 16 ) , depthFactor: depthFactor,
121- widenFactor: widenFactor, initialStride: ( 1 , 1 ) )
122- l3 = WideResNetBasicBlock (
123- featureCounts: ( 16 , 32 ) , depthFactor: depthFactor,
124- widenFactor: widenFactor)
125- l4 = WideResNetBasicBlock (
126- featureCounts: ( 32 , 64 ) , depthFactor: depthFactor,
127- widenFactor: widenFactor)
107+ self . l2 = WideResNetBasicBlock (
108+ featureCounts: ( 16 , 16 * widenFactor) , depthFactor: depthFactor, initialStride: ( 1 , 1 ) )
109+ self . l3 = WideResNetBasicBlock ( featureCounts: ( 16 * widenFactor, 32 * widenFactor) ,
110+ depthFactor: depthFactor)
111+ self . l4 = WideResNetBasicBlock ( featureCounts: ( 32 * widenFactor, 64 * widenFactor) ,
112+ depthFactor: depthFactor)
128113
129114 self . norm = BatchNorm ( featureCount: 64 * widenFactor)
130115 self . avgPool = AvgPool2D ( poolSize: ( 8 , 8 ) , strides: ( 8 , 8 ) )
0 commit comments