Skip to content

Commit

Permalink
internal/lsp: add support for extracting non-nested returns
Browse files Browse the repository at this point in the history
If there is a return statement that is guaranteed to execute in
the selection to extract to function, then the result of calling
the extracted function can be directly returned.

Updates golang/go#37170

Change-Id: I6454e4107d670e4a1bc9048b2e1073fc80fc78ab
Reviewed-on: https://go-review.googlesource.com/c/tools/+/312469
Trust: Suzy Mueller <[email protected]>
Run-TryBot: Suzy Mueller <[email protected]>
gopls-CI: kokoro <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Reviewed-by: Rebecca Stambler <[email protected]>
  • Loading branch information
suzmue committed Apr 27, 2021
1 parent d0768c9 commit 9ff8648
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 37 deletions.
2 changes: 1 addition & 1 deletion internal/lsp/analysis/unusedparams/unusedparams.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
fieldList, body = f.Type.Params, f.Body
}
// If there are no arguments or the function is empty, then return.
if fieldList.NumFields() == 0 || len(body.List) == 0 {
if fieldList.NumFields() == 0 || body == nil || len(body.List) == 0 {
return
}

Expand Down
111 changes: 76 additions & 35 deletions internal/lsp/source/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,9 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
return nil, fmt.Errorf("extractFunction: package scope is empty")
}

// TODO: Support non-nested return statements.
// A return statement is non-nested if its parent node is equal to the parent node
// of the first node in the selection. These cases must be handled separately because
// non-nested return statements are guaranteed to execute. Our control flow does not
// properly consider these situations yet.
// non-nested return statements are guaranteed to execute.
var retStmts []*ast.ReturnStmt
var hasNonNestedReturn bool
startParent := findParent(outer, start)
Expand All @@ -216,14 +214,10 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
}
if findParent(outer, n) == startParent {
hasNonNestedReturn = true
return false
}
retStmts = append(retStmts, ret)
return false
})
if hasNonNestedReturn {
return nil, fmt.Errorf("extractFunction: selected block contains non-nested return")
}
containsReturnStatement := len(retStmts) > 0

// Now that we have determined the correct range for the selection block,
Expand Down Expand Up @@ -396,23 +390,54 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
// in the original function. If the condition is met, the original function should
// return a value, mimicking the functionality of the original return statement(s)
// in the selection.
//
// If there is a return that is guaranteed to execute (hasNonNestedReturns=true), then
// we don't need to include this additional condition check and can simply return.
//
// Before:
//
// func _() int {
// a := 1
// b := 2
// **if a == b {
// return a
// }
// return b**
// }
//
// After:
//
// func _() int {
// a := 1
// b := 2
// return x0(a, b)
// }
//
// func x0(a int, b int) int {
// if a == b {
// return a
// }
// return b
// }

var retVars []*returnVariable
var ifReturn *ast.IfStmt
if containsReturnStatement {
// The selected block contained return statements, so we have to modify the
// signature of the extracted function as described above. Adjust all of
// the return statements in the extracted function to reflect this change in
// signature.
if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
pkg, extractedBlock); err != nil {
return nil, err
if !hasNonNestedReturn {
// The selected block contained return statements, so we have to modify the
// signature of the extracted function as described above. Adjust all of
// the return statements in the extracted function to reflect this change in
// signature.
if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
pkg, extractedBlock); err != nil {
return nil, err
}
}
// Collect the additional return values and types needed to accommodate return
// statements in the selection. Update the type signature of the extracted
// function and construct the if statement that will be inserted in the enclosing
// function.
retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start)
retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start, hasNonNestedReturn)
if err != nil {
return nil, err
}
Expand All @@ -421,8 +446,10 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
// Add a return statement to the end of the new function. This return statement must include
// the values for the types of the original extracted function signature and (if a return
// statement is present in the selection) enclosing function signature.
// This only needs to be done if the selections does not have a non-nested return, otherwise
// it already terminates with a return statement.
hasReturnValues := len(returns)+len(retVars) > 0
if hasReturnValues {
if hasReturnValues && !hasNonNestedReturn {
extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
Results: append(returns, getZeroVals(retVars)...),
})
Expand All @@ -439,7 +466,7 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.
sym = token.DEFINE
}
funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0)
extractedFunCall := generateFuncCall(hasReturnValues, params,
extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params,
append(returns, getNames(retVars)...), funName, sym)

// Build the extracted function.
Expand Down Expand Up @@ -951,15 +978,17 @@ func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
// signature of the extracted function. We prepare names, signatures, and "zero values" that
// represent the new variables. We also use this information to construct the if statement that
// is inserted below the call to the extracted function.
func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos) ([]*returnVariable, *ast.IfStmt, error) {
// Generate information for the added bool value.
cond := &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)}
retVars := []*returnVariable{
{
func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) {
var retVars []*returnVariable
var cond *ast.Ident
if !hasNonNestedReturns {
// Generate information for the added bool value.
cond = &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)}
retVars = append(retVars, &returnVariable{
name: cond,
decl: &ast.Field{Type: ast.NewIdent("bool")},
zeroVal: ast.NewIdent("false"),
},
})
}
// Generate information for the values in the return signature of the enclosing function.
if enclosing.Results != nil {
Expand All @@ -982,13 +1011,16 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.
})
}
}
// Create the return statement for the enclosing function. We must exclude the variable
// for the condition of the if statement (cond) from the return statement.
ifReturn := &ast.IfStmt{
Cond: cond,
Body: &ast.BlockStmt{
List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}},
},
var ifReturn *ast.IfStmt
if !hasNonNestedReturns {
// Create the return statement for the enclosing function. We must exclude the variable
// for the condition of the if statement (cond) from the return statement.
ifReturn = &ast.IfStmt{
Cond: cond,
Body: &ast.BlockStmt{
List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}},
},
}
}
return retVars, ifReturn, nil
}
Expand Down Expand Up @@ -1034,17 +1066,26 @@ func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]

// generateFuncCall constructs a call expression for the extracted function, described by the
// given parameters and return variables.
func generateFuncCall(hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
var replace ast.Node
if hasReturnVals {
callExpr := &ast.CallExpr{
Fun: ast.NewIdent(name),
Args: params,
}
replace = &ast.AssignStmt{
Lhs: returns,
Tok: token,
Rhs: []ast.Expr{callExpr},
if hasNonNestedReturn {
// Create a return statement that returns the result of the function call.
replace = &ast.ReturnStmt{
Return: 0,
Results: []ast.Expr{callExpr},
}
} else {
// Assign the result of the function call.
replace = &ast.AssignStmt{
Lhs: returns,
Tok: token,
Rhs: []ast.Expr{callExpr},
}
}
} else {
replace = &ast.CallExpr{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package extract

func _() bool {
x := 1 //@mark(exSt13, "x")
if x == 0 {
return true
}
return false //@mark(exEn13, "false")
//@extractfunc(exSt13, exEn13)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- functionextraction_extract_return_basic_nonnested_4_2 --
package extract

func _() bool {
return fn0() //@mark(exEn13, "false")
//@extractfunc(exSt13, exEn13)
}

func fn0() bool {
x := 1
if x == 0 {
return true
}
return false
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package extract

import "fmt"

func _() (int, string, error) {
x := 1
y := "hello"
z := "bye" //@mark(exSt10, "z")
if y == z {
return x, y, fmt.Errorf("same")
} else {
z = "hi"
return x, z, nil
}
return x, z, nil //@mark(exEn10, "nil")
//@extractfunc(exSt10, exEn10)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
-- functionextraction_extract_return_complex_nonnested_8_2 --
package extract

import "fmt"

func _() (int, string, error) {
x := 1
y := "hello"
return fn0(y, x) //@mark(exEn10, "nil")
//@extractfunc(exSt10, exEn10)
}

func fn0(y string, x int) (int, string, error) {
z := "bye"
if y == z {
return x, y, fmt.Errorf("same")
} else {
z = "hi"
return x, z, nil
}
return x, z, nil
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package extract

import "go/ast"

func _() {
ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool {
if n == nil { //@mark(exSt11, "if")
return true
}
return false //@mark(exEn11, "false")
})
//@extractfunc(exSt11, exEn11)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- functionextraction_extract_return_func_lit_nonnested_7_3 --
package extract

import "go/ast"

func _() {
ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool {
return fn0(n) //@mark(exEn11, "false")
})
//@extractfunc(exSt11, exEn11)
}

func fn0(n ast.Node) bool {
if n == nil {
return true
}
return false
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package extract

func _() string {
x := 1
if x == 0 { //@mark(exSt12, "if")
x = 3
return "a"
}
x = 2
return "b" //@mark(exEn12, "\"b\"")
//@extractfunc(exSt12, exEn12)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- functionextraction_extract_return_init_nonnested_5_2 --
package extract

func _() string {
x := 1
return fn0(x) //@mark(exEn12, "\"b\"")
//@extractfunc(exSt12, exEn12)
}

func fn0(x int) string {
if x == 0 {
x = 3
return "a"
}
x = 2
return "b"
}

2 changes: 1 addition & 1 deletion internal/lsp/testdata/summary.txt.golden
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ FormatCount = 6
ImportCount = 8
SemanticTokenCount = 3
SuggestedFixCount = 40
FunctionExtractionCount = 13
FunctionExtractionCount = 17
DefinitionsCount = 95
TypeDefinitionsCount = 10
HighlightsCount = 69
Expand Down

0 comments on commit 9ff8648

Please sign in to comment.