diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 4f40cdc2b04bc..e9b72bb7c55b6 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -52,6 +52,7 @@ TYPE_ATTR(convention) TYPE_ATTR(noescape) TYPE_ATTR(escaping) TYPE_ATTR(differentiable) +TYPE_ATTR(noDerivative) // SIL-specific attributes TYPE_ATTR(block_storage) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 36b0e01566ac8..829b4821b5a97 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -3926,6 +3926,11 @@ ERROR(opaque_type_in_protocol_requirement,none, "'some' type cannot be the return type of a protocol requirement; did you mean to add an associated type?", ()) +// Function differentiability +ERROR(attr_only_on_parameters_of_differentiable,none, + "'%0' may only be used on parameters of '@differentiable' function " + "types", (StringRef)) + // SIL ERROR(opened_non_protocol,none, "@opened cannot be applied to non-protocol type %0", (Type)) diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 8943ea372446b..c03a523d0ae40 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -1808,8 +1808,8 @@ class ParameterTypeFlags { NonEphemeral = 1 << 2, OwnershipShift = 3, Ownership = 7 << OwnershipShift, - - NumBits = 6 + NoDerivative = 1 << 7, + NumBits = 7 }; OptionSet value; static_assert(NumBits < 8*sizeof(OptionSet), "overflowed"); @@ -1823,15 +1823,17 @@ class ParameterTypeFlags { } ParameterTypeFlags(bool variadic, bool autoclosure, bool nonEphemeral, - ValueOwnership ownership) + ValueOwnership ownership, bool noDerivative) : value((variadic ? Variadic : 0) | (autoclosure ? AutoClosure : 0) | (nonEphemeral ? NonEphemeral : 0) | - uint8_t(ownership) << OwnershipShift) {} + uint8_t(ownership) << OwnershipShift | + (noDerivative ? NoDerivative : 0)) {} /// Create one from what's present in the parameter type inline static ParameterTypeFlags fromParameterType(Type paramTy, bool isVariadic, bool isAutoClosure, - bool isNonEphemeral, ValueOwnership ownership); + bool isNonEphemeral, ValueOwnership ownership, + bool isNoDerivative); bool isNone() const { return !value; } bool isVariadic() const { return value.contains(Variadic); } @@ -1840,6 +1842,7 @@ class ParameterTypeFlags { bool isInOut() const { return getValueOwnership() == ValueOwnership::InOut; } bool isShared() const { return getValueOwnership() == ValueOwnership::Shared;} bool isOwned() const { return getValueOwnership() == ValueOwnership::Owned; } + bool isNoDerivative() const { return value.contains(NoDerivative); } ValueOwnership getValueOwnership() const { return ValueOwnership((value.toRaw() & Ownership) >> OwnershipShift); @@ -1882,6 +1885,12 @@ class ParameterTypeFlags { : value - ParameterTypeFlags::NonEphemeral); } + ParameterTypeFlags withNoDerivative(bool noDerivative) const { + return ParameterTypeFlags(noDerivative + ? value | ParameterTypeFlags::NoDerivative + : value - ParameterTypeFlags::NoDerivative); + } + bool operator ==(const ParameterTypeFlags &other) const { return value.toRaw() == other.value.toRaw(); } @@ -1948,8 +1957,8 @@ class YieldTypeFlags { ParameterTypeFlags asParamFlags() const { return ParameterTypeFlags(/*variadic*/ false, /*autoclosure*/ false, - /*nonEphemeral*/ false, - getValueOwnership()); + /*nonEphemeral*/ false, getValueOwnership(), + /*noDerivative*/ false); } bool operator ==(const YieldTypeFlags &other) const { @@ -2821,6 +2830,9 @@ class AnyFunctionType : public TypeBase { /// Whether the parameter is marked '@_nonEphemeral' bool isNonEphemeral() const { return Flags.isNonEphemeral(); } + /// Whether the parameter is marked '@noDerivative'. + bool isNoDerivative() const { return Flags.isNoDerivative(); } + ValueOwnership getValueOwnership() const { return Flags.getValueOwnership(); } @@ -5818,10 +5830,9 @@ inline TupleTypeElt TupleTypeElt::getWithType(Type T) const { } /// Create one from what's present in the parameter decl and type -inline ParameterTypeFlags -ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic, - bool isAutoClosure, bool isNonEphemeral, - ValueOwnership ownership) { +inline ParameterTypeFlags ParameterTypeFlags::fromParameterType( + Type paramTy, bool isVariadic, bool isAutoClosure, bool isNonEphemeral, + ValueOwnership ownership, bool isNoDerivative) { // FIXME(Remove InOut): The last caller that needs this is argument // decomposition. Start by enabling the assertion there and fixing up those // callers, then remove this, then remove @@ -5831,7 +5842,7 @@ ParameterTypeFlags::fromParameterType(Type paramTy, bool isVariadic, ownership == ValueOwnership::InOut); ownership = ValueOwnership::InOut; } - return {isVariadic, isAutoClosure, isNonEphemeral, ownership}; + return {isVariadic, isAutoClosure, isNonEphemeral, ownership, isNoDerivative}; } inline const Type *BoundGenericType::getTrailingObjectsPointer() const { diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 805f472be89d4..db8035eadacb0 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -2893,9 +2893,10 @@ void AnyFunctionType::decomposeInput( } default: - result.emplace_back(type->getInOutObjectType(), Identifier(), - ParameterTypeFlags::fromParameterType( - type, false, false, false, ValueOwnership::Default)); + result.emplace_back( + type->getInOutObjectType(), Identifier(), + ParameterTypeFlags::fromParameterType(type, false, false, false, + ValueOwnership::Default, false)); return; } } diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index e0212538f1261..e4e038947fc01 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -2502,6 +2502,8 @@ static void printParameterFlags(ASTPrinter &printer, PrintOptions options, ParameterTypeFlags flags, bool escaping) { if (!options.excludeAttrKind(TAK_autoclosure) && flags.isAutoClosure()) printer << "@autoclosure "; + if (!options.excludeAttrKind(TAK_noDerivative) && flags.isNoDerivative()) + printer << "@noDerivative "; switch (flags.getValueOwnership()) { case ValueOwnership::Default: diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 5cf8d55f85146..64abf260a60ba 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -6039,11 +6039,10 @@ AnyFunctionType::Param ParamDecl::toFunctionParam(Type type) const { type = ParamDecl::getVarargBaseTy(type); auto label = getArgumentName(); - auto flags = ParameterTypeFlags::fromParameterType(type, - isVariadic(), - isAutoClosure(), - isNonEphemeral(), - getValueOwnership()); + auto flags = ParameterTypeFlags::fromParameterType( + type, isVariadic(), isAutoClosure(), isNonEphemeral(), + getValueOwnership(), + /*isNoDerivative*/ false); return AnyFunctionType::Param(type, label, flags); } diff --git a/lib/AST/TypeRepr.cpp b/lib/AST/TypeRepr.cpp index 200f3aa5604e4..7340027e31c8d 100644 --- a/lib/AST/TypeRepr.cpp +++ b/lib/AST/TypeRepr.cpp @@ -298,6 +298,8 @@ void AttributedTypeRepr::printAttrs(ASTPrinter &Printer, Printer.printSimpleAttr("@autoclosure") << " "; if (hasAttr(TAK_escaping)) Printer.printSimpleAttr("@escaping") << " "; + if (hasAttr(TAK_noDerivative)) + Printer.printSimpleAttr("@noDerivative") << " "; if (hasAttr(TAK_differentiable)) { if (Attrs.isLinear()) { diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index c3616371deeec..3158e3be8ef1c 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -1793,6 +1793,7 @@ namespace { resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr, TypeResolutionOptions options, bool requiresMappingOut, + DifferentiabilityKind diffKind, SmallVectorImpl &ps); Type resolveSILFunctionType(FunctionTypeRepr *repr, @@ -2026,6 +2027,11 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs, // Remember whether this is a function parameter. bool isParam = options.is(TypeResolverContext::FunctionInput); + // Remember whether this is a variadic function parameter. + bool isVariadicFunctionParam = + options.is(TypeResolverContext::VariadicFunctionInput) && + !options.hasBase(TypeResolverContext::EnumElementDecl); + // The type we're working with, in case we want to build it differently // based on the attributes we see. Type ty; @@ -2370,6 +2376,21 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs, attrs.ConventionArguments = None; } + if (attrs.has(TAK_noDerivative)) { + if (!Context.LangOpts.EnableExperimentalDifferentiableProgramming) { + diagnose(attrs.getLoc(TAK_noDerivative), + diag::experimental_differentiable_programming_disabled); + } else if (!isParam) { + // @noDerivative is only valid on parameters. + diagnose(attrs.getLoc(TAK_noDerivative), + (isVariadicFunctionParam + ? diag::attr_not_on_variadic_parameters + : diag::attr_only_on_parameters_of_differentiable), + "@noDerivative"); + } + attrs.clearAttribute(TAK_noDerivative); + } + // In SIL, handle @opened (n), which creates an existential archetype. if (attrs.has(TAK_opened)) { if (!ty->isExistentialType()) { @@ -2422,7 +2443,7 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs, bool TypeResolver::resolveASTFunctionTypeParams( TupleTypeRepr *inputRepr, TypeResolutionOptions options, - bool requiresMappingOut, + bool requiresMappingOut, DifferentiabilityKind diffKind, SmallVectorImpl &elements) { elements.reserve(inputRepr->getNumElements()); @@ -2486,8 +2507,24 @@ bool TypeResolver::resolveASTFunctionTypeParams( ownership = ValueOwnership::Default; break; } + + bool noDerivative = false; + if (auto *attrTypeRepr = dyn_cast(eltTypeRepr)) { + if (attrTypeRepr->getAttrs().has(TAK_noDerivative)) { + if (diffKind == DifferentiabilityKind::NonDifferentiable && + Context.LangOpts.EnableExperimentalDifferentiableProgramming) + diagnose(eltTypeRepr->getLoc(), + diag::attr_only_on_parameters_of_differentiable, + "@noDerivative") + .highlight(eltTypeRepr->getSourceRange()); + else + noDerivative = true; + } + } + auto paramFlags = ParameterTypeFlags::fromParameterType( - ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership); + ty, variadic, autoclosure, /*isNonEphemeral*/ false, ownership, + noDerivative); elements.emplace_back(ty, Identifier(), paramFlags); } @@ -2541,7 +2578,8 @@ Type TypeResolver::resolveASTFunctionType( SmallVector params; if (resolveASTFunctionTypeParams(repr->getArgsTypeRepr(), options, - repr->getGenericEnvironment() != nullptr, params)) { + repr->getGenericEnvironment() != nullptr, + diffKind, params)) { return Type(); } diff --git a/lib/Serialization/Deserialization.cpp b/lib/Serialization/Deserialization.cpp index 31646a5219c2a..f260a588bec93 100644 --- a/lib/Serialization/Deserialization.cpp +++ b/lib/Serialization/Deserialization.cpp @@ -4770,12 +4770,11 @@ class swift::TypeDeserializer { IdentifierID labelID; TypeID typeID; - bool isVariadic, isAutoClosure, isNonEphemeral; + bool isVariadic, isAutoClosure, isNonEphemeral, isNoDerivative; unsigned rawOwnership; - decls_block::FunctionParamLayout::readRecord(scratch, labelID, typeID, - isVariadic, isAutoClosure, - isNonEphemeral, - rawOwnership); + decls_block::FunctionParamLayout::readRecord( + scratch, labelID, typeID, isVariadic, isAutoClosure, isNonEphemeral, + rawOwnership, isNoDerivative); auto ownership = getActualValueOwnership((serialization::ValueOwnership)rawOwnership); @@ -4786,10 +4785,10 @@ class swift::TypeDeserializer { if (!paramTy) return paramTy.takeError(); - params.emplace_back(paramTy.get(), - MF.getIdentifier(labelID), + params.emplace_back(paramTy.get(), MF.getIdentifier(labelID), ParameterTypeFlags(isVariadic, isAutoClosure, - isNonEphemeral, *ownership)); + isNonEphemeral, *ownership, + isNoDerivative)); } if (!isGeneric) { diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 284004efed076..b1c96011230b7 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 530; // @_implicitly_synthesizes_nested_requirement +const uint16_t SWIFTMODULE_VERSION_MINOR = 531; // function parameter noDerivative /// A standard hash seed used for all string hashes in a serialized module. /// @@ -905,12 +905,13 @@ namespace decls_block { using FunctionParamLayout = BCRecordLayout< FUNCTION_PARAM, - IdentifierIDField, // name - TypeIDField, // type - BCFixed<1>, // vararg? - BCFixed<1>, // autoclosure? - BCFixed<1>, // non-ephemeral? - ValueOwnershipField // inout, shared or owned? + IdentifierIDField, // name + TypeIDField, // type + BCFixed<1>, // vararg? + BCFixed<1>, // autoclosure? + BCFixed<1>, // non-ephemeral? + ValueOwnershipField, // inout, shared or owned? + BCFixed<1> // noDerivative? >; using MetatypeTypeLayout = BCRecordLayout< diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index ba2ba44d2bc60..3f9471b7eb565 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -4021,8 +4021,8 @@ class Serializer::TypeSerializer : public TypeVisitor { S.Out, S.ScratchRecord, abbrCode, S.addDeclBaseNameRef(param.getLabel()), S.addTypeRef(param.getPlainType()), paramFlags.isVariadic(), - paramFlags.isAutoClosure(), paramFlags.isNonEphemeral(), - rawOwnership); + paramFlags.isAutoClosure(), paramFlags.isNonEphemeral(), rawOwnership, + paramFlags.isNoDerivative()); } } diff --git a/test/AutoDiff/ModuleInterface/differentiation.swift b/test/AutoDiff/ModuleInterface/differentiation.swift index 6b3a791fb9ca3..35ee0f9067891 100644 --- a/test/AutoDiff/ModuleInterface/differentiation.swift +++ b/test/AutoDiff/ModuleInterface/differentiation.swift @@ -6,3 +6,6 @@ public func a(f: @differentiable (Float) -> Float) {} public func b(f: @differentiable(linear) (Float) -> Float) {} // CHECK: public func b(f: @differentiable(linear) (Swift.Float) -> Swift.Float) + +public func c(f: @differentiable (Float, @noDerivative Float) -> Float) {} +// CHECK: public func c(f: @differentiable (Swift.Float, @noDerivative Swift.Float) -> Swift.Float) diff --git a/test/AutoDiff/Parse/differentiable_func_type.swift b/test/AutoDiff/Parse/differentiable_func_type.swift index 7ef91d3524fb9..e163a246a4fba 100644 --- a/test/AutoDiff/Parse/differentiable_func_type.swift +++ b/test/AutoDiff/Parse/differentiable_func_type.swift @@ -8,12 +8,25 @@ let b: @differentiable(linear) (Float) -> Float // okay // CHECK: (pattern_named 'b' // CHECK-NEXT: (type_attributed attrs=@differentiable(linear) -let c: @differentiable (Float) throws -> Float // okay +let c: @differentiable (Float, @noDerivative Float) -> Float // okay // CHECK: (pattern_named 'c' +// CHECK-NEXT: (type_attributed attrs=@differentiable +// CHECK-NEXT: (type_function +// CHECK-NEXT: (type_tuple +// CHECK-NEXT: (type_ident +// CHECK-NEXT: (component id='Float' bind=none)) +// CHECK-NEXT: (type_attributed attrs=@noDerivative +// CHECK-NEXT: (type_ident +// CHECK-NEXT: (component id='Float' bind=none))) +// CHECK-NEXT: (type_ident +// CHECK-NEXT: (component id='Float' bind=none))))) + +let d: @differentiable (Float) throws -> Float // okay +// CHECK: (pattern_named 'd' // CHECK-NEXT: (type_attributed attrs=@differentiable{{[^(]}} -let d: @differentiable(linear) (Float) throws -> Float // okay -// CHECK: (pattern_named 'd' +let e: @differentiable(linear) (Float) throws -> Float // okay +// CHECK: (pattern_named 'e' // CHECK-NEXT: (type_attributed attrs=@differentiable(linear) // Generic type test. diff --git a/test/AutoDiff/Sema/differentiable_features_disabled.swift b/test/AutoDiff/Sema/differentiable_features_disabled.swift index e8212a69258d8..1ba3ec05e8a80 100644 --- a/test/AutoDiff/Sema/differentiable_features_disabled.swift +++ b/test/AutoDiff/Sema/differentiable_features_disabled.swift @@ -3,6 +3,16 @@ // expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}} let _: @differentiable (Float) -> Float +// expected-error @+2 {{differentiable programming is an experimental feature that is currently disabled}} +// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}} +let _: @differentiable (Float, @noDerivative Float) -> Float + +// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}} +let _: (Float, @noDerivative Float) -> Float + +// expected-error @+1 {{differentiable programming is an experimental feature that is currently disabled}} +let _: @noDerivative Float + func id(_ x: Float) -> Float { return x } diff --git a/test/AutoDiff/Sema/differentiable_func_type.swift b/test/AutoDiff/Sema/differentiable_func_type.swift index 0ca2f8aa50296..0515dcc95d027 100644 --- a/test/AutoDiff/Sema/differentiable_func_type.swift +++ b/test/AutoDiff/Sema/differentiable_func_type.swift @@ -4,3 +4,24 @@ let _: @differentiable Float let _: @differentiable (Float) -> Float + +// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}} +let _: @noDerivative Float + +// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}} +let _: (Float) -> @noDerivative Float + +// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}} +let _: @differentiable (Float) -> @noDerivative Float + +// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}} +let _: (@noDerivative Float) -> Float + +// expected-error @+2 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}} +// expected-error @+1 {{'@noDerivative' must not be used on variadic parameters}} +let _: (Float, @noDerivative Float...) -> Float + +let _: @differentiable (@noDerivative Float, Float) -> Float + +// expected-error @+1 {{'@noDerivative' must not be used on variadic parameters}} +let _: @differentiable (Float, @noDerivative Float...) -> Float diff --git a/test/AutoDiff/Serialization/differentiation.swift b/test/AutoDiff/Serialization/differentiation.swift index d5f1276c3f494..3a17b5440b8a9 100644 --- a/test/AutoDiff/Serialization/differentiation.swift +++ b/test/AutoDiff/Serialization/differentiation.swift @@ -10,3 +10,6 @@ func a(_ f: @differentiable (Float) -> Float) {} func b(_ f: @differentiable(linear) (Float) -> Float) {} // CHECK: func b(_ f: @differentiable(linear) (Float) -> Float) + +func c(_ f: @differentiable (Float, @noDerivative Float) -> Float) {} +// CHECK: func c(_ f: @differentiable (Float, @noDerivative Float) -> Float)