Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
fb2d647
Merge in work from branch jvp-emitter.
bartchr808 Jul 9, 2019
34fcebe
Init commit.
bartchr808 Jul 10, 2019
571b5e8
WIP.
bartchr808 Jul 10, 2019
1651a78
Make DifferentialEmitter create differential.
bartchr808 Jul 10, 2019
2ce55ec
Merge branch 'tensorflow' into differential-emitter
bartchr808 Jul 10, 2019
a06d54d
Typos and small bugs.
bartchr808 Jul 10, 2019
0ca031b
Get visitApplyInst and visitReturnInst working.
bartchr808 Jul 10, 2019
652a6d7
Remove multi block logic.
bartchr808 Jul 11, 2019
8fd3e44
WIP
bartchr808 Jul 12, 2019
5aeca34
Merge branch 'tensorflow' into differential-emitter
bartchr808 Jul 12, 2019
d04b479
Accidentally removed mapper.
bartchr808 Jul 12, 2019
966c470
Move JVP lower and start adding AdjointValue.
bartchr808 Jul 12, 2019
5883668
Create differential builder.
bartchr808 Jul 15, 2019
f27a2bd
Get value mapping for seed params.
bartchr808 Jul 15, 2019
da8b1d6
PullbackInfo -> LinearMapInfo.
bartchr808 Jul 16, 2019
8aade2f
Get correct output, have retain memory leak.
bartchr808 Jul 16, 2019
17776b5
Move around code.
bartchr808 Jul 16, 2019
e313857
Fix release bug.
bartchr808 Jul 16, 2019
0d7c17c
More cleanup.
bartchr808 Jul 16, 2019
e60e478
More cleanup x2.
bartchr808 Jul 16, 2019
95c7321
Add tangent accumulation.
bartchr808 Jul 16, 2019
e20f3a6
Merge branch 'tensorflow' into differential-emitter
bartchr808 Jul 16, 2019
f20b98a
PR feedback #1.
bartchr808 Jul 16, 2019
3ad5c96
WIP: start adding tests.
bartchr808 Jul 17, 2019
3fb255d
Merge branch 'tensorflow' into differential-emitter
bartchr808 Jul 17, 2019
0973f3f
Fix SIL tests.
bartchr808 Jul 17, 2019
2ac6316
Remove tangent aggregation logic and cleanup.
bartchr808 Jul 19, 2019
75aa17c
Merge branch 'tensorflow' into differential-emitter
bartchr808 Jul 19, 2019
bca2df3
Fix comments, add jvpNegate, add/modify tests, simplify lit flag.
bartchr808 Jul 19, 2019
71db0ca
Merge branch 'tensorflow' into differential-emitter
bartchr808 Jul 22, 2019
14eef90
PR feedback.
bartchr808 Jul 24, 2019
64e334d
Update new test due to upstream changes and existing tests now that J…
bartchr808 Jul 24, 2019
8eb5bf5
PR feedback.
bartchr808 Jul 24, 2019
0b1f430
Fix tests and PR feedback.
bartchr808 Jul 24, 2019
5a6cdae
PR feedback.
bartchr808 Jul 24, 2019
c8fac01
[NFC] Fix naming.
dan-zheng Jul 24, 2019
293dca1
PR feedback and additional 'Tracked<Float>' tests.
bartchr808 Jul 25, 2019
4881e7b
[NFC] Small name and spacing changes.
bartchr808 Jul 25, 2019
96a5c66
Merge branch 'tensorflow' into differential-emitter
bartchr808 Aug 5, 2019
6889345
Revamp linear map info struct creation.
bartchr808 Aug 9, 2019
9ad0596
Merge branch 'tensorflow' into differential-emitter
bartchr808 Aug 9, 2019
cbd0fa2
Merge branch 'tensorflow' into differential-emitter
bartchr808 Aug 12, 2019
54b97bd
WIP: throw fatal error in JVPs that aren't defined.
bartchr808 Aug 12, 2019
cbce4d4
Throw fatal error in JVPs that aren't defined.
bartchr808 Aug 12, 2019
6b82eee
Gardening.
bartchr808 Aug 12, 2019
af1f241
Merge branch 'differential-emitter' of https://github.com/bartchr808/…
bartchr808 Aug 19, 2019
e4025d3
Merge branch 'tensorflow' into differential-emitter
bartchr808 Aug 19, 2019
1883ac1
Make changes for ownership change.
bartchr808 Aug 19, 2019
fb68404
Style feedback.
bartchr808 Aug 20, 2019
129a3d5
Add more tests (classes, protocols) and PR feedback.
bartchr808 Aug 20, 2019
adaf91b
PR feedback.
bartchr808 Aug 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,460 changes: 1,137 additions & 323 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ public struct Tracked<T> {
}
private var handle: Box

