Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/DifferentiableProgramming.md
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,14 @@ extension Optional: Differentiable where Wrapped: Differentiable {

@noDerivative
public var zeroTangentVectorInitializer: () -> TangentVector {
{ TangentVector(.zero) }
switch self {
case nil:
return { TangentVector(nil) }
case let x?:
return { [zeroTanInit = x.zeroTangentVectorInitializer] in
TangentVector(zeroTanInit())
}
}
}
}
```
Expand Down
6 changes: 6 additions & 0 deletions stdlib/public/Differentiation/ArrayDifferentiation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ where Element: Differentiable {
base[i].move(along: direction.base[i])
}
}

/// A closure that produces a `TangentVector` of zeros with the same
/// `count` as `self`.
public var zeroTangentVectorInitializer: () -> TangentVector {
return base.zeroTangentVectorInitializer
}
}

extension Array.DifferentiableView: Equatable
Expand Down
1 change: 1 addition & 0 deletions stdlib/public/Differentiation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE
DifferentiationUtilities.swift
AnyDifferentiable.swift
ArrayDifferentiation.swift
OptionalDifferentiation.swift

GYB_SOURCES
FloatingPointDifferentiation.swift.gyb
Expand Down
83 changes: 83 additions & 0 deletions stdlib/public/Differentiation/OptionalDifferentiation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//===--- OptionalDifferentiation.swift ------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//

import Swift

extension Optional: Differentiable where Wrapped: Differentiable {
public struct TangentVector: Differentiable, AdditiveArithmetic {
public typealias TangentVector = Self

public var value: Wrapped.TangentVector?

public init(_ value: Wrapped.TangentVector?) {
self.value = value
}

public static var zero: Self {
return Self(.zero)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it seems more efficient to use Optional.TangentVector(nil) as Optional.TangentVector.zero. We can explore the implications when working on differentiation support for Optional.

}

public static func + (lhs: Self, rhs: Self) -> Self {
switch (lhs.value, rhs.value) {
case (nil, nil): return Self(nil)
case let (x?, nil): return Self(x)
case let (nil, y?): return Self(y)
case let (x?, y?): return Self(x + y)
}
}

public static func - (lhs: Self, rhs: Self) -> Self {
switch (lhs.value, rhs.value) {
case (nil, nil): return Self(nil)
case let (x?, nil): return Self(x)
case let (nil, y?): return Self(.zero - y)
case let (x?, y?): return Self(x - y)
}
}

public mutating func move(along direction: TangentVector) {
if let value = direction.value {
self.value?.move(along: value)
}
}

@noDerivative
public var zeroTangentVectorInitializer: () -> TangentVector {
switch value {
case nil:
return { Self(nil) }
case let x?:
return { [zeroTanInit = x.zeroTangentVectorInitializer] in
Self(zeroTanInit())
}
}
}
}

public mutating func move(along direction: TangentVector) {
if let value = direction.value {
self?.move(along: value)
}
}

@noDerivative
public var zeroTangentVectorInitializer: () -> TangentVector {
switch self {
case nil:
return { TangentVector(nil) }
case let x?:
return { [zeroTanInit = x.zeroTangentVectorInitializer] in
TangentVector(zeroTanInit())
}
}
}
}
73 changes: 73 additions & 0 deletions test/AutoDiff/stdlib/optional.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

import _Differentiation
import StdlibUnittest

var OptionalDifferentiationTests = TestSuite("OptionalDifferentiation")

OptionalDifferentiationTests.test("Optional operations") {
// Differentiable.move(along:)
do {
var some: Float? = 2
some.move(along: .init(3))
expectEqual(5, some)

var none: Float? = nil
none.move(along: .init(3))
expectEqual(nil, none)
}

// Differentiable.zeroTangentVectorInitializer
do {
let some: [Float]? = [1, 2, 3]
expectEqual(.init([0, 0, 0]), some.zeroTangentVectorInitializer())

let none: [Float]? = nil
expectEqual(.init(nil), none.zeroTangentVectorInitializer())
}
}

OptionalDifferentiationTests.test("Optional.TangentVector operations") {
// Differentiable.move(along:)
do {
var some: Optional<Float>.TangentVector = .init(2)
some.move(along: .init(3))
expectEqual(5, some.value)

var none: Optional<Float>.TangentVector = .init(nil)
none.move(along: .init(3))
expectEqual(nil, none.value)
}

// Differentiable.zeroTangentVectorInitializer
do {
var some: [Float]? = [1, 2, 3]
expectEqual(.init([0, 0, 0]), some.zeroTangentVectorInitializer())

var none: [Float]? = nil
expectEqual(.init(nil), none.zeroTangentVectorInitializer())
}

// AdditiveArithmetic.zero
expectEqual(.init(Float.zero), Float?.TangentVector.zero)
expectEqual(.init([Float].TangentVector.zero), [Float]?.TangentVector.zero)

// AdditiveArithmetic.+, AdditiveArithmetic.-
do {
var some: Optional<Float>.TangentVector = .init(2)
var none: Optional<Float>.TangentVector = .init(nil)

expectEqual(.init(4), some + some)
expectEqual(.init(2), some + none)
expectEqual(.init(2), none + some)
expectEqual(.init(nil), none + none)

expectEqual(.init(0), some - some)
expectEqual(.init(2), some - none)
expectEqual(.init(-2), none - some)
expectEqual(.init(nil), none - none)
}
}

runAllTests()