Skip to content

Commit

Permalink
Merge pull request #460 from swiftwasm/master
Browse files Browse the repository at this point in the history
[pull] swiftwasm from master
  • Loading branch information
pull[bot] authored Mar 21, 2020
2 parents 7248331 + 24445dd commit 5c3bf1c
Show file tree
Hide file tree
Showing 35 changed files with 1,584 additions and 79 deletions.
19 changes: 19 additions & 0 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,25 @@ class ASTMangler : public Mangler {
Type SelfType,
ModuleDecl *Module);

/// Mangle the derivative function (JVP/VJP) for the given:
/// - Mangled original function name.
/// - Derivative function kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string
mangleAutoDiffDerivativeFunctionHelper(StringRef name,
AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config);

/// Mangle the linear map (differential/pullback) for the given:
/// - Mangled original function name.
/// - Linear map kind.
/// - Derivative function configuration: parameter/result indices and
/// derivative generic signature.
std::string mangleAutoDiffLinearMapHelper(StringRef name,
AutoDiffLinearMapKind kind,
AutoDiffConfig config);

/// Mangle a SIL differentiability witness key:
/// - Mangled original function name.
/// - Parameter indices.
Expand Down
76 changes: 76 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "swift/AST/TypeAlignments.h"
#include "swift/Basic/Range.h"
#include "swift/Basic/SourceLoc.h"
#include "llvm/ADT/StringExtras.h"

