diff --git a/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp b/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp index fe92106003fa8..1ad347ac6304c 100644 --- a/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp +++ b/lib/SILOptimizer/SILCombiner/SILCombinerCastVisitors.cpp @@ -790,7 +790,7 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst( // // This unblocks the `thin_to_thick_function` peephole optimization below. if (auto *CFI = dyn_cast(Cvt->getOperand())) { - if (CFI->getSingleUse()) { + if (hasOneNonDebugUse(CFI)) { if (auto *TTTFI = dyn_cast(CFI->getOperand())) { if (TTTFI->getSingleUse()) { auto convertedThickType = CFI->getType().castTo(); @@ -836,7 +836,7 @@ SILInstruction *SILCombiner::visitConvertEscapeToNoEscapeInst( // %vjp' = convert_escape_to_noescape %vjp // %y = differentiable_function(%orig', %jvp', %vjp') if (auto *DFI = dyn_cast(Cvt->getOperand())) { - if (DFI->hasOneUse()) { + if (hasOneNonDebugUse(DFI)) { auto createConvertEscapeToNoEscape = [&](NormalDifferentiableFunctionTypeComponent extractee) { if (!DFI->hasExtractee(extractee)) @@ -1020,9 +1020,7 @@ SILCombiner::visitConvertFunctionInst(ConvertFunctionInst *cfi) { // %vjp' = convert_function %vjp // %y = differentiable_function(%orig', %jvp', %vjp') if (auto *DFI = dyn_cast(cfi->getOperand())) { - // Workaround for a problem with OSSA: https://github.com/swiftlang/swift/issues/78848 - // TODO: remove this if-statement once the underlying problem is fixed. - if (cfi->getFunction()->hasOwnership()) + if (!hasOneNonDebugUse(DFI)) return nullptr; auto createConvertFunctionOfComponent = diff --git a/test/AutoDiff/sil_combine.sil b/test/AutoDiff/sil_combine.sil index 741d226e63aaa..e69a29a95fe88 100644 --- a/test/AutoDiff/sil_combine.sil +++ b/test/AutoDiff/sil_combine.sil @@ -59,3 +59,139 @@ bb0(%orig : $@callee_guaranteed (Float) -> Float): // CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float // CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) // CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined' + +// MARK: `convert_function` hoisting + +// This should optimize down single partial_apply that escapes +sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float { +bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): + %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float + + %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float + %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for + + %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative { + undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for , + undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for + } + + debug_value %diff_fn, let, name "f", argno 1 + + %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float + %conv_orig = differentiable_function_extract [original] %conv_diff + return %conv_orig +} + +// CHECK-LABEL: sil @differential_function_convert_single_use +// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) +// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float +// CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float +// CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float + +sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted (@in_guaranteed T) -> Float for ) -> () + +// differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there + +sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float { +bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): + %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float + + %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float + %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for + + %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative { + undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for , + undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for + } + + debug_value %diff_fn, let, name "f", argno 1 + + %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float + %conv_orig = differentiable_function_extract [original] %conv_diff + + %blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted (@in_guaranteed T) -> Float for ) -> () + apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted (@in_guaranteed T) -> Float for ) -> () + + return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float +} + +// CHECK-LABEL: sil @differential_function_convert_multiple_use +// CHECK: convert_function +// CHECK: differentiable_function +// CHECK: convert_function +// CHECK: differentiable_function_extract + +// MARK: `convert_escape_to_noescape` hoisting + +sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> () + +// Here we should be able to unfold partial_apply down to direct function call + +sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () { +bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): + %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float + + %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float + + %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative { + undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float), + undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float) + } + + debug_value %diff_fn, let, name "f", argno 1 + + %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float + %conv_orig = differentiable_function_extract [original] %conv_diff + + %arg = alloc_stack $Float + apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float + + dealloc_stack %arg : $*Float + strong_release %pa + + %res = tuple () + return %res : $() +} + +// CHECK-LABEL: sil @differential_function_noescape_single_use +// CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) +// CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float +// CHECK: %[[ARG:.*]] = alloc_stack $Float +// CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float + + +// differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there + +sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () { +bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float): + %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float + + %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float + + %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative { + undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float), + undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float) + } + + debug_value %diff_fn, let, name "f", argno 1 + + %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float + %conv_orig = differentiable_function_extract [original] %conv_diff + + %arg = alloc_stack $Float + apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float + + %blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> () + apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> () + + dealloc_stack %arg : $*Float + strong_release %pa + + %res = tuple () + return %res : $() +} + +// CHECK-LABEL: sil @differential_function_noescape_multiple_use +// CHECK: differentiable_function +// CHECK: convert_escape_to_noescape +// CHECK: differentiable_function_extract