@differentiable(vjp: _vjpInit where T : Differentiable, T == T.TangentVector)
@differentiable(jvp: _jvpInit, vjp: _vjpInit where T : Differentiable, T == T.TangentVector)
public init(_ value: T) {
self.handle = Box(value)
}

@differentiable(vjp: _vjpValue where T : Differentiable, T == T.TangentVector)
@differentiable(jvp: _jvpValue, vjp: _vjpValue where T : Differentiable, T == T.TangentVector)
public var value: T {
get { handle.value }
set { handle.value = newValue }
Expand Down Expand Up @@ -177,10 +177,21 @@ extension Tracked where T : Differentiable, T == T.TangentVector {
return (Tracked(value), { v in v.value })
}

@usableFromInline
internal static func _jvpInit(_ value: T)
-> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) {
return (Tracked(value), { v in Tracked(v) })
}

@usableFromInline
internal func _vjpValue() -> (T, (T.TangentVector) -> Self.TangentVector) {
return (value, { v in Tracked(v) })
}

@usableFromInline
internal func _jvpValue() -> (T, (Self.TangentVector) -> T.TangentVector) {
return (value, { v in v.value })
}
}

extension Tracked where T : Differentiable, T == T.TangentVector {
Expand All @@ -197,6 +208,20 @@ extension Tracked where T : Differentiable, T == T.TangentVector {
-> (value: Self, pullback: (Self) -> (Self, Self)) {
return (lhs - rhs, { v in (v, .zero - v) })
}

@usableFromInline
@differentiating(+)
internal static func _vjpAdd(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> (Self)) {
return (lhs + rhs, { (dx, dy) in dx + dy })
}

@usableFromInline
@differentiating(-)
internal static func _vjpSubtract(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> (Self)) {
return (lhs - rhs, { (dx, dy) in dx - dy })
}
}

extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
Expand All @@ -207,6 +232,13 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude,
-> (value: Self, pullback: (Self) -> (Self, Self)) {
return (lhs * rhs, { v in (v * rhs, v * lhs) })
}

@usableFromInline
@differentiating(*)
internal static func _vjpMultiply(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> (Self)) {
return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs })
}
}

extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector {
Expand All @@ -216,6 +248,13 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector
-> (value: Self, pullback: (Self) -> (Self, Self)) {
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
}

@usableFromInline
@differentiating(/)
internal static func _vjpDivide(lhs: Self, rhs: Self)
-> (value: Self, differential: (Self, Self) -> (Self)) {
return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy })
}
}

