diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 73f1b263a..7c4cf1d23 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1022,7 +1022,7 @@ public struct GlobalAveragePooling1D: Layer { /// - input: The input to the layer. /// - context: The contextual information for the layer application, e.g. the current learning /// phase. - /// - Returns: The output + /// - Returns: The output. @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { return input.mean(alongAxes: 1).reshaped(to: [input.shape[0], input.shape[2]]) @@ -1048,6 +1048,25 @@ public struct GlobalAveragePooling2D: Layer { } } +/// A global average pooling layer for spatial and spatio-temporal data. +@_fixed_layout +public struct GlobalAveragePooling3D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + return input.mean(alongAxes: [1, 2, 3]).reshaped(to: [input.shape[0], input.shape[4]]) + } +} + /// A layer that applies layer normalization over a mini-batch of inputs. /// /// Reference: [Layer Normalization](https://arxiv.org/abs/1607.06450).