Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ IDENTIFIER(zero)
IDENTIFIER(Scalar)
// Differentiable
IDENTIFIER(AllDifferentiableVariables)
IDENTIFIER(CotangentVector)
IDENTIFIER(TangentVector)
IDENTIFIER(Derivative)
IDENTIFIER(Gradient)
IDENTIFIER(allDifferentiableVariables)
IDENTIFIER(moved)
IDENTIFIER(tangentVector)
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4157,10 +4157,10 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
Identifier associatedTypeIdentifier;
switch (kind) {
case AutoDiffAssociatedVectorSpaceKind::Tangent:
associatedTypeIdentifier = ctx.Id_TangentVector;
associatedTypeIdentifier = ctx.Id_Derivative;
break;
case AutoDiffAssociatedVectorSpaceKind::Cotangent:
associatedTypeIdentifier = ctx.Id_CotangentVector;
associatedTypeIdentifier = ctx.Id_Gradient;
break;
}
auto associatedTypeLookup =
Expand Down
54 changes: 27 additions & 27 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
// Nominal type must not customize `TangentVector`, `CotangentVector`, or
// `AllDifferentiableVariables` to anything other than `Self`.
// Otherwise, synthesis is semantically unsupported.
auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector);
auto cotangentDecls = nominal->lookupDirect(C.Id_CotangentVector);
auto tangentDecls = nominal->lookupDirect(C.Id_Derivative);
auto cotangentDecls = nominal->lookupDirect(C.Id_Gradient);
auto allDiffableVarsDecls =
nominal->lookupDirect(C.Id_AllDifferentiableVariables);
auto nominalTypeInContext =
Expand Down Expand Up @@ -435,7 +435,7 @@ static ValueDecl *deriveDifferentiable_moved(DerivedConformance &derived) {
auto parentDC = derived.getConformanceContext();
auto selfInterfaceType = parentDC->getDeclaredInterfaceType();

auto *tangentDecl = getAssociatedStructDecl(parentDC, C.Id_TangentVector);
auto *tangentDecl = getAssociatedStructDecl(parentDC, C.Id_Derivative);
auto tangentType = tangentDecl->getDeclaredInterfaceType();

return deriveDifferentiable_method(
Expand All @@ -450,10 +450,10 @@ deriveDifferentiable_tangentVector(DerivedConformance &derived) {
auto parentDC = derived.getConformanceContext();
auto &C = derived.TC.Context;

auto *tangentDecl = getAssociatedStructDecl(parentDC, C.Id_TangentVector);
auto *tangentDecl = getAssociatedStructDecl(parentDC, C.Id_Derivative);
auto tangentType = tangentDecl->getDeclaredInterfaceType();

auto *cotangentDecl = getAssociatedStructDecl(parentDC, C.Id_CotangentVector);
auto *cotangentDecl = getAssociatedStructDecl(parentDC, C.Id_Gradient);
auto cotangentType = cotangentDecl->getDeclaredInterfaceType();

return deriveDifferentiable_method(
Expand Down Expand Up @@ -657,7 +657,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
auto nominal = derived.Nominal;
auto &C = nominal->getASTContext();

assert(id == C.Id_TangentVector || id == C.Id_CotangentVector ||
assert(id == C.Id_Derivative || id == C.Id_Gradient ||
id == C.Id_AllDifferentiableVariables);

// If the associated struct already exists, return it.
Expand Down Expand Up @@ -689,7 +689,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,

// If the associated type is `TangentVector` or `CotangentVector`, make it
// also conform to `AdditiveArithmetic`.
if (id == C.Id_TangentVector || id == C.Id_CotangentVector)
if (id == C.Id_Derivative || id == C.Id_Gradient)
inherited.push_back(addArithType);

// Associated struct can derive `AdditiveArithmetic` if the associated types
Expand Down Expand Up @@ -919,14 +919,14 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived,
bool freshlySynthesized = allDiffableVarsStructSynthesis.second;

auto tangentStructSynthesis =
getOrSynthesizeSingleAssociatedStruct(derived, C.Id_TangentVector);
getOrSynthesizeSingleAssociatedStruct(derived, C.Id_Derivative);
auto *tangentStruct = tangentStructSynthesis.first;
if (!tangentStruct)
return nullptr;
freshlySynthesized |= tangentStructSynthesis.second;

auto cotangentStructSynthesis =
getOrSynthesizeSingleAssociatedStruct(derived, C.Id_CotangentVector);
getOrSynthesizeSingleAssociatedStruct(derived, C.Id_Gradient);
auto *cotangentStruct = cotangentStructSynthesis.first;
if (!cotangentStruct)
return nullptr;
Expand All @@ -940,18 +940,18 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived,
checkAndDiagnoseImplicitNoDerivative(TC, nominal, parentDC);

// Add associated typealiases for structs.
addAssociatedTypeAliasDecl(C.Id_TangentVector,
addAssociatedTypeAliasDecl(C.Id_Derivative,
tangentStruct, tangentStruct, TC);
addAssociatedTypeAliasDecl(C.Id_TangentVector,
addAssociatedTypeAliasDecl(C.Id_Derivative,
cotangentStruct, cotangentStruct, TC);
addAssociatedTypeAliasDecl(C.Id_TangentVector,
addAssociatedTypeAliasDecl(C.Id_Derivative,
allDiffableVarsStruct, tangentStruct, TC);

addAssociatedTypeAliasDecl(C.Id_CotangentVector,
addAssociatedTypeAliasDecl(C.Id_Gradient,
tangentStruct, cotangentStruct, TC);
addAssociatedTypeAliasDecl(C.Id_CotangentVector,
addAssociatedTypeAliasDecl(C.Id_Gradient,
cotangentStruct, tangentStruct, TC);
addAssociatedTypeAliasDecl(C.Id_CotangentVector,
addAssociatedTypeAliasDecl(C.Id_Gradient,
allDiffableVarsStruct, cotangentStruct, TC);

addAssociatedTypeAliasDecl(C.Id_AllDifferentiableVariables,
Expand Down Expand Up @@ -983,9 +983,9 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived,

// Return the requested associated struct type.
StructDecl *requestedStructDecl = nullptr;
if (id == C.Id_TangentVector)
if (id == C.Id_Derivative)
requestedStructDecl = tangentStruct;
else if (id == C.Id_CotangentVector)
else if (id == C.Id_Gradient)
requestedStructDecl = cotangentStruct;
else if (id == C.Id_AllDifferentiableVariables)
requestedStructDecl = allDiffableVarsStruct;
Expand Down Expand Up @@ -1074,9 +1074,9 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
bool allMembersAssocTypesEqualsSelf =
llvm::all_of(diffProperties, [&](VarDecl *member) {
auto tangentType =
getAssociatedType(member, parentDC, C.Id_TangentVector);
getAssociatedType(member, parentDC, C.Id_Derivative);
auto cotangentType =
getAssociatedType(member, parentDC, C.Id_CotangentVector);
getAssociatedType(member, parentDC, C.Id_Gradient);
auto allDiffableVarsType =
getAssociatedType(member, parentDC, C.Id_AllDifferentiableVariables);
return tangentType->isEqual(cotangentType) &&
Expand Down Expand Up @@ -1106,13 +1106,13 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived,
checkAndDiagnoseImplicitNoDerivative(TC, nominal, parentDC);
addAssociatedTypeAliasDecl(C.Id_AllDifferentiableVariables,
allDiffableVarsStruct, allDiffableVarsStruct, TC);
addAssociatedTypeAliasDecl(C.Id_TangentVector,
addAssociatedTypeAliasDecl(C.Id_Derivative,
allDiffableVarsStruct, allDiffableVarsStruct, TC);
addAssociatedTypeAliasDecl(C.Id_CotangentVector,
addAssociatedTypeAliasDecl(C.Id_Gradient,
allDiffableVarsStruct, allDiffableVarsStruct, TC);
addAssociatedTypeAliasDecl(C.Id_TangentVector,
addAssociatedTypeAliasDecl(C.Id_Derivative,
parentDC, allDiffableVarsStruct, TC);
addAssociatedTypeAliasDecl(C.Id_CotangentVector,
addAssociatedTypeAliasDecl(C.Id_Gradient,
parentDC, allDiffableVarsStruct, TC);
TC.validateDecl(allDiffableVarsStruct);
return parentDC->mapTypeIntoContext(
Expand All @@ -1135,12 +1135,12 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
}

Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
if (requirement->getBaseName() == TC.Context.Id_TangentVector)
if (requirement->getBaseName() == TC.Context.Id_Derivative)
return deriveDifferentiable_AssociatedStruct(
*this, TC.Context.Id_TangentVector);
if (requirement->getBaseName() == TC.Context.Id_CotangentVector)
*this, TC.Context.Id_Derivative);
if (requirement->getBaseName() == TC.Context.Id_Gradient)
return deriveDifferentiable_AssociatedStruct(
*this, TC.Context.Id_CotangentVector);
*this, TC.Context.Id_Gradient);
if (requirement->getBaseName() == TC.Context.Id_AllDifferentiableVariables)
return deriveDifferentiable_AssociatedStruct(
*this, TC.Context.Id_AllDifferentiableVariables);
Expand Down
8 changes: 4 additions & 4 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
return getRequirement(KnownProtocolKind::KeyPathIterable);

// SWIFT_ENABLE_TENSORFLOW
// Differentiable.TangentVector
// Differentiable.CotangentVector
// Differentiable.Derivative
// Differentiable.Gradient
// Differentiable.AllDifferentiableVariables
if (name.isSimpleName(ctx.Id_TangentVector) ||
name.isSimpleName(ctx.Id_CotangentVector) ||
if (name.isSimpleName(ctx.Id_Derivative) ||
name.isSimpleName(ctx.Id_Gradient) ||
name.isSimpleName(ctx.Id_AllDifferentiableVariables))
return getRequirement(KnownProtocolKind::__Differentiable);

Expand Down
4 changes: 2 additions & 2 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3035,10 +3035,10 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
}
if (funcResultElt.getName().str() == "differential") {
kind = AutoDiffAssociatedFunctionKind::JVP;
autoDiffAssocTyId = ctx.Id_TangentVector;
autoDiffAssocTyId = ctx.Id_Derivative;
} else if (funcResultElt.getName().str() == "pullback") {
kind = AutoDiffAssociatedFunctionKind::VJP;
autoDiffAssocTyId = ctx.Id_CotangentVector;
autoDiffAssocTyId = ctx.Id_Derivative;
} else {
TC.diagnose(attr->getLocation(),
diag::differentiating_attr_expected_result_tuple_func_label);
Expand Down
32 changes: 16 additions & 16 deletions stdlib/public/TensorFlow/Gradients.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,29 @@ public extension Differentiable {
@inlinable
func gradient<R : TensorFlowFloatingPoint>(
in f: @differentiable (Self) -> Tensor<R>
) -> CotangentVector {
) -> Gradient {
return self.pullback(in: f)(Tensor<R>(1))
}

@inlinable
func valueWithGradient<R : TensorFlowFloatingPoint>(
in f: @differentiable (Self) -> Tensor<R>
) -> (value: Tensor<R>, gradient: CotangentVector) {
) -> (value: Tensor<R>, gradient: Gradient) {
let (y, pb) = self.valueWithPullback(in: f)
return (y, pb(Tensor<R>(1)))
}

@inlinable
func gradient<T : Differentiable, R : TensorFlowFloatingPoint>(
at x: T, in f: @differentiable (Self, T) -> Tensor<R>
) -> (CotangentVector, T.CotangentVector) {
) -> (Gradient, T.Gradient) {
return self.pullback(at: x, in: f)(Tensor<R>(1))
}

@inlinable
func valueWithGradient<T : Differentiable, R : TensorFlowFloatingPoint>(
at x: T, in f: @differentiable (Self, T) -> Tensor<R>
) -> (value: Tensor<R>, gradient: (CotangentVector, T.CotangentVector)) {
) -> (value: Tensor<R>, gradient: (Gradient, T.Gradient)) {
let (y, pb) = self.valueWithPullback(at: x, in: f)
return (y, pb(Tensor<R>(1)))
}
Expand All @@ -85,7 +85,7 @@ public extension Differentiable {
@inlinable
public func valueWithGradient<T, R>(
at x: T, in f: @differentiable (T) -> Tensor<R>
) -> (value: Tensor<R>, gradient: T.CotangentVector)
) -> (value: Tensor<R>, gradient: T.Gradient)
where T : Differentiable, R : TensorFlowFloatingPoint {
let (y, pullback) = valueWithPullback(at: x, in: f)
return (y, pullback(Tensor<R>(1)))
Expand All @@ -94,7 +94,7 @@ where T : Differentiable, R : TensorFlowFloatingPoint {
@inlinable
public func valueWithGradient<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> Tensor<R>
) -> (value: Tensor<R>, gradient: (T.CotangentVector, U.CotangentVector))
) -> (value: Tensor<R>, gradient: (T.Gradient, U.Gradient))
where T : Differentiable, U : Differentiable,
R : TensorFlowFloatingPoint {
let (y, pullback) = valueWithPullback(at: x, y, in: f)
Expand All @@ -105,7 +105,7 @@ public func valueWithGradient<T, U, R>(
public func valueWithGradient<T, U, V, R>(
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> Tensor<R>
) -> (value: Tensor<R>,
gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector))
gradient: (T.Gradient, U.Gradient, V.Gradient))
where T : Differentiable, U : Differentiable, V : Differentiable,
R : TensorFlowFloatingPoint {
let (y, pullback) = valueWithPullback(at: x, y, z, in: f)
Expand All @@ -117,7 +117,7 @@ public func valueWithGradient<T, U, V, R>(
@inlinable
public func valueWithGradient<T, R>(
of f: @escaping @differentiable (T) -> Tensor<R>
) -> (T) -> (value: Tensor<R>, gradient: T.CotangentVector)
) -> (T) -> (value: Tensor<R>, gradient: T.Gradient)
where T : Differentiable, R : TensorFlowFloatingPoint {
return { x in valueWithGradient(at: x, in: f) }
}
Expand All @@ -126,7 +126,7 @@ public func valueWithGradient<T, R>(
public func valueWithGradient<T, U, R>(
of f: @escaping @differentiable (T, U) -> Tensor<R>
) -> (T, U)
-> (value: Tensor<R>, gradient: (T.CotangentVector, U.CotangentVector))
-> (value: Tensor<R>, gradient: (T.Gradient, U.Gradient))
where T : Differentiable, U : Differentiable,
R : TensorFlowFloatingPoint {
return { x, y in valueWithGradient(at: x, y, in: f) }
Expand All @@ -137,7 +137,7 @@ public func valueWithGradient<T, U, V, R>(
of f: @escaping @differentiable (T, U, V) -> Tensor<R>
) -> (T, U, V)
-> (value: Tensor<R>,
gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector))
gradient: (T.Gradient, U.Gradient, V.Gradient))
where T : Differentiable, U : Differentiable, V : Differentiable,
R : TensorFlowFloatingPoint {
return { x, y, z in valueWithGradient(at: x, y, z, in: f) }
Expand All @@ -148,15 +148,15 @@ public func valueWithGradient<T, U, V, R>(
@inlinable
public func gradient<T, R>(
at x: T, in f: @differentiable (T) -> Tensor<R>
) -> T.CotangentVector
) -> T.Gradient
where T : Differentiable, R : TensorFlowFloatingPoint {
return pullback(at: x, in: f)(Tensor<R>(1))
}

@inlinable
public func gradient<T, U, R>(
at x: T, _ y: U, in f: @differentiable (T, U) -> Tensor<R>
) -> (T.CotangentVector, U.CotangentVector)
) -> (T.Gradient, U.Gradient)
where T : Differentiable, U : Differentiable,
R : TensorFlowFloatingPoint {
return pullback(at: x, y, in: f)(Tensor<R>(1))
Expand All @@ -165,7 +165,7 @@ public func gradient<T, U, R>(
@inlinable
public func gradient<T, U, V, R>(
at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> Tensor<R>
) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector)
) -> (T.Gradient, U.Gradient, V.Gradient)
where T : Differentiable, U : Differentiable, V : Differentiable,
R : TensorFlowFloatingPoint {
return pullback(at: x, y, z, in: f)(Tensor<R>(1))
Expand All @@ -176,15 +176,15 @@ public func gradient<T, U, V, R>(
@inlinable
public func gradient<T, R>(
of f: @escaping @differentiable (T) -> Tensor<R>
) -> (T) -> T.CotangentVector
) -> (T) -> T.Gradient
where T : Differentiable, R : TensorFlowFloatingPoint {
return { x in gradient(at: x, in: f) }
}

@inlinable
public func gradient<T, U, R>(
of f: @escaping @differentiable (T, U) -> Tensor<R>
) -> (T, U) -> (T.CotangentVector, U.CotangentVector)
) -> (T, U) -> (T.Gradient, U.Gradient)
where T : Differentiable, U : Differentiable,
R : TensorFlowFloatingPoint {
return { x, y in gradient(at: x, y, in: f) }
Expand All @@ -193,7 +193,7 @@ public func gradient<T, U, R>(
@inlinable
public func gradient<T, U, V, R>(
of f: @escaping @differentiable (T, U, V) -> Tensor<R>
) -> (T, U, V) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector)
) -> (T, U, V) -> (T.Gradient, U.Gradient, V.Gradient)
where T : Differentiable, U : Differentiable, V : Differentiable,
R : TensorFlowFloatingPoint {
return { x, y, z in gradient(at: x, y, z, in: f) }
Expand Down
6 changes: 3 additions & 3 deletions stdlib/public/TensorFlow/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ extension Tensor : VectorNumeric where Scalar : Numeric {
extension Tensor : ShapedVectorNumeric where Scalar : Numeric {}

extension Tensor : Differentiable where Scalar : TensorFlowFloatingPoint {
public typealias TangentVector = Tensor
public typealias CotangentVector = Tensor
public typealias Derivative = Tensor
public typealias Gradient = Tensor
public typealias AllDifferentiableVariables = Tensor
@inlinable @inline(__always)
public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
public func tangentVector(from cotangent: Gradient) -> Derivative {
return cotangent
}
}
Expand Down
Loading