// Differential operators for `Tracked<Float>`.
Expand Down
17 changes: 14 additions & 3 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ public func pullback<T, U, V, R>(

@inlinable
public func derivative<T: FloatingPoint, R>(
at x: T, in f: @escaping @differentiable (T) -> R
at x: T, in f: @differentiable (T) -> R
) -> R.TangentVector
where T.TangentVector == T {
return differential(at: x, in: f)(T(1))
}

@inlinable
public func derivative<T: FloatingPoint, U: FloatingPoint, R>(
at x: T, _ y: U, in f: @escaping @differentiable (T, U) -> R
at x: T, _ y: U, in f: @differentiable (T, U) -> R
) -> R.TangentVector
where T.TangentVector == T,
U.TangentVector == U {
Expand All @@ -496,7 +496,7 @@ public func derivative<T: FloatingPoint, U: FloatingPoint, R>(

@inlinable
public func derivative<T: FloatingPoint, U: FloatingPoint, V: FloatingPoint, R>(
at x: T, _ y: U, _ z: V, in f: @escaping @differentiable (T, U, V) -> R
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R
) -> R.TangentVector
where T.TangentVector == T,
U.TangentVector == U,
Expand Down Expand Up @@ -995,3 +995,14 @@ public extension Array where Element: Differentiable {
return (value: values, pullback: pullback)
}
}

//===----------------------------------------------------------------------===//
// JVP Diagnostics
//===----------------------------------------------------------------------===//
@_silgen_name("_printJVPErrorAndExit")
public func _printJVPErrorAndExit() -> Never {
fatalError("""
JVP does not exist. Differential-first differentiation APIs are \
experimental and should not be used.
""")
}
45 changes: 45 additions & 0 deletions stdlib/public/core/FloatingPointTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,15 @@ extension ${Self} {
-> (value: ${Self}, pullback: (${Self}) -> ${Self}) {
return (-x, { v in -v })
}

@usableFromInline
@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiating(-)
static func _jvpNegate(x: ${Self})
-> (value: ${Self}, differential: (${Self}) -> ${Self}) {
return (-x, { dx in -dx })
}
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1797,6 +1806,15 @@ extension ${Self} {
return (lhs + rhs, { v in (v, v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(+)
static func _jvpAdd(
lhs: ${Self}, rhs: ${Self}
) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(-)
Expand All @@ -1806,6 +1824,15 @@ extension ${Self} {
return (lhs - rhs, { v in (v, -v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(-)
static func _jvpSubtract(
lhs: ${Self}, rhs: ${Self}
) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(*)
Expand All @@ -1815,6 +1842,15 @@ extension ${Self} {
return (lhs * rhs, { v in (rhs * v, lhs * v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(*)
static func _jvpMultiply(
lhs: ${Self}, rhs: ${Self}
) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
return (lhs * rhs, { (dlhs, drhs) in lhs * drhs + rhs * dlhs })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(/)
Expand All @@ -1823,6 +1859,15 @@ extension ${Self} {
) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) {
return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) })
}

@inlinable // FIXME(sil-serialize-all)
@_transparent
@differentiating(/)
static func _jvpDivide(
lhs: ${Self}, rhs: ${Self}
) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) {
return (lhs / rhs, { (dlhs, drhs) in dlhs / rhs - lhs / (rhs * rhs) * drhs })
}
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 0 additions & 1 deletion test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ func activeInoutArg(_ x: Float) -> Float {
// expected-error @+1 {{function is not differentiable}}
_ = pullback(at: .zero, in: activeInoutArg(_:))


func activeInoutArgTuple(_ x: Float) -> Float {
var tuple = (x, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ _ = pullback(at: Wrapper(1)) { x in x + x * x }
// CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper
// CHECK-SIL-LABEL: // static Wrapper.+ infix(_:_:)
// CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper

95 changes: 95 additions & 0 deletions test/AutoDiff/forward_mode_diagnostics.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// RUN: %target-swift-frontend -Xllvm -run-jvp-generation -emit-sil -verify %s

// TODO: move these tests back into `autodiff_diagnostics.swift` once
// forward mode reaches feature parity with reverse mode.

//===----------------------------------------------------------------------===//
// Basic function
//===----------------------------------------------------------------------===//

func one_to_one_0(_ x: Float) -> Float {
return x + 2
}

_ = derivative(at: 0, in: one_to_one_0) // okay!

//===----------------------------------------------------------------------===//
// Function composition
//===----------------------------------------------------------------------===//

func base(_ x: Float) -> Float {
// expected-error @+2 2 {{expression is not differentiable}}
// expected-note @+1 2 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
return Float(Int(x))
}

// TODO: Fix nested differentiation diagnostics. Need to fix indirect differentiation invokers.
func nested(_ x: Float) -> Float {
// xpected-note @+1 {{when differentiating this function call}}
return base(x)
}

func middle(_ x: Float) -> Float {
// xpected-note @+1 {{when differentiating this function call}}
return nested(x)
}

func middle2(_ x: Float) -> Float {
// xpected-note @+1 {{when differentiating this function call}}
return middle(x)
}

func func_to_diff(_ x: Float) -> Float {
// xpected-note @+1 {{expression is not differentiable}}
return middle2(x)
}

func calls_diff_of_nested(_ x: Float) -> Float {
// xpected-error @+1 {{function is not differentiable}}
return derivative(at: x, in: func_to_diff)
}

//===----------------------------------------------------------------------===//
// Inout arguments
//===----------------------------------------------------------------------===//

func activeInoutArg(_ x: Float) -> Float {
var a = x
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
a += x
return a
}
// expected-error @+1 {{function is not differentiable}}
_ = differential(at: .zero, in: activeInoutArg(_:))

func activeInoutArgTuple(_ x: Float) -> Float {
var tuple = (x, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
tuple.0 *= x
return x * tuple.0
}
// expected-error @+1 {{function is not differentiable}}
_ = differential(at: .zero, in: activeInoutArgTuple(_:))

//===----------------------------------------------------------------------===//
// Non-varied results
//===----------------------------------------------------------------------===//

func one() -> Float {
return 1
}
@differentiable
func nonVariedResult(_ x: Float) -> Float {
// expected-warning @+1 2 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}}
return one()
}

//===----------------------------------------------------------------------===//
// Subset parameters
//===----------------------------------------------------------------------===//

func nondiff(_ f: @differentiable (Float, @nondiff Float) -> Float) -> Float {
// expected-note @+2 {{cannot differentiate with respect to a '@nondiff' parameter}}
// expected-error @+1 {{function is not differentiable}}
return derivative(at: 2, 3) { (x, y) in f(x * x, y) }
}
Loading