diff --git a/docs/DifferentiableFunctions.md b/docs/DifferentiableFunctions.md
index 64d81b96e75..c10c109eb95 100644
--- a/docs/DifferentiableFunctions.md
+++ b/docs/DifferentiableFunctions.md
@@ -259,8 +259,8 @@ extension Layer {
/// gradients at the layer and at the input, respectively.
func appliedForBackpropagation(to input: Input)
-> (output: Output,
- backpropagator: (_ direction: Output.CotangentVector)
- -> (layerGradient: CotangentVector, inputGradient: Input.CotangentVector)) {
+ backpropagator: (_ direction: Output.TangentVector)
+ -> (layerGradient: TangentVector, inputGradient: Input.TangentVector)) {
let (out, pullback) = valueWithPullback(at: input) { layer, input in
return layer(input)
}
@@ -347,13 +347,13 @@ Internally, `differentiableFunction(from:)` is defined just using the
```swift
/// Returns a differentiable function given its derivative.
public func differentiableFunction(
- from vjp: @escaping (T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector)
+ from vjp: @escaping (T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
) -> @differentiable (T) -> R {
func original(_ x: T) -> R {
return vjp(x).value
}
@differentiating(original)
- func derivative(_ x: T) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) {
+ func derivative(_ x: T) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) {
return vjp(x)
}
return original
@@ -431,8 +431,8 @@ defined:
// functions. It simply returns `pullback(1)`.
func gradient(
at x: T, in f: @differentiable (T) -> R
-) -> T.CotangentVector
- where T: Differentiable, R: FloatingPoint & Differentiable, R.CotangentVector == R
+) -> T.TangentVector
+ where T: Differentiable, R: FloatingPoint & Differentiable, R.TangentVector == R
{
let (value, pullback) = valueWithPullback(at: x, in: f)
return pullback(R(1))
diff --git a/docs/DifferentiableTypes.md b/docs/DifferentiableTypes.md
index 01f9c8d39e5..266165a662e 100644
--- a/docs/DifferentiableTypes.md
+++ b/docs/DifferentiableTypes.md
@@ -51,7 +51,7 @@ print(𝛁v)
// Vector(x: 2.0, y: 0.0, z: 0.0)
```
-A `Differentiable`-conforming type may have stored properties that are not meant to have a derivative with respect to `self`. Use the `@noDerivative` attribute to mark those properties; they will not have a corresponding entry in the synthesized `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables` struct types.
+A `Differentiable`-conforming type may have stored properties that are not meant to have a derivative with respect to `self`. Use the `@noDerivative` attribute to mark those properties; they will not have a corresponding entry in the synthesized `TangentVector` and `AllDifferentiableVariables` struct types.
Here’s an example deep learning layer with some `@noDerivative` properties:
@@ -103,20 +103,12 @@ public protocol Differentiable {
/// The tangent bundle of this differentiable manifold.
associatedtype TangentVector: AdditiveArithmetic & Differentiable
where TangentVector.TangentVector == TangentVector,
- TangentVector.CotangentVector == CotangentVector,
TangentVector.AllDifferentiableVariables == TangentVector
- /// The cotangent bundle of this differentiable manifold.
- associatedtype CotangentVector: AdditiveArithmetic & Differentiable
- where CotangentVector.TangentVector == CotangentVector,
- CotangentVector.CotangentVector == TangentVector,
- CotangentVector.AllDifferentiableVariables == CotangentVector
-
/// The type of all differentiable variables in this type.
associatedtype AllDifferentiableVariables: Differentiable
where AllDifferentiableVariables.AllDifferentiableVariables == AllDifferentiableVariables,
AllDifferentiableVariables.TangentVector == TangentVector,
- AllDifferentiableVariables.CotangentVector == CotangentVector
/// All differentiable variables in this type.
var allDifferentiableVariables: AllDifferentiableVariables { get }
@@ -124,9 +116,6 @@ public protocol Differentiable {
/// Returns `self` moved along the value space towards the given tangent vector.
/// In Riemannian geometry (mathematics), this represents exponential map.
func moved(along direction: TangentVector) -> Self
-
- /// Converts a cotangent vector to its corresponding tangent vector.
- func tangentVector(from cotangent: CotangentVector) -> TangentVector
}
```
@@ -141,20 +130,15 @@ Mathematically, `Differentiable` represents a [differentiable manifold]: this is
Here is a detailed explanation of the `Differentiable` protocol:
-* `associatedtype TangentVector` represents the type of directional derivatives computed via forward-mode differentiation.
-* `associatedtype CotangentVector` represents the type of gradient values computed via reverse-mode differentiation.
- * `CotangentVector` types are used and produced by differential operators like `gradient` and `pullback`.
+* `associatedtype TangentVector` represents the type of derivatives.
* `var allDifferentiableVariables: AllDifferentiableVariables` represents all differentiable variables in an instance of the conforming type, where `associatedtype AllDifferentiableVariables` is the type of all differentiable variables.
* The motivation/design behind "all differentiable variables" is enabling key-path-based parameter optimization by making parameters and their gradients have the same type. Read the [synthesis rules](#compiler-synthesized-implementations) below and the [parameter optimization document][parameter-optimization] for more information.
-* `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables` are closely related.
+* `TangentVector` and `AllDifferentiableVariables` are closely related.
* All three associated types must themselves conform to `Differentiable`.
* The `Differentiable` protocol associated types of the associated types themselves are defined to be mathematically correct.
* `Foo.TangentVector.TangentVector` is `Foo.TangentVector` itself.
- * `Foo.CotangentVector.TangentVector` is `Foo.CotangentVector` itself.
- * `Foo.TangentVector.CotangentVector` is `Foo.CotangentVector`.
- * `Foo.CotangentVector.CotangentVector` is `Foo.TangentVector`.
- * `Foo.AllDifferentiableVariables` has the same `TangentVector` and `CotangentVector` as `Foo`.
- * Additionally, `TangentVector` and `CotangentVector` must conform to `AdditiveArithmetic`, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation.
+ * `Foo.AllDifferentiableVariables` has the same `TangentVector` as `Foo`.
+ * Additionally, `TangentVector` must conform to `AdditiveArithmetic`, so that they can be zero-initialized and accumulated via addition. These are necessary to perform the chain rule of differentiation.
* Manifold operations.
* These currently involve `tangentVector(from:)` and `moved(along:)`. These operations can be useful for implementing manifold-related algorithms, like optimization on manifolds, but are not relevant for simple differentiation use cases.
@@ -163,7 +147,6 @@ The standard library defines conformances to the `Differentiable` protocol for `
```swift
extension Float: Differentiable {
public typealias TangentVector = Float
- public typealias CotangentVector = Float
public typealias AllDifferentiableVariables = Float
}
// Conformances for `Double` and `Float80` are defined similarly.
@@ -171,7 +154,6 @@ extension Float: Differentiable {
// `Tensor` is defined in the TensorFlow library and represents a multidimensional array.
extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint {
public typealias TangentVector = Tensor
- public typealias CotangentVector = Tensor
public typealias AllDifferentiableVariables = Tensor
}
```
@@ -190,16 +172,16 @@ The synthesis behavior is explained below.
### Associated type synthesis
-Here are the synthesis rules for the three `Differentiable` associated types: `TangentVector`, `CotangentVector`, and `AllDifferentiableVariables`.
+Here are the synthesis rules for the two `Differentiable` associated types: `TangentVector` and `AllDifferentiableVariables`.
Let "differentiation properties" refer to all stored properties of the conforming type that are not marked with `@noDerivative`. These stored properties are guaranteed by the synthesis condition to all conform to `Differentiable`.
The synthesis rules are:
* Set associated types to `Self`, if possible.
- * If the conforming type conforms to `AdditiveArithmetic`, and no `@noDerivative` stored properties exist, and all stored properties satisfy `Self == Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables`, then all associated types can be set to typealiases of `Self`.
-* Synthesize a single `AllDifferentiableVariables` member struct. Set `TangentVector` and `CotangentVector` to `AllDifferentiableVariables` if possible; otherwise synthesize more member structs.
+ * If the conforming type conforms to `AdditiveArithmetic`, and no `@noDerivative` stored properties exist, and all stored properties satisfy `Self == Self.TangentVector == Self.AllDifferentiableVariables`, then all associated types can be set to typealiases of `Self`.
+* Synthesize a single `AllDifferentiableVariables` member struct. Set `TangentVector` to `AllDifferentiableVariables` if possible; otherwise synthesize more member structs.
* Regarding member struct synthesis: for each "differentiation property" in the conforming type, a corresponding stored property is synthesized in the member structs, with type equal to the property’s associated type.
- * `TangentVector` and `CotangentVector` can be set to `AllDifferentiableVariables` if all differentiation properties conform to `AdditiveArithmetic` and satisfy `Self.TangentVector == Self.CotangentVector == Self.AllDifferentiableVariables`. This is useful because it prevents redundant struct synthesis. Also, this enables [key-path-based parameter optimization][parameter-optimization] because parameters and gradients have the same type.
+ * `TangentVector` can be set to `AllDifferentiableVariables` if all differentiation properties conform to `AdditiveArithmetic` and satisfy `Self.TangentVector == Self.AllDifferentiableVariables`. This is useful because it prevents redundant struct synthesis. Also, this enables [key-path-based parameter optimization][parameter-optimization] because parameters and gradients have the same type.
A memberwise initializer is synthesized for the conforming type itself, in addition to all associated structs. This is important for differentiating struct properties accesses and synthesizing manifold operation requirements.
@@ -229,21 +211,13 @@ Manifold operations are synthesized to forward the same operation defined on dif
```swift
// Let `Foo` be the name of the type conforming to `Differentiable`.
-func tangentVector(from cotangent: CotangentVector) -> TangentVector {
- return TangentVector(x: x.tangentVector(from: cotangent.x), ...)
-}
func moved(along tangent: TangentVector) -> Foo {
- return Foo(x: x.moved(along: tangent.x), ...)
+ Foo(x: x.moved(along: tangent.x), ...)
}
-// Potential shortcuts for synthesis:
-// When `TangentVector == CotangentVector`:
-func tangentVector(from cotangent: CotangentVector) -> TangentVector {
- return cotangent
-}
-// When `Foo == TangentVector`:
+// Potential shortcut for synthesis, when `Foo == TangentVector`:
func moved(along tangent: TangentVector) -> Foo {
- return tangent
+ self + tangent
}
```
@@ -266,11 +240,6 @@ struct GenericWrapper: Differentiable {
// var y: U.TangentVector
// ...
// }
- // struct CotangentVector: Differentiable, AdditiveArithmetic {
- // var x: T.CotangentVector
- // var y: U.CotangentVector
- // ...
- // }
// struct AllDifferentiableVariables: Differentiable {
// var x: T.AllDifferentiableVariables
// var y: U.AllDifferentiableVariables
@@ -280,10 +249,6 @@ struct GenericWrapper: Differentiable {
// get { return AllDifferentiableVariables(x: x, y: y) }
// set { x = newValue.x; y = newValue.y }
// }
- // func tangentVector(from cotangent: CotangentVector) -> TangentVector {
- // return TangentVector(x: x.tangentVector(from: cotangent.x),
- // y: y.tangentVector(from: cotangent.y))
- // }
// func moved(along tangent: TangentVector) -> Foo {
// return GenericWrapper(x: x.moved(along: tangent.x)
// y: y.moved(along: tangent.y))