From d69e892f039d9d2347cda50de7d173ca9d4c6bbb Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Mon, 6 Jan 2020 15:32:21 -0800 Subject: [PATCH 1/2] [AutoDiff upstream] forbid @derivative of protocol req --- lib/Sema/TypeCheckAttr.cpp | 3 ++ .../Sema/derivative_attr_type_checking.swift | 34 +++++++++---------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 3dae0bf29bf89..86b8fa6f90761 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3503,6 +3503,9 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D, }; auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { + // TODO(TF-982): Allow derivatives on protocol requirements. + if (isa(originalCandidate->getDeclContext())) + return false; return checkFunctionSignature( cast(originalFnType->getCanonicalType()), originalCandidate->getInterfaceType()->getCanonicalType(), diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 1e1b70c7cd341..27db38dc34c4e 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -185,6 +185,11 @@ protocol StaticMethod: Differentiable { static func generic(_ x: T) -> T } +extension StaticMethod { + static func foo(_ x: Float) -> Float { x } + static func generic(_ x: T) -> T { x } +} + extension StaticMethod { @derivative(of: foo) static func jvpFoo(x: Float) -> (value: Float, differential: (Float) -> Float) @@ -215,11 +220,16 @@ extension StaticMethod { // Test instance methods. protocol InstanceMethod: Differentiable { - // expected-note @+1 {{'foo' defined here}} func foo(_ x: Self) -> Self + func generic(_ x: T) -> Self +} + +extension InstanceMethod { + // expected-note @+1 {{'foo' defined here}} + func foo(_ x: Self) -> Self { x } // expected-note @+1 {{'generic' defined here}} - func generic(_ x: T) -> Self + func generic(_ x: T) -> Self { self } } extension InstanceMethod { @@ -539,24 +549,14 @@ extension HasStoredProperty { // Test cross-file derivative registration. Currently unsupported. // TODO(TF-1021): Lift this restriction. -extension AdditiveArithmetic where Self: Differentiable { +extension FloatingPoint where Self: Differentiable { // expected-error @+1 {{derivative not in the same file as the original function}} - @derivative(of: +) - static func vjpPlus(x: Self, y: Self) -> ( + @derivative(of: rounded) + func vjpRounded() -> ( value: Self, - pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector) - ) { - return (x + y, { v in (v, v) }) - } -} - -extension FloatingPoint where Self: Differentiable, Self == Self.TangentVector { - // expected-error @+1 {{derivative not in the same file as the original function}} - @derivative(of: +) - static func vjpPlus(x: Self, y: Self) -> ( - value: Self, pullback: (Self) -> (Self, Self) + pullback: (Self.TangentVector) -> (Self.TangentVector) ) { - return (x + y, { v in (v, v) }) + fatalError() } } From 29465f82add36c3a964e2f8bd1557c7d35934f81 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Mon, 6 Jan 2020 18:29:26 -0800 Subject: [PATCH 2/2] Add standalone test. --- .../Sema/derivative_attr_type_checking.swift | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index 27db38dc34c4e..babb8fad87aa5 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -546,6 +546,23 @@ extension HasStoredProperty { } } +// Test derivative registration for protocol requirements. Currently unsupported. +// TODO(TF-982): Lift this restriction and add proper support. + +protocol ProtocolRequirementDerivative { + func requirement(_ x: Float) -> Float +} +extension ProtocolRequirementDerivative { + // NOTE: the error is misleading because `findAbstractFunctionDecl` in + // TypeCheckAttr.cpp is not setup to show customized error messages for + // invalid original function candidates. + // expected-error @+1 {{could not find function 'requirement' with expected type ' (Self) -> (Float) -> Float'}} + @derivative(of: requirement) + func vjpRequirement(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { + fatalError() + } +} + // Test cross-file derivative registration. Currently unsupported. // TODO(TF-1021): Lift this restriction.