diff --git a/docs/DifferentiableFunctions.md b/docs/DifferentiableFunctions.md index 64d81b96e75..c10c109eb95 100644 --- a/docs/DifferentiableFunctions.md +++ b/docs/DifferentiableFunctions.md @@ -259,8 +259,8 @@ extension Layer { /// gradients at the layer and at the input, respectively. func appliedForBackpropagation(to input: Input) -> (output: Output, - backpropagator: (_ direction: Output.CotangentVector) - -> (layerGradient: CotangentVector, inputGradient: Input.CotangentVector)) { + backpropagator: (_ direction: Output.TangentVector) + -> (layerGradient: TangentVector, inputGradient: Input.TangentVector)) { let (out, pullback) = valueWithPullback(at: input) { layer, input in return layer(input) } @@ -347,13 +347,13 @@ Internally, `differentiableFunction(from:)` is defined just using the ```swift /// Returns a differentiable function given its derivative. public func differentiableFunction( - from vjp: @escaping (T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) + from vjp: @escaping (T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) ) -> @differentiable (T) -> R { func original(_ x: T) -> R { return vjp(x).value } @differentiating(original) - func derivative(_ x: T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) { + func derivative(_ x: T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) { return vjp(x) } return original @@ -431,8 +431,8 @@ defined: // functions. It simply returns `pullback(1)`. func gradient( at x: T, in f: @differentiable (T) -> R -) -> T.CotangentVector - where T: Differentiable, R: FloatingPoint & Differentiable, R.CotangentVector == R +) -> T.TangentVector + where T: Differentiable, R: FloatingPoint & Differentiable, R.TangentVector == R { let (value, pullback) = valueWithPullback(at: x, in: f) return pullback(R(1)) diff --git a/docs/DifferentiableTypes.md b/docs/DifferentiableTypes.md index 01f9c8d39e5..266165a662e 100644 --- a/docs/DifferentiableTypes.md +++ b/docs/DifferentiableTypes.md @@ -51,7 +51,7 @@ print(𝛁v) // Vector(x: 2.0, y: 0.0, z: 0.0) ``` -A `Differentiable`-conforming type may have stored properties that are not meant to have a derivative with respect to `self`. Use the `@noDerivative` attribute to mark those properties; they will not have a corresponding entry in the synthesized `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables` struct types. +A `Differentiable`-conforming type may have stored properties that are not meant to have a derivative with respect to `self`. Use the `@noDerivative` attribute to mark those properties; they will not have a corresponding entry in the synthesized `TangentVector` and `AllDifferentiableVariables` struct types. Here’s an example deep learning layer with some `@noDerivative` properties: @@ -103,20 +103,12 @@ public protocol Differentiable { /// The tangent bundle of this differentiable manifold. associatedtype TangentVector: AdditiveArithmetic & Differentiable where TangentVector.TangentVector == TangentVector, - TangentVector.CotangentVector == CotangentVector, TangentVector.AllDifferentiableVariables == TangentVector - /// The cotangent bundle of this differentiable manifold. - associatedtype CotangentVector: AdditiveArithmetic & Differentiable - where CotangentVector.TangentVector == CotangentVector, - CotangentVector.CotangentVector == TangentVector, - CotangentVector.AllDifferentiableVariables == CotangentVector - /// The type of all differentiable variables in this type. associatedtype AllDifferentiableVariables: Differentiable where AllDifferentiableVariables.AllDifferentiableVariables == AllDifferentiableVariables, AllDifferentiableVariables.TangentVector == TangentVector, - AllDifferentiableVariables.CotangentVector == CotangentVector /// All differentiable variables in this type. var allDifferentiableVariables: AllDifferentiableVariables { get } @@ -124,9 +116,6 @@ public protocol Differentiable { /// Returns `self` moved along the value space towards the given tangent vector. /// In Riemannian geometry (mathematics), this represents exponential map. func moved(along direction: TangentVector) -> Self - - /// Converts a cotangent vector to its corresponding tangent vector. - func tangentVector(from cotangent: CotangentVector) -> TangentVector } ``` @@ -141,20 +130,15 @@ Mathematically, `Differentiable` represents a [differentiable manifold]: this is

Here is a detailed explanation of the `Differentiable` protocol: -* `associatedtype TangentVector` represents the type of directional derivatives computed via forward-mode differentiation. -* `associatedtype CotangentVector` represents the type of gradient values computed via reverse-mode differentiation. - * `CotangentVector` types are used and produced by differential operators like `gradient` and `pullback`. +* `associatedtype TangentVector` represents the type of derivatives. * `var allDifferentiableVariables: AllDifferentiableVariables` represents all differentiable variables in an instance of the conforming type, where `associatedtype AllDifferentiableVariables` is the type of all differentiable variables. * The motivation/design behind "all differentiable variables" is enabling key-path-based parameter optimization by making parameters and their gradients have the same type. Read the [synthesis rules](#compiler-synthesized-implementations) below and the [parameter optimization document][parameter-optimization] for more information. -* `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables` are closely related. +* `TangentVector` and `AllDifferentiableVariables` are closely related. * All three associated types must themselves conform to `Differentiable`. * The `Differentiable` protocol associated types of the associated types themselves are defined to be mathematically correct. * `Foo.TangentVector.TangentVector` is `Foo.TangentVector` itself. - * `Foo.CotangentVector.TangentVector` is `Foo.CotangentVector` itself. - * `Foo.TangentVector.CotangentVector` is `Foo.CotangentVector`. - * `Foo.CotangentVector.CotangentVector` is `Foo.TangentVector`. - * `Foo.AllDifferentiableVariables` has the same `TangentVector` and `CotangentVector` as `Foo`. - * Additionally, `TangentVector` and `CotangentVector` must conform to `AdditiveArithmetic`, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation. + * `Foo.AllDifferentiableVariables` has the same `TangentVector` as `Foo`. + * Additionally, `TangentVector` must conform to `AdditiveArithmetic`, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation. * Manifold operations. * These currently involve `tangentVector(from:)` and `moved(along:)`. These operations can be useful for implementing manifold-related algorithms, like optimization on manifolds, but are not relevant for simple differentiation use cases. @@ -163,7 +147,6 @@ The standard library defines conformances to the `Differentiable` protocol for ` ```swift extension Float: Differentiable { public typealias TangentVector = Float - public typealias CotangentVector = Float public typealias AllDifferentiableVariables = Float } // Conformances for `Double` and `Float80` are defined similarly. @@ -171,7 +154,6 @@ extension Float: Differentiable { // `Tensor` is defined in the TensorFlow library and represents a multidimensional array. extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint { public typealias TangentVector = Tensor - public typealias CotangentVector = Tensor public typealias AllDifferentiableVariables = Tensor } ``` @@ -190,16 +172,16 @@ The synthesis behavior is explained below. ### Associated type synthesis -Here are the synthesis rules for the three `Differentiable` associated types: `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables`. +Here are the synthesis rules for the two `Differentiable` associated types: `TangentVector` and `AllDifferentiableVariables`. Let "differentiation properties" refer to all stored properties of the conforming type that are not marked with `@noDerivative`. These stored properties are guaranteed by the synthesis condition to all conform to `Differentiable`. The synthesis rules are: * Set associated types to `Self`, if possible. - * If the conforming type conforms to `AdditiveArithmetic`, and no `@noDerivative` stored properties exist, and all stored properties satisfy `Self == Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables`, then all associated types can be set to typealiases of `Self`. -* Synthesize a single `AllDifferentiableVariables` member struct. Set `TangentVector` and `CotangentVector` to `AllDifferentiableVariables` if possible; otherwise synthesize more member structs. + * If the conforming type conforms to `AdditiveArithmetic`, and no `@noDerivative` stored properties exist, and all stored properties satisfy `Self == Self.TangentVector == Self.AllDifferentiableVariables`, then all associated types can be set to typealiases of `Self`. +* Synthesize a single `AllDifferentiableVariables` member struct. Set `TangentVector` to `AllDifferentiableVariables` if possible; otherwise synthesize more member structs. * Regarding member struct synthesis: for each "differentiation property" in the conforming type, a corresponding stored property is synthesized in the member structs, with type equal to the property’s associated type. - * `TangentVector` and `CotangentVector` can be set to `AllDifferentiableVariables` if all differentiation properties conform to `AdditiveArithmetic` and satisfy `Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables`. This is useful because it prevents redundant struct synthesis. Also, this enables [key-path-based parameter optimization][parameter-optimization] because parameters and gradients have the same type. + * `TangentVector` can be set to `AllDifferentiableVariables` if all differentiation properties conform to `AdditiveArithmetic` and satisfy `Self.TangentVector == Self.AllDifferentiableVariables`. This is useful because it prevents redundant struct synthesis. Also, this enables [key-path-based parameter optimization][parameter-optimization] because parameters and gradients have the same type. A memberwise initializer is synthesized for the conforming type itself, in addition to all associated structs. This is important for differentiating struct properties accesses and synthesizing manifold operation requirements. @@ -229,21 +211,13 @@ Manifold operations are synthesized to forward the same operation defined on dif ```swift // Let `Foo` be the name of the type conforming to `Differentiable`. -func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return TangentVector(x: x.tangentVector(from: cotangent.x), ...) -} func moved(along tangent: TangentVector) -> Foo { - return Foo(x: x.moved(along: tangent.x), ...) + Foo(x: x.moved(along: tangent.x), ...) } -// Potential shortcuts for synthesis: -// When `TangentVector == CotangentVector`: -func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return cotangent -} -// When `Foo == TangentVector`: +// Potential shortcut for synthesis, when `Foo == TangentVector`: func moved(along tangent: TangentVector) -> Foo { - return tangent + self + tangent } ``` @@ -266,11 +240,6 @@ struct GenericWrapper: Differentiable { // var y: U.TangentVector // ... // } - // struct CotangentVector: Differentiable, AdditiveArithmetic { - // var x: T.CotangentVector - // var y: U.CotangentVector - // ... - // } // struct AllDifferentiableVariables: Differentiable { // var x: T.AllDifferentiableVariables // var y: U.AllDifferentiableVariables @@ -280,10 +249,6 @@ struct GenericWrapper: Differentiable { // get { return AllDifferentiableVariables(x: x, y: y) } // set { x = newValue.x; y = newValue.y } // } - // func tangentVector(from cotangent: CotangentVector) -> TangentVector { - // return TangentVector(x: x.tangentVector(from: cotangent.x), - // y: y.tangentVector(from: cotangent.y)) - // } // func moved(along tangent: TangentVector) -> Foo { // return GenericWrapper(x: x.moved(along: tangent.x) // y: y.moved(along: tangent.y))