Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3151,6 +3151,8 @@ ERROR(autodiff_attr_original_multiple_semantic_results,none,
ERROR(autodiff_attr_result_not_differentiable,none,
"can only differentiate functions with results that conform to "
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
ERROR(autodiff_attr_opaque_result_type_unsupported,none,
"cannot differentiate functions returning opaque result types", ())

// differentiation `wrt` parameters clause
ERROR(diff_function_no_parameters,none,
Expand Down
27 changes: 26 additions & 1 deletion lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4240,6 +4240,15 @@ IndexSubset *DifferentiableAttributeTypeCheckRequest::evaluate(

auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>();

// Diagnose if original function has opaque result types.
if (auto *opaqueResultTypeDecl = original->getOpaqueResultTypeDecl()) {
diags.diagnose(
attr->getLocation(),
diag::autodiff_attr_opaque_result_type_unsupported);
attr->setInvalid();
return nullptr;
}

// Diagnose if original function is an invalid class member.
bool isOriginalClassMember = original->getDeclContext() &&
original->getDeclContext()->getSelfClassDecl();
Expand Down Expand Up @@ -4532,6 +4541,16 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
return true;
}
}

// Diagnose if original function has opaque result types.
if (auto *opaqueResultTypeDecl = originalAFD->getOpaqueResultTypeDecl()) {
diags.diagnose(
attr->getLocation(),
diag::autodiff_attr_opaque_result_type_unsupported);
attr->setInvalid();
return true;
}

// Diagnose if original function is an invalid class member.
bool isOriginalClassMember =
originalAFD->getDeclContext() &&
Expand Down Expand Up @@ -5083,9 +5102,15 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
attr->setInvalid();
return;
}

attr->setOriginalFunction(originalAFD);

// Diagnose if original function has opaque result types.
if (auto *opaqueResultTypeDecl = originalAFD->getOpaqueResultTypeDecl()) {
diagnose(attr->getLocation(), diag::autodiff_attr_opaque_result_type_unsupported);
attr->setInvalid();
return;
}

// Get the linearity parameter types.
SmallVector<AnyFunctionType::Param, 4> linearParams;
expectedOriginalFnType->getSubsetParameters(linearParamIndices, linearParams,
Expand Down
12 changes: 11 additions & 1 deletion test/AutoDiff/Sema/derivative_attr_type_checking.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend-typecheck -verify %s
// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s

import _Differentiation

Expand Down Expand Up @@ -1124,3 +1124,13 @@ extension Float {
fatalError()
}
}

// Test original function with opaque result type.

func opaqueResult(_ x: Float) -> some Differentiable { x }

// expected-error @+1 {{could not find function 'opaqueResult' with expected type '(Float) -> Float'}}
@derivative(of: opaqueResult)
func vjpOpaqueResult(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
fatalError()
}
6 changes: 5 additions & 1 deletion test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %target-swift-frontend-typecheck -verify %s
// RUN: %target-swift-frontend-typecheck -verify -disable-availability-checking %s

import _Differentiation

Expand Down Expand Up @@ -697,3 +697,7 @@ struct Accessors: Differentiable {
_modify { yield &stored }
}
}

// expected-error @+1 {{cannot differentiate functions returning opaque result types}}
@differentiable
func opaqueResult(_ x: Float) -> some Differentiable { x }