Skip to content

Commit

Permalink
Merge pull request #527 from swiftwasm/master
Browse files Browse the repository at this point in the history
[pull] swiftwasm from master
  • Loading branch information
pull[bot] authored Mar 29, 2020
2 parents 58242d4 + 80e5a51 commit 2abceea
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 4 deletions.
174 changes: 174 additions & 0 deletions lib/SIL/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,19 @@ namespace {

RetTy visitSILFunctionType(CanSILFunctionType type,
AbstractionPattern origType) {
// Handle `@differentiable` and `@differentiable(linear)` functions.
switch (type->getDifferentiabilityKind()) {
case DifferentiabilityKind::Normal:
return asImpl().visitNormalDifferentiableSILFunctionType(
type, getNormalDifferentiableSILFunctionTypeRecursiveProperties(
type, origType));
case DifferentiabilityKind::Linear:
return asImpl().visitLinearDifferentiableSILFunctionType(
type, getLinearDifferentiableSILFunctionTypeRecursiveProperties(
type, origType));
case DifferentiabilityKind::NonDifferentiable:
break;
}
// Only escaping closures are references.
bool isSwiftEscaping = type->getExtInfo().isNoEscape() &&
type->getExtInfo().getRepresentation() ==
Expand All @@ -250,6 +263,53 @@ namespace {
return asImpl().handleTrivial(type);
}

RecursiveProperties
getNormalDifferentiableSILFunctionTypeRecursiveProperties(
CanSILFunctionType type, AbstractionPattern origType) {
auto &M = TC.M;
auto origTy = type->getWithoutDifferentiability();
// Pass the `AbstractionPattern` generic signature to
// `SILFunctionType:getAutoDiffDerivativeFunctionType` for correct type
// lowering.
auto jvpTy = origTy->getAutoDiffDerivativeFunctionType(
type->getDifferentiabilityParameterIndices(), /*resultIndex*/ 0,
AutoDiffDerivativeFunctionKind::JVP, TC,
LookUpConformanceInModule(&M), CanGenericSignature());
auto vjpTy = origTy->getAutoDiffDerivativeFunctionType(
type->getDifferentiabilityParameterIndices(), /*resultIndex*/ 0,
AutoDiffDerivativeFunctionKind::VJP, TC,
LookUpConformanceInModule(&M), CanGenericSignature());
RecursiveProperties props;
props.addSubobject(classifyType(origType, origTy, TC, Expansion));
props.addSubobject(classifyType(origType, jvpTy, TC, Expansion));
props.addSubobject(classifyType(origType, vjpTy, TC, Expansion));
return props;
}

RecursiveProperties
getLinearDifferentiableSILFunctionTypeRecursiveProperties(
CanSILFunctionType type, AbstractionPattern origType) {
auto &M = TC.M;
auto origTy = type->getWithoutDifferentiability();
auto transposeTy = origTy->getAutoDiffTransposeFunctionType(
type->getDifferentiabilityParameterIndices(), TC,
LookUpConformanceInModule(&M), origType.getGenericSignatureOrNull());
RecursiveProperties props;
props.addSubobject(classifyType(origType, origTy, TC, Expansion));
props.addSubobject(classifyType(origType, transposeTy, TC, Expansion));
return props;
}

RetTy visitNormalDifferentiableSILFunctionType(
CanSILFunctionType type, RecursiveProperties props) {
return handleAggregateByProperties(type, props);
}

RetTy visitLinearDifferentiableSILFunctionType(
CanSILFunctionType type, RecursiveProperties props) {
return handleAggregateByProperties(type, props);
}

RetTy visitLValueType(CanLValueType type,
AbstractionPattern origType) {
llvm_unreachable("shouldn't get an l-value type here");
Expand Down Expand Up @@ -960,6 +1020,106 @@ namespace {
}
};

/// A type lowering for `@differentiable` function types.
class NormalDifferentiableSILFunctionTypeLowering final
: public LoadableAggTypeLowering<
NormalDifferentiableSILFunctionTypeLowering,
NormalDifferentiableFunctionTypeComponent> {
public:
using LoadableAggTypeLowering::LoadableAggTypeLowering;

SILValue emitRValueProject(
SILBuilder &B, SILLocation loc, SILValue tupleValue,
NormalDifferentiableFunctionTypeComponent extractee,
const TypeLowering &eltLowering) const {
return B.createDifferentiableFunctionExtract(
loc, extractee, tupleValue);
}

SILValue rebuildAggregate(SILBuilder &B, SILLocation loc,
ArrayRef<SILValue> values) const override {
assert(values.size() == 3);
auto fnTy = getLoweredType().castTo<SILFunctionType>();
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
return B.createDifferentiableFunction(
loc, paramIndices, values[0], std::make_pair(values[1], values[2]));
}

void lowerChildren(TypeConverter &TC,
SmallVectorImpl<Child> &children) const override {
auto fnTy = getLoweredType().castTo<SILFunctionType>();
auto numDerivativeFns = 2;
children.reserve(numDerivativeFns + 1);
auto origFnTy = fnTy->getWithoutDifferentiability();
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
children.push_back(Child{
NormalDifferentiableFunctionTypeComponent::Original,
TC.getTypeLowering(origFnTy, getExpansionContext())
});
for (AutoDiffDerivativeFunctionKind kind :
{AutoDiffDerivativeFunctionKind::JVP,
AutoDiffDerivativeFunctionKind::VJP}) {
auto derivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
paramIndices, 0, kind, TC,
LookUpConformanceInModule(&TC.M));
auto silTy = SILType::getPrimitiveObjectType(derivativeFnTy);
NormalDifferentiableFunctionTypeComponent extractee(kind);
// Assert that we have the right extractee. A terrible bug in the past
// was caused by implicit conversions from `unsigned` to
// `NormalDifferentiableFunctionTypeComponent` which resulted into a
// wrong extractee.
assert(extractee.getAsDerivativeFunctionKind() == kind);
children.push_back(Child{
extractee, TC.getTypeLowering(silTy, getExpansionContext())});
}
assert(children.size() == 3);
}
};

/// A type lowering for `@differentiable(linear)` function types.
class LinearDifferentiableSILFunctionTypeLowering final
: public LoadableAggTypeLowering<
LinearDifferentiableSILFunctionTypeLowering,
LinearDifferentiableFunctionTypeComponent> {
public:
using LoadableAggTypeLowering::LoadableAggTypeLowering;

SILValue emitRValueProject(
SILBuilder &B, SILLocation loc, SILValue tupleValue,
LinearDifferentiableFunctionTypeComponent component,
const TypeLowering &eltLowering) const {
return B.createLinearFunctionExtract(loc, component, tupleValue);
}

SILValue rebuildAggregate(SILBuilder &B, SILLocation loc,
ArrayRef<SILValue> values) const override {
assert(values.size() == 2);
auto fnTy = getLoweredType().castTo<SILFunctionType>();
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
return B.createLinearFunction(loc, paramIndices, values[0], values[1]);
}

void lowerChildren(TypeConverter &TC,
SmallVectorImpl<Child> &children) const override {
auto fnTy = getLoweredType().castTo<SILFunctionType>();
children.reserve(2);
auto origFnTy = fnTy->getWithoutDifferentiability();
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
children.push_back(Child{
LinearDifferentiableFunctionTypeComponent::Original,
TC.getTypeLowering(origFnTy, getExpansionContext())
});
auto transposeFnTy = origFnTy->getAutoDiffTransposeFunctionType(
paramIndices, TC, LookUpConformanceInModule(&TC.M));
auto transposeSILFnTy = SILType::getPrimitiveObjectType(transposeFnTy);
children.push_back(Child{
LinearDifferentiableFunctionTypeComponent::Transpose,
TC.getTypeLowering(transposeSILFnTy, getExpansionContext())
});
assert(children.size() == 2);
}
};

class LeafLoadableTypeLowering : public NonTrivialLoadableTypeLowering {
public:
LeafLoadableTypeLowering(SILType type, RecursiveProperties properties,
Expand Down Expand Up @@ -1358,6 +1518,20 @@ namespace {
properties);
}

TypeLowering *
visitNormalDifferentiableSILFunctionType(CanSILFunctionType type,
RecursiveProperties props) {
return handleAggregateByProperties
<NormalDifferentiableSILFunctionTypeLowering>(type, props);
}

TypeLowering *
visitLinearDifferentiableSILFunctionType(CanSILFunctionType type,
RecursiveProperties props) {
return handleAggregateByProperties
<LinearDifferentiableSILFunctionTypeLowering>(type, props);
}

template <class LoadableLoweringClass>
TypeLowering *handleAggregateByProperties(CanType type,
RecursiveProperties props) {
Expand Down
1 change: 0 additions & 1 deletion lib/SILGen/SILGenPoly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3905,7 +3905,6 @@ ManagedValue SILGenFunction::getThunkedAutoDiffLinearMap(
return getThunkedResult();
}

// SWIFT_ENABLE_TENSORFLOW
SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
SILFunction *customDerivativeFn, SILFunction *originalFn,
const AutoDiffConfig &config, AutoDiffDerivativeFunctionKind kind) {
Expand Down
12 changes: 9 additions & 3 deletions test/AutoDiff/SIL/Serialization/differentiable_function.swift
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
// RUN: %empty-directory(%t)
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiation -enable-experimental-differentiable-programming
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiation -enable-experimental-differentiable-programming
// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiation -emit-sorted-sil -enable-experimental-differentiable-programming | %FileCheck %s
// RUN: %target-sil-opt -enable-experimental-differentiable-programming %s -emit-sib -o %t/tmp.sib -module-name main
// RUN: %target-sil-opt -enable-experimental-differentiable-programming %t/tmp.sib -o %t/tmp.sil -module-name main
// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files.
// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil
// RUN: %target-sil-opt -enable-experimental-differentiable-programming %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck %s

// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround.
// REQUIRES: shell

sil_stage raw

import Swift
import _Differentiation

sil @a : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float {
bb0(%0 : $@differentiable (Float) -> Float):
Expand Down

0 comments on commit 2abceea

Please sign in to comment.