diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 4233e8121..73f1b263a 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1009,12 +1009,32 @@ public struct AvgPool2D: Layer { } } + +/// A global average pooling layer for temporal data. +@_fixed_layout +public struct GlobalAveragePooling1D: Layer { + /// Creates a global average pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + return input.mean(alongAxes: 1).reshaped(to: [input.shape[0], input.shape[2]]) + } +} + /// A global average pooling layer for spatial data. @_fixed_layout public struct GlobalAveragePooling2D: Layer { /// Creates a global average pooling layer. public init() {} - + /// Returns the output obtained from applying the layer to the given input. /// /// - Parameters: