diff --git a/Autoencoder/main.swift b/Autoencoder/main.swift index f575c00760c..0d69e0925f4 100644 --- a/Autoencoder/main.swift +++ b/Autoencoder/main.swift @@ -94,7 +94,7 @@ struct Autoencoder: Layer { activation: tanh) @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4) return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4) } diff --git a/CIFAR/Data.swift b/CIFAR/Data.swift index dafac78848a..a2dd55948c9 100644 --- a/CIFAR/Data.swift +++ b/CIFAR/Data.swift @@ -29,9 +29,32 @@ func downloadCIFAR10IfNotPresent(to directory: String = ".") { } } +extension Tensor where Scalar : _TensorFlowDataTypeCompatible { + public var _tfeTensorHandle: _AnyTensorHandle { + TFETensorHandle(_owning: handle._cTensorHandle) + } +} + struct Example: TensorGroup { var label: Tensor var data: Tensor + + init(label: Tensor, data: Tensor) { + self.label = label + self.data = data + } + + public init( + _handles: C + ) where C.Element: _AnyTensorHandle { + precondition(_handles.count == 2) + let labelIndex = _handles.startIndex + let dataIndex = _handles.index(labelIndex, offsetBy: 1) + label = Tensor(handle: TensorHandle(handle: _handles[labelIndex])) + data = Tensor(handle: TensorHandle(handle: _handles[dataIndex])) + } + + public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] } } // Each CIFAR data file is provided as a Python pickle of NumPy arrays diff --git a/CIFAR/Models.swift b/CIFAR/Models.swift index 516b9f14838..4608beada6c 100644 --- a/CIFAR/Models.swift +++ b/CIFAR/Models.swift @@ -29,7 +29,7 @@ struct PyTorchModel: Layer { var dense3 = Dense(inputSize: 84, outputSize: 10, activation: identity) @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let convolved = input.sequenced(through: conv1, pool1, conv2, pool2) return convolved.sequenced(through: flatten, dense1, dense2, dense3) } @@ -54,7 +54,7 @@ struct KerasModel: Layer { var dense2 = Dense(inputSize: 512, outputSize: 10, activation: identity) @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let conv1 = input.sequenced(through: conv1a, conv1b, pool1, dropout1) let conv2 = conv1.sequenced(through: conv2a, conv2b, pool2, dropout2) return conv2.sequenced(through: flatten, dense1, dropout3, dense2) diff --git a/CIFAR/ResNet.swift b/CIFAR/ResNet.swift index b14fab3d71f..9ef5100b10d 100644 --- a/CIFAR/ResNet.swift +++ b/CIFAR/ResNet.swift @@ -37,7 +37,7 @@ struct Conv2DBatchNorm: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: conv, norm) } } @@ -68,7 +68,7 @@ struct BasicBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let blocksReduced = blocks.differentiableReduce(input) { last, layer in relu(layer(last)) } @@ -97,7 +97,7 @@ struct ResNet: Layer { var classifier = Dense(inputSize: 64, outputSize: 10, activation: softmax) @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let tmp = relu(inputLayer(input)) let convolved = tmp.sequenced(through: basicBlock1, basicBlock2, basicBlock3) return convolved.sequenced(through: averagePool, flatten, classifier) diff --git a/CIFAR/WideResNet.swift b/CIFAR/WideResNet.swift index a58b0529714..0677c0a1c82 100644 --- a/CIFAR/WideResNet.swift +++ b/CIFAR/WideResNet.swift @@ -41,7 +41,7 @@ struct BatchNormConv2DBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let firstLayer = conv1(relu(norm1(input))) return conv2(relu(norm2(firstLayer))) } @@ -87,7 +87,7 @@ struct WideResNetBasicBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let blocksReduced = blocks.differentiableReduce(input) { last, layer in relu(layer(last)) } @@ -126,7 +126,7 @@ struct WideResNet: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = input.sequenced(through: l1, l2, l3, l4) let finalNorm = relu(norm(inputLayer)) return finalNorm.sequenced(through: avgPool, flatten, classifier) diff --git a/Catch/README.md b/Catch/README.md index 566f7d449e1..136be5630c0 100644 --- a/Catch/README.md +++ b/Catch/README.md @@ -23,5 +23,5 @@ installed. Make sure you've added the correct version of `swift` to your path. To train the model, run: ``` -swift -O catch.swift +swift -O main.swift ``` diff --git a/Catch/main.swift b/Catch/main.swift index a6192e6df3c..c5d92aa6b1e 100644 --- a/Catch/main.swift +++ b/Catch/main.swift @@ -53,7 +53,7 @@ struct Model: Layer { generator: &rng) @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: layer1, layer2) } } @@ -83,7 +83,7 @@ extension CatchAgent { let (ŷ, backprop) = model.appliedForBackpropagation(to: x) let maxIndex = ŷ.argmax().scalarized() - let 𝛁loss = -log(Tensor(ŷ.max())).broadcast(like: ŷ) * previousReward + let 𝛁loss = -log(Tensor(ŷ.max())).broadcasted(like: ŷ) * previousReward let (𝛁model, _) = backprop(𝛁loss) optimizer.update(&model.allDifferentiableVariables, along: 𝛁model) diff --git a/Gym/CartPole/main.swift b/Gym/CartPole/main.swift index e989fe4266b..fb2224aa721 100644 --- a/Gym/CartPole/main.swift +++ b/Gym/CartPole/main.swift @@ -48,7 +48,7 @@ struct Net: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: l1, l2) } } diff --git a/MNIST/main.swift b/MNIST/main.swift index ecf10f0f572..56b82b1d220 100644 --- a/MNIST/main.swift +++ b/MNIST/main.swift @@ -64,7 +64,7 @@ struct Classifier: Layer { var layer1b = Dense(inputSize: 128, outputSize: 10, activation: softmax) @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let convolved = input.sequenced(through: conv1a, conv1b, pool1) return convolved.sequenced(through: dropout1a, flatten, layer1a, dropout1b, layer1b) } diff --git a/MiniGo/Models/GoModel.swift b/MiniGo/Models/GoModel.swift index e46d759ddff..23941100aed 100644 --- a/MiniGo/Models/GoModel.swift +++ b/MiniGo/Models/GoModel.swift @@ -61,7 +61,7 @@ struct ConvBN: Layer { } @differentiable - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { return norm(conv(input)) } } @@ -90,7 +90,7 @@ struct ResidualIdentityBlock: Layer { } @differentiable - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { var tmp = relu(layer1(input)) tmp = layer2(tmp) return relu(tmp + input) @@ -158,7 +158,7 @@ public struct GoModel: Layer { } @differentiable(wrt: (self, input), vjp: _vjpCall) - public func call(_ input: Tensor) -> GoModelOutput { + public func callAsFunction(_ input: Tensor) -> GoModelOutput { let batchSize = input.shape[0] var output = relu(initialConv(input)) diff --git a/ResNet/Data.swift b/ResNet/Data.swift index cf101551b79..e94bbd48e39 100644 --- a/ResNet/Data.swift +++ b/ResNet/Data.swift @@ -32,9 +32,32 @@ func downloadCIFAR10IfNotPresent(to directory: String = ".") { } } +extension Tensor where Scalar : _TensorFlowDataTypeCompatible { + public var _tfeTensorHandle: _AnyTensorHandle { + TFETensorHandle(_owning: handle._cTensorHandle) + } +} + struct Example: TensorGroup { var label: Tensor var data: Tensor + + init(label: Tensor, data: Tensor) { + self.label = label + self.data = data + } + + public init( + _handles: C + ) where C.Element: _AnyTensorHandle { + precondition(_handles.count == 2) + let labelIndex = _handles.startIndex + let dataIndex = _handles.index(labelIndex, offsetBy: 1) + label = Tensor(handle: TensorHandle(handle: _handles[labelIndex])) + data = Tensor(handle: TensorHandle(handle: _handles[dataIndex])) + } + + public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] } } // Each CIFAR data file is provided as a Python pickle of NumPy arrays diff --git a/ResNet/ResNet50.swift b/ResNet/ResNet50.swift index 487cee5eded..6684677237d 100644 --- a/ResNet/ResNet50.swift +++ b/ResNet/ResNet50.swift @@ -37,7 +37,7 @@ struct ConvBN: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: conv, norm) } } @@ -65,7 +65,7 @@ struct ResidualBasicBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return layer2(relu(layer1(input))) } } @@ -94,7 +94,7 @@ struct ResidualBasicBlockShortcut: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return layer2(relu(layer1(input))) + shortcut(input) } } @@ -127,7 +127,7 @@ struct ResidualConvBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let tmp = relu(layer2(relu(layer1(input)))) return relu(layer3(tmp) + shortcut(input)) } @@ -150,7 +150,7 @@ struct ResidualIdentityBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let tmp = relu(layer2(relu(layer1(input)))) return relu(layer3(tmp) + input) } @@ -175,7 +175,7 @@ struct ResidualIdentityBlockStack: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: block1, block2, block3, block4, block5) } } @@ -218,7 +218,7 @@ struct ResNet18: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = maxPool(relu(l1(input))) let level2 = inputLayer.sequenced(through: l2a, l2b) let level3 = level2.sequenced(through: l3a, l3b) @@ -274,7 +274,7 @@ struct ResNet34: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = maxPool(relu(l1(input))) let level2 = inputLayer.sequenced(through: l2a, l2b, l2c) let level3 = level2.sequenced(through: l3a, l3b, l3c, l3d) @@ -326,7 +326,7 @@ struct ResNet50: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = maxPool(relu(l1(input))) let level2 = inputLayer.sequenced(through: l2a, l2b, l2c) let level3 = level2.sequenced(through: l3a, l3b, l3c, l3d) @@ -383,7 +383,7 @@ struct ResNet101: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = maxPool(relu(l1(input))) let level2 = inputLayer.sequenced(through: l2a, l2b, l2c) let level3 = level2.sequenced(through: l3a, l3b, l3c, l3d) @@ -441,7 +441,7 @@ struct ResNet152: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = maxPool(relu(l1(input))) let level2 = inputLayer.sequenced(through: l2a, l2b, l2c) let level3 = level2.sequenced(through: l3a, l3b, l3c, l3d) diff --git a/ResNet/ResNetV2.swift b/ResNet/ResNetV2.swift index b8545e8ee48..cb3dbb5c90b 100644 --- a/ResNet/ResNetV2.swift +++ b/ResNet/ResNetV2.swift @@ -38,7 +38,7 @@ struct Conv2DBatchNorm: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: conv, norm) } } @@ -60,7 +60,7 @@ struct BatchNormConv2D: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return conv(relu(norm(input))) } } @@ -88,7 +88,7 @@ struct PreActivatedResidualBasicBlock: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: layer1, layer2) } } @@ -117,7 +117,7 @@ struct PreActivatedResidualBasicBlockShortcut: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { return input.sequenced(through: layer1, layer2) + shortcut(input) } } @@ -162,7 +162,7 @@ struct PreActivatedResNet18: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = input.sequenced(through: l1, maxPool) let level2 = inputLayer.sequenced(through: l2a, l2b) let level3 = level2.sequenced(through: l3a, l3b) @@ -221,7 +221,7 @@ struct PreActivatedResNet34: Layer { } @differentiable - func call(_ input: Input) -> Output { + func callAsFunction(_ input: Input) -> Output { let inputLayer = input.sequenced(through: l1, maxPool) let level2 = inputLayer.sequenced(through: l2a, l2b, l2c) let level3 = level2.sequenced(through: l3a, l3b, l3c, l3d) diff --git a/Transformer/Model.swift b/Transformer/Model.swift index cad4dbf65f1..3fcf71ccd50 100644 --- a/Transformer/Model.swift +++ b/Transformer/Model.swift @@ -23,7 +23,7 @@ struct TimeDistributed: Layer { } @differentiable(wrt: (self, input)) - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { let (batchSize, timeSteps, features) = (input.shape[0], input.shape[1], input.shape[2]) let reshaped = input.reshaped(to: [batchSize * timeSteps, features]) let output = dense(reshaped) @@ -45,7 +45,7 @@ struct FeedForward: Layer { } @differentiable(wrt: (self, input)) - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { return input.sequenced(through: dense1, dropout, dense2) } } @@ -117,13 +117,13 @@ struct Attention: Layer { } @differentiable(wrt: (self, input)) - func call(_ input: AttentionInput) -> Tensor { + func callAsFunction(_ input: AttentionInput) -> Tensor { var dotProducts = batchedMatmul(input.query, input.key, adjointRight: true) dotProducts = causallyMasked(dotProducts, enable: causal) / scale return batchedMatmul(dropout(softmax(dotProducts)), input.value) } - func call(_ input: AttentionInput, state: inout AttentionContext) -> Tensor { + func callAsFunction(_ input: AttentionInput, state: inout AttentionContext) -> Tensor { state = AttentionContext( key: state.key.concatenated(with: input.key, alongAxis: 1), value: state.value.concatenated(with: input.value, alongAxis: 1)) @@ -192,7 +192,7 @@ struct MultiHeadAttention: Layer { } @differentiable(wrt: (self, input)) - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { let qkvProjected = wqkv(input) let qkvSplit = splitHeads(qkvProjected, headCount: headCount) let attentionInput = splitQKV(qkvSplit) @@ -200,7 +200,7 @@ struct MultiHeadAttention: Layer { return wo(joinHeads(outputs, headCount: headCount)) } - func call(_ input: Tensor, state: inout AttentionContext) -> Tensor { + func callAsFunction(_ input: Tensor, state: inout AttentionContext) -> Tensor { let qkvProjected = wqkv(input) let qkvSplit = splitQKV(qkvProjected) let attentionInput = makeAttentionInput( @@ -234,14 +234,14 @@ struct EncoderLayer: Layer { } @differentiable(wrt: (self, input)) - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { let attended = input + input.sequenced( through: selfAttentionNorm, selfAttention, selfAttentionDropout) return attended + attended.sequenced( through: feedForwardNorm, feedForward, feedForwardDropout) } - func call(_ input: Tensor, state: inout AttentionContext) -> Tensor { + func callAsFunction(_ input: Tensor, state: inout AttentionContext) -> Tensor { var tmp = input tmp = selfAttentionNorm(tmp) tmp = selfAttention(tmp, state: &state) @@ -264,7 +264,7 @@ struct Embedding: Differentiable { } @differentiable(wrt: self) - func call(_ input: Tensor) -> Tensor { + func callAsFunction(_ input: Tensor) -> Tensor { return weight.gathering(atIndices: input) } } @@ -275,14 +275,14 @@ struct TransformerLM { var layers: [EncoderLayer] var norm: LayerNorm - func call(_ tokens: Tensor, states: inout [AttentionContext]) -> Tensor { + func callAsFunction(_ tokens: Tensor, states: inout [AttentionContext]) -> Tensor { let positions = (0..(shape: [1, tokens.shape[1]], scalars: positions) var h = embedding(tokens) h = h + positionalEmbeddings.gathering(atIndices: positionsTensor) for i in 0..( } }) } - -extension Tensor where Scalar: TensorFlowFloatingPoint { - /// Gathers slices of self at the specified indices along the first axis. The result has the - /// same size in the first axis as the scalar count of the index tensor, and the same - /// size in subsequent axes as self. - @differentiable(wrt: self, vjp: _vjpGathering) - func gathering(atIndices indices: Tensor) -> Tensor { - return Raw.gather(params: self, indices: indices) - } - - func _vjpGathering(atIndices indices: Tensor) -> (Tensor, (Tensor) -> Tensor) { - let value = gathering(atIndices: indices) - return (value, { [wShape = shape] seed in - var valuesShape = wShape - valuesShape[0] = indices.scalarCount - let values = seed.reshaped(to: valuesShape) - let indices = indices.reshaped(to: [indices.scalarCount]) - // TODO provide an option for sparse embedding gradients (e.g. equivalent of Python - // IndexedSlices) - return Raw.unsortedSegmentSum( - data: values, - segmentIds: indices, - numSegments: Tensor(Int32(wShape[0]))) - }) - } -}