namespace swift {

Expand Down Expand Up @@ -95,6 +96,45 @@ struct DifferentiabilityWitnessFunctionKind {
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

/// SIL-level automatic differentiation indices. Consists of:
/// - Parameter indices: indices of parameters to differentiate with respect to.
/// - Result index: index of the result to differentiate from.
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
// `AutoDiffConfig` supports multiple result indices.
struct SILAutoDiffIndices {
/// The index of the dependent result to differentiate from.
unsigned source;
/// The indices for independent parameters to differentiate with respect to.
IndexSubset *parameters;

/*implicit*/ SILAutoDiffIndices(unsigned source, IndexSubset *parameters)
: source(source), parameters(parameters) {}

bool operator==(const SILAutoDiffIndices &other) const;

bool operator!=(const SILAutoDiffIndices &other) const {
return !(*this == other);
};

/// Returns true if `parameterIndex` is a differentiability parameter index.
bool isWrtParameter(unsigned parameterIndex) const {
return parameterIndex < parameters->getCapacity() &&
parameters->contains(parameterIndex);
}

void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;

std::string mangle() const {
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
interleave(
parameters->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
}
};

/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
Expand All @@ -110,6 +150,11 @@ struct AutoDiffConfig {
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}

/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
// TODO(TF-913): This is a temporary shim for incremental removal of
// `SILAutoDiffIndices`. Eventually remove this.
SILAutoDiffIndices getSILAutoDiffIndices() const;

void print(llvm::raw_ostream &s = llvm::outs()) const;
SWIFT_DEBUG_DUMP;
};
Expand Down Expand Up @@ -282,6 +327,37 @@ void getFunctionSemanticResultTypes(
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
GenericEnvironment *genericEnv = nullptr);

/// Returns the lowered SIL parameter indices for the given AST parameter
/// indices and `AnyfunctionType`.
///
/// Notable lowering-related changes:
/// - AST tuple parameter types are exploded when lowered to SIL.
/// - AST curried `Self` parameter types become the last parameter when lowered
/// to SIL.
///
/// Examples:
///
/// AST function type: (A, B, C) -> R
/// AST parameter indices: 101, {A, C}
/// Lowered SIL function type: $(A, B, C) -> R
/// Lowered SIL parameter indices: 101
///
/// AST function type: (Self) -> (A, B, C) -> R
/// AST parameter indices: 1010, {Self, B}
/// Lowered SIL function type: $(A, B, C, Self) -> R
/// Lowered SIL parameter indices: 0101
///
/// AST function type: (A, (B, C), D) -> R
/// AST parameter indices: 110, {A, (B, C)}
/// Lowered SIL function type: $(A, B, C, D) -> R
/// Lowered SIL parameter indices: 1110
///
/// Note:
/// - The AST function type must not be curried unless it is a method.
/// Otherwise, the behavior is undefined.
IndexSubset *getLoweredParameterIndices(IndexSubset *astParameterIndices,
AnyFunctionType *functionType);

/// "Constrained" derivative generic signatures require all differentiability
/// parameters to conform to the `Differentiable` protocol.
///
Expand Down
12 changes: 12 additions & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5191,6 +5191,18 @@ class VarDecl : public AbstractStorageDecl {
/// a suitable `init(initialValue:)`.
bool isPropertyMemberwiseInitializedWithWrappedType() const;

/// Whether the innermost property wrapper's initializer's 'wrappedValue' parameter
/// is marked with '@autoclosure' and '@escaping'.
bool isInnermostPropertyWrapperInitUsesEscapingAutoClosure() const;

/// Return the interface type of the value used for the 'wrappedValue:'
/// parameter when initializing a property wrapper.
///
/// If the property has an attached property wrapper and the 'wrappedValue:'
/// parameter is an autoclosure, return a function type returning the stored
/// value. Otherwise, return the interface type of the stored value.
Type getPropertyWrapperInitValueInterfaceType() const;

/// If this property is the backing storage for a property with an attached
/// property wrapper, return the original property.
///
Expand Down
3 changes: 0 additions & 3 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1658,9 +1658,6 @@ ERROR(redundant_class_requirement,none,
"redundant 'class' requirement", ())
ERROR(late_class_requirement,none,
"'class' must come first in the requirement list", ())
ERROR(where_toplevel_nongeneric,none,
"'where' clause cannot be attached to non-generic "
"top-level declaration", ())
ERROR(where_inside_brackets,none,
"'where' clause next to generic parameters is obsolete, "
"must be written following the declaration's type", ())
Expand Down
12 changes: 7 additions & 5 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,7 @@ NOTE(unstable_mangled_name_add_objc,none,
"for compatibility with existing archives, use '@objc' "
"to record the Swift 3 runtime name", ())

// Generic types
// Generic declarations
ERROR(unsupported_type_nested_in_generic_function,none,
"type %0 cannot be nested in generic function %1",
(Identifier, DeclName))
Expand All @@ -1591,6 +1591,12 @@ ERROR(unsupported_type_nested_in_protocol_extension,none,
ERROR(unsupported_nested_protocol,none,
"protocol %0 cannot be nested inside another declaration",
(Identifier))
ERROR(where_nongeneric_ctx,none,
"'where' clause on non-generic member declaration requires a "
"generic context", ())
ERROR(where_nongeneric_toplevel,none,
"'where' clause cannot be applied to a non-generic top-level "
"declaration", ())

// Type aliases
ERROR(type_alias_underlying_type_access,none,
Expand Down Expand Up @@ -2755,10 +2761,6 @@ ERROR(dynamic_self_stored_property_init,none,
ERROR(dynamic_self_default_arg,none,
"covariant 'Self' type cannot be referenced from a default argument expression", ())

ERROR(where_nongeneric_ctx,none,
"'where' clause on non-generic member declaration requires a "
"generic context", ())

//------------------------------------------------------------------------------
// MARK: Type Check Attributes
//------------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/PropertyWrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ struct PropertyWrapperTypeInfo {
HasInitialValueInit
} wrappedValueInit = NoWrappedValueInit;

/// Whether the init(wrappedValue:), if it exists, has the wrappedValue
/// argument as an escaping autoclosure.
bool isWrappedValueInitUsingEscapingAutoClosure = false;

/// The initializer that will be called to default-initialize a
/// value with an attached property wrapper.
enum {
Expand Down
51 changes: 51 additions & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,57 @@ std::string ASTMangler::mangleReabstractionThunkHelper(
return finalize();
}

std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
StringRef name, AutoDiffDerivativeFunctionKind kind,
AutoDiffConfig config) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
beginManglingWithoutPrefix();

Buffer << "AD__" << name << '_';
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP:
Buffer << "_jvp_";
break;
case AutoDiffDerivativeFunctionKind::VJP:
Buffer << "_vjp_";
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
}

auto result = Storage.str().str();
Storage.clear();
return result;
}

std::string ASTMangler::mangleAutoDiffLinearMapHelper(
StringRef name, AutoDiffLinearMapKind kind, AutoDiffConfig config) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
beginManglingWithoutPrefix();

Buffer << "AD__" << name << '_';
switch (kind) {
case AutoDiffLinearMapKind::Differential:
Buffer << "_differential_";
break;
case AutoDiffLinearMapKind::Pullback:
Buffer << "_pullback_";
break;
}
Buffer << config.getSILAutoDiffIndices().mangle();
if (config.derivativeGenericSignature) {
Buffer << '_';
appendGenericSignature(config.derivativeGenericSignature);
}

auto result = Storage.str().str();
Storage.clear();
return result;
}

std::string ASTMangler::mangleSILDifferentiabilityWitnessKey(
SILDifferentiabilityWitnessKey key) {
// TODO(TF-20): Make the mangling scheme robust. Support demangling.
Expand Down
8 changes: 6 additions & 2 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,12 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
if (doIt(Inherit))
return true;
}

if (auto *ATD = dyn_cast<AssociatedTypeDecl>(TPD)) {

if (const auto ATD = dyn_cast<AssociatedTypeDecl>(TPD)) {
if (const auto DefaultTy = ATD->getDefaultDefinitionTypeRepr())
if (doIt(DefaultTy))
return true;

if (auto *WhereClause = ATD->getTrailingWhereClause()) {
for (auto &Req: WhereClause->getRequirements()) {
if (doIt(Req))
Expand Down
54 changes: 54 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
}
}

void SILAutoDiffIndices::print(llvm::raw_ostream &s) const {
s << "(source=" << source << " parameters=(";
interleave(
parameters->getIndices(), [&s](unsigned p) { s << p; },
[&s] { s << ' '; });
s << "))";
}

void SILAutoDiffIndices::dump() const {
print(llvm::errs());
llvm::errs() << '\n';
}

SILAutoDiffIndices AutoDiffConfig::getSILAutoDiffIndices() const {
assert(resultIndices->getNumIndices() == 1);
return SILAutoDiffIndices(*resultIndices->begin(), parameterIndices);
}

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
Expand Down Expand Up @@ -138,6 +156,42 @@ void autodiff::getFunctionSemanticResultTypes(
}
}

// TODO(TF-874): Simplify this helper. See TF-874 for WIP.
IndexSubset *
autodiff::getLoweredParameterIndices(IndexSubset *parameterIndices,
AnyFunctionType *functionType) {
SmallVector<AnyFunctionType *, 2> curryLevels;
unwrapCurryLevels(functionType, curryLevels);

// Compute the lowered sizes of all AST parameter types.
SmallVector<unsigned, 8> paramLoweredSizes;
unsigned totalLoweredSize = 0;
auto addLoweredParamInfo = [&](Type type) {
unsigned paramLoweredSize = countNumFlattenedElementTypes(type);
paramLoweredSizes.push_back(paramLoweredSize);
totalLoweredSize += paramLoweredSize;
};
for (auto *curryLevel : llvm::reverse(curryLevels))
for (auto &param : curryLevel->getParams())
addLoweredParamInfo(param.getPlainType());

// Build lowered SIL parameter indices by setting the range of bits that
// corresponds to each "set" AST parameter.
llvm::SmallVector<unsigned, 8> loweredSILIndices;
unsigned currentBitIndex = 0;
for (unsigned i : range(parameterIndices->getCapacity())) {
auto paramLoweredSize = paramLoweredSizes[i];
if (parameterIndices->contains(i)) {
auto indices = range(currentBitIndex, currentBitIndex + paramLoweredSize);
loweredSILIndices.append(indices.begin(), indices.end());
}
currentBitIndex += paramLoweredSize;
}

return IndexSubset::get(functionType->getASTContext(), totalLoweredSize,
loweredSILIndices);
}

GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
Expand Down
Loading

0 comments on commit 5c3bf1c

Please sign in to comment.