Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
//
// This unblocks the `thin_to_thick_function` peephole optimization below.
if (auto *CFI = dyn_cast<ConvertFunctionInst>(Cvt->getOperand())) {
if (CFI->getSingleUse()) {
if (hasOneNonDebugUse(CFI)) {
if (auto *TTTFI = dyn_cast<ThinToThickFunctionInst>(CFI->getOperand())) {
if (TTTFI->getSingleUse()) {
auto convertedThickType = CFI->getType().castTo<SILFunctionType>();
Expand Down Expand Up @@ -836,7 +836,7 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst(
// %vjp' = convert_escape_to_noescape %vjp
// %y = differentiable_function(%orig', %jvp', %vjp')
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(Cvt->getOperand())) {
if (DFI->hasOneUse()) {
if (hasOneNonDebugUse(DFI)) {
auto createConvertEscapeToNoEscape =
[&](NormalDifferentiableFunctionTypeComponent extractee) {
if (!DFI->hasExtractee(extractee))
Expand Down Expand Up @@ -1020,9 +1020,7 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) {
// %vjp' = convert_function %vjp
// %y = differentiable_function(%orig', %jvp', %vjp')
if (auto *DFI = dyn_cast<DifferentiableFunctionInst>(cfi->getOperand())) {
// Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78848
// TODO: remove this if-statement once the underlying problem is fixed.
if (cfi->getFunction()->hasOwnership())
if (!hasOneNonDebugUse(DFI))
return nullptr;

auto createConvertFunctionOfComponent =
Expand Down
136 changes: 136 additions & 0 deletions test/AutoDiff/sil_combine.sil
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,139 @@ bb0(%orig : $@callee_guaranteed (Float) -> Float):
// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'

// MARK: `convert_function` hoisting

// This should optimize down single partial_apply that escapes
sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float

%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>

%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
}

debug_value %diff_fn, let, name "f", argno 1

%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff
return %conv_orig
}

// CHECK-LABEL: sil @differential_function_convert_single_use
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
// CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float

sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()

// differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there

sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float

%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
%conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>

%diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
}

debug_value %diff_fn, let, name "f", argno 1

%conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff

%blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()

return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
}

// CHECK-LABEL: sil @differential_function_convert_multiple_use
// CHECK: convert_function
// CHECK: differentiable_function
// CHECK: convert_function
// CHECK: differentiable_function_extract

// MARK: `convert_escape_to_noescape` hoisting

sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()

// Here we should be able to unfold partial_apply down to direct function call

sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float

%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float

%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
}

debug_value %diff_fn, let, name "f", argno 1

%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff

%arg = alloc_stack $Float
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float

dealloc_stack %arg : $*Float
strong_release %pa

%res = tuple ()
return %res : $()
}

// CHECK-LABEL: sil @differential_function_noescape_single_use
// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
// CHECK: %[[ARG:.*]] = alloc_stack $Float
// CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float


// differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there

sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
%thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float

%pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float

%diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
}

debug_value %diff_fn, let, name "f", argno 1

%conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
%conv_orig = differentiable_function_extract [original] %conv_diff

%arg = alloc_stack $Float
apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float

%blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()

dealloc_stack %arg : $*Float
strong_release %pa

%res = tuple ()
return %res : $()
}

// CHECK-LABEL: sil @differential_function_noescape_multiple_use
// CHECK: differentiable_function
// CHECK: convert_escape_to_noescape
// CHECK: differentiable_function_extract