diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 3e4f3113f9e..7292cb0ef1f 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -236,6 +236,7 @@ var codeActionProducers = [...]codeActionProducer{ {kind: settings.RefactorExtractFunction, fn: refactorExtractFunction}, {kind: settings.RefactorExtractMethod, fn: refactorExtractMethod}, {kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile}, + {kind: settings.RefactorExtractAllOccursOfExpr, fn: refactorExtractAllOccursOfExpr}, {kind: settings.RefactorExtractVariable, fn: refactorExtractVariable}, {kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true}, {kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote}, @@ -458,6 +459,20 @@ func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error return nil } +// refactorExtractAllOccursOfExpr produces "Extract all occcurrances of expression" code action. +// See [extractAllOccursOfExpr] for command implementation. +func refactorExtractAllOccursOfExpr(ctx context.Context, req *codeActionsRequest) error { + // Don't suggest if only one expr is found, + // otherwise will duplicate with [refactorExtractVariable] + if exprs, ok, _ := canExtractExprs(req.start, req.end, req.pgf.File); ok && len(exprs) > 1 { + startOffset := req.pgf.Tok.Offset(exprs[0].Pos()) + endOffset := req.pgf.Tok.Offset(exprs[0].End()) + expr := req.pgf.Src[startOffset:endOffset] + req.addApplyFixAction(fmt.Sprintf("Extract %d occcurrances of %s", len(exprs), expr), fixExtractAllOccursOfExpr, req.loc) + } + return nil +} + // refactorExtractToNewFile produces "Extract declarations to new file" code actions. // See [server.commandHandler.ExtractToNewFile] for command implementation. func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) error { diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 6ea011e220e..75938fc8b86 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -36,14 +36,14 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file // TODO: stricter rules for selectorExpr. case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) + lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0) lhsNames = append(lhsNames, lhsName) case *ast.CallExpr: tup, ok := info.TypeOf(expr).(*types.Tuple) if !ok { // If the call expression only has one return value, we can treat it the // same as our standard extract variable case. - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) + lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0) lhsNames = append(lhsNames, lhsName) break } @@ -51,7 +51,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file for i := 0; i < tup.Len(); i++ { // Generate a unique variable for each return value. var lhsName string - lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx) + lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", idx) lhsNames = append(lhsNames, lhsName) } default: diff --git a/gopls/internal/golang/extractexprs.go b/gopls/internal/golang/extractexprs.go new file mode 100644 index 00000000000..129b2fa2ea0 --- /dev/null +++ b/gopls/internal/golang/extractexprs.go @@ -0,0 +1,371 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package golang + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/token" + "go/types" + "sort" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/ast/astutil" + "golang.org/x/tools/gopls/internal/util/safetoken" + "golang.org/x/tools/internal/analysisinternal" +) + +// extractAllOccursOfExpr replaces all occurrences of a specified expression within the same function +// with newVar. Its position is determined by the deepest common scope accessible to all occurrences. +func extractAllOccursOfExpr(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { + tokFile := fset.File(file.Pos()) + exprs, _, err := canExtractExprs(start, end, file) + if err != nil { + return nil, nil, fmt.Errorf("extractVariable: cannot extract %s: %v", safetoken.StartPosition(fset, start), err) + } + + scopes := make([][]*types.Scope, len(exprs)) + for i, e := range exprs { + path, _ := astutil.PathEnclosingInterval(file, e.Pos(), e.End()) + scopes[i] = CollectScopes(info, path, e.Pos()) + } + + // Where should the newVar live. + commonScope, err := findDeepestCommonScope(scopes) + if err != nil { + return nil, nil, fmt.Errorf("extractVariable: %v", err) + } + + var innerScopes []*types.Scope + for _, scope := range scopes { + for _, s := range scope { + if s != nil { + innerScopes = append(innerScopes, s) + break + } + } + } + if len(innerScopes) != len(exprs) { + return nil, nil, fmt.Errorf("extractVariable: nil scope") + } + // So the largest scope's name won't conflict too. + innerScopes = append(innerScopes, commonScope) + + // Create new AST node for extracted code. + var lhsNames []string + switch expr := exprs[0].(type) { + // TODO: stricter rules for selectorExpr. + case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, + *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: + lhsName, _ := generateAvailableIdentifierByScopes(innerScopes, "newVar", 0) + lhsNames = append(lhsNames, lhsName) + case *ast.CallExpr: + tup, ok := info.TypeOf(expr).(*types.Tuple) + if !ok { + // If the call expression only has one return value, we can treat it the + // same as our standard extract variable case. + lhsName, _ := generateAvailableIdentifierByScopes(innerScopes, "newVar", 0) + lhsNames = append(lhsNames, lhsName) + break + } + idx := 0 + for i := 0; i < tup.Len(); i++ { + // Generate a unique variable for each return value. + var lhsName string + lhsName, idx = generateAvailableIdentifierByScopes(innerScopes, "newVar", idx) + lhsNames = append(lhsNames, lhsName) + } + default: + return nil, nil, fmt.Errorf("cannot extract %T", expr) + } + + var validPath []ast.Node + if commonScope != innerScopes[0] { + // This means the first expr within function body is not the largest scope, + // we need to find the scope immediately follow the common + // scope where we will insert the statement before. + child := innerScopes[0] + for p := child; p != nil; p = p.Parent() { + if p == commonScope { + break + } + child = p + } + validPath, _ = astutil.PathEnclosingInterval(file, child.Pos(), child.End()) + } else { + // Just insert before the first expr. + validPath, _ = astutil.PathEnclosingInterval(file, exprs[0].Pos(), exprs[0].End()) + } + fn := func() ast.Expr { + return &ast.UnaryExpr{} + } +label: + switch r1 := fn().(type) { + case *ast.CallExpr: + fn() + // __AUTO_GENERATED_PRINT_VAR_START__ + fmt.Println(fmt.Sprintf("replaceAllOccursOfExpr r1: %v", r1)) // __AUTO_GENERATED_PRINT_VAR_END__ + break label + } + // + // TODO: There is a bug here: for a variable declared in a labeled + // switch/for statement it returns the for/switch statement itself + // which produces the below code which is a compiler error e.g. + // label: + // switch r1 := r() { ... break label ... } + // On extracting "r()" to a variable + // label: + // x := r() + // switch r1 := x { ... break label ... } // compiler error + // + insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(validPath) + if insertBeforeStmt == nil { + return nil, nil, fmt.Errorf("cannot find location to insert extraction") + } + indent, err := calculateIndentation(src, tokFile, insertBeforeStmt) + if err != nil { + return nil, nil, err + } + newLineIndent := "\n" + indent + + lhs := strings.Join(lhsNames, ", ") + assignStmt := &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(lhs)}, + Tok: token.DEFINE, + Rhs: []ast.Expr{exprs[0]}, + } + var buf bytes.Buffer + if err := format.Node(&buf, fset, assignStmt); err != nil { + return nil, nil, err + } + assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent + var textEdits []analysis.TextEdit + textEdits = append(textEdits, analysis.TextEdit{ + Pos: insertBeforeStmt.Pos(), + End: insertBeforeStmt.Pos(), + NewText: []byte(assignment), + }) + for _, e := range exprs { + textEdits = append(textEdits, analysis.TextEdit{ + Pos: e.Pos(), + End: e.End(), + NewText: []byte(lhs), + }) + } + return fset, &analysis.SuggestedFix{ + TextEdits: textEdits, + }, nil +} + +// findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains. +// Each scope chain represents the scopes of an expression from innermost to outermost. +// If no common scope is found, it returns an error. +func findDeepestCommonScope(scopeChains [][]*types.Scope) (*types.Scope, error) { + if len(scopeChains) == 0 { + return nil, fmt.Errorf("no scopes provided") + } + // Get the first scope chain as the reference. + referenceChain := scopeChains[0] + + // Iterate from innermost to outermost scope. + for i := 0; i < len(referenceChain); i++ { + candidateScope := referenceChain[i] + if candidateScope == nil { + continue + } + isCommon := true + // See if other exprs' chains all have candidateScope as a common ancestor. + for _, chain := range scopeChains[1:] { + found := false + for j := 0; j < len(chain); j++ { + if chain[j] == candidateScope { + found = true + break + } + } + if !found { + isCommon = false + break + } + } + if isCommon { + return candidateScope, nil + } + } + return nil, fmt.Errorf("no common scope found") +} + +// canExtractExprs reports whether the code in the given range can be +// extracted to a variable. It finds all occurrences of an expression +// within the same function. +func canExtractExprs(start, end token.Pos, file *ast.File) ([]ast.Expr, bool, error) { + if start == end { + return nil, false, fmt.Errorf("start and end are equal") + } + path, _ := astutil.PathEnclosingInterval(file, start, end) + if len(path) == 0 { + return nil, false, fmt.Errorf("no path enclosing interval") + } + for _, n := range path { + if _, ok := n.(*ast.ImportSpec); ok { + return nil, false, fmt.Errorf("cannot extract variable in an import block") + } + } + node := path[0] + if start != node.Pos() || end != node.End() { + return nil, false, fmt.Errorf("range does not map to an AST node") + } + expr, ok := node.(ast.Expr) + if !ok { + return nil, false, fmt.Errorf("node is not an expression") + } + + var exprs []ast.Expr + exprs = append(exprs, expr) + if funcDecl, ok := path[len(path)-2].(*ast.FuncDecl); ok { + ast.Inspect(funcDecl, func(n ast.Node) bool { + if e, ok := n.(ast.Expr); ok && e != expr { + if exprIdentical(e, expr) { + exprs = append(exprs, e) + } + } + return true + }) + } + sort.Slice(exprs, func(i, j int) bool { + return exprs[i].Pos() < exprs[j].Pos() + }) + + switch expr.(type) { + case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr, + *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: + return exprs, true, nil + } + return nil, false, fmt.Errorf("cannot extract an %T to a variable", expr) +} + +// generateAvailableIdentifierByScopes adjusts the new identifier name +// until there are no collisions in any of the provided scopes. +func generateAvailableIdentifierByScopes(scopes []*types.Scope, prefix string, idx int) (string, int) { + name := prefix + for { + collision := false + for _, scope := range scopes { + if scope.Lookup(name) != nil { + collision = true + break + } + } + if !collision { + return name, idx + } + idx++ + name = fmt.Sprintf("%s%d", prefix, idx) + } +} + +// exprIdentical recursively compares two ast.Expr nodes for structural equality, +// ignoring position fields. +func exprIdentical(x, y ast.Expr) bool { + if x == nil || y == nil { + return x == y + } + switch x := x.(type) { + case *ast.BasicLit: + y, ok := y.(*ast.BasicLit) + return ok && x.Kind == y.Kind && x.Value == y.Value + case *ast.CompositeLit: + y, ok := y.(*ast.CompositeLit) + if !ok || len(x.Elts) != len(y.Elts) || !exprIdentical(x.Type, y.Type) { + return false + } + for i := range x.Elts { + if !exprIdentical(x.Elts[i], y.Elts[i]) { + return false + } + } + return true + case *ast.ArrayType: + y, ok := y.(*ast.ArrayType) + return ok && exprIdentical(x.Len, y.Len) && exprIdentical(x.Elt, y.Elt) + case *ast.Ellipsis: + y, ok := y.(*ast.Ellipsis) + return ok && exprIdentical(x.Elt, y.Elt) + case *ast.FuncLit: + y, ok := y.(*ast.FuncLit) + return ok && exprIdentical(x.Type, y.Type) + case *ast.IndexExpr: + y, ok := y.(*ast.IndexExpr) + return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Index, y.Index) + case *ast.IndexListExpr: + y, ok := y.(*ast.IndexListExpr) + if !ok || len(x.Indices) != len(y.Indices) || !exprIdentical(x.X, y.X) { + return false + } + for i := range x.Indices { + if !exprIdentical(x.Indices[i], y.Indices[i]) { + return false + } + } + return true + case *ast.SliceExpr: + y, ok := y.(*ast.SliceExpr) + return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Low, y.Low) && exprIdentical(x.High, y.High) && exprIdentical(x.Max, y.Max) && x.Slice3 == y.Slice3 + case *ast.TypeAssertExpr: + y, ok := y.(*ast.TypeAssertExpr) + return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Type, y.Type) + case *ast.StarExpr: + y, ok := y.(*ast.StarExpr) + return ok && exprIdentical(x.X, y.X) + case *ast.KeyValueExpr: + y, ok := y.(*ast.KeyValueExpr) + return ok && exprIdentical(x.Key, y.Key) && exprIdentical(x.Value, y.Value) + case *ast.UnaryExpr: + y, ok := y.(*ast.UnaryExpr) + return ok && x.Op == y.Op && exprIdentical(x.X, y.X) + case *ast.MapType: + y, ok := y.(*ast.MapType) + return ok && exprIdentical(x.Value, y.Value) && exprIdentical(x.Key, y.Key) + case *ast.ChanType: + y, ok := y.(*ast.ChanType) + return ok && exprIdentical(x.Value, y.Value) && x.Dir == y.Dir + case *ast.BinaryExpr: + y, ok := y.(*ast.BinaryExpr) + return ok && x.Op == y.Op && + exprIdentical(x.X, y.X) && + exprIdentical(x.Y, y.Y) + case *ast.Ident: + y, ok := y.(*ast.Ident) + return ok && x.Name == y.Name + case *ast.ParenExpr: + y, ok := y.(*ast.ParenExpr) + return ok && exprIdentical(x.X, y.X) + case *ast.SelectorExpr: + y, ok := y.(*ast.SelectorExpr) + return ok && + exprIdentical(x.X, y.X) && + exprIdentical(x.Sel, y.Sel) + case *ast.CallExpr: + y, ok := y.(*ast.CallExpr) + if !ok || len(x.Args) != len(y.Args) { + return false + } + if !exprIdentical(x.Fun, y.Fun) { + return false + } + for i := range x.Args { + if !exprIdentical(x.Args[i], y.Args[i]) { + return false + } + } + return true + default: + // For unhandled expression types, consider them unequal. + return false + } +} diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index 119ca390ced..86852aad28b 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -59,6 +59,7 @@ func singleFile(fixer1 singleFileFixer) fixer { // Names of ApplyFix.Fix created directly by the CodeAction handler. const ( fixExtractVariable = "extract_variable" + fixExtractAllOccursOfExpr = "extract_all_occurs_of_expr" fixExtractFunction = "extract_function" fixExtractMethod = "extract_method" fixInlineCall = "inline_call" @@ -106,6 +107,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file fixExtractFunction: singleFile(extractFunction), fixExtractMethod: singleFile(extractMethod), fixExtractVariable: singleFile(extractVariable), + fixExtractAllOccursOfExpr: singleFile(extractAllOccursOfExpr), fixInlineCall: inlineCall, fixInvertIfCondition: singleFile(invertIfCondition), fixSplitLines: singleFile(splitLines), diff --git a/gopls/internal/settings/codeactionkind.go b/gopls/internal/settings/codeactionkind.go index 16a2eecb2cb..ea14eec45ca 100644 --- a/gopls/internal/settings/codeactionkind.go +++ b/gopls/internal/settings/codeactionkind.go @@ -97,10 +97,11 @@ const ( RefactorInlineCall protocol.CodeActionKind = "refactor.inline.call" // refactor.extract - RefactorExtractFunction protocol.CodeActionKind = "refactor.extract.function" - RefactorExtractMethod protocol.CodeActionKind = "refactor.extract.method" - RefactorExtractVariable protocol.CodeActionKind = "refactor.extract.variable" - RefactorExtractToNewFile protocol.CodeActionKind = "refactor.extract.toNewFile" + RefactorExtractFunction protocol.CodeActionKind = "refactor.extract.function" + RefactorExtractMethod protocol.CodeActionKind = "refactor.extract.method" + RefactorExtractVariable protocol.CodeActionKind = "refactor.extract.variable" + RefactorExtractAllOccursOfExpr protocol.CodeActionKind = "refactor.extract.variable.all" + RefactorExtractToNewFile protocol.CodeActionKind = "refactor.extract.toNewFile" // Note: add new kinds to: // - the SupportedCodeActions map in default.go diff --git a/gopls/internal/test/integration/misc/codeactions_test.go b/gopls/internal/test/integration/misc/codeactions_test.go index 7e5ac9aba62..c71dac30d93 100644 --- a/gopls/internal/test/integration/misc/codeactions_test.go +++ b/gopls/internal/test/integration/misc/codeactions_test.go @@ -69,6 +69,7 @@ func g() {} settings.GoFreeSymbols, settings.GoplsDocFeatures, settings.RefactorExtractVariable, + settings.RefactorExtractAllOccursOfExpr, settings.RefactorInlineCall) check("gen/a.go", settings.GoAssembly, diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt index 259b84a09a3..b2471cc22f9 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt @@ -25,5 +25,5 @@ func main() { -- @type_switch_func_call/extract_switch.go -- @@ -10 +10,2 @@ - switch r := f().(type) { //@codeactionedit("f()", "refactor.extract.variable", type_switch_func_call) -+ x := f() -+ switch r := x.(type) { //@codeactionedit("f()", "refactor.extract.variable", type_switch_func_call) ++ newVar := f() ++ switch r := newVar.(type) { //@codeactionedit("f()", "refactor.extract.variable", type_switch_func_call) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt index 8c500d02c1e..e372064df52 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt @@ -15,13 +15,13 @@ func _() { -- @basic_lit1/basic_lit.go -- @@ -4 +4,2 @@ - var _ = 1 + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) -+ x := 1 -+ var _ = x + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) ++ newVar := 1 ++ var _ = newVar + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) -- @basic_lit2/basic_lit.go -- @@ -5 +5,2 @@ - var _ = 3 + 4 //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) -+ x := 3 + 4 -+ var _ = x //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) ++ newVar := 3 + 4 ++ var _ = newVar //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) -- func_call.go -- package extract @@ -36,13 +36,13 @@ func _() { -- @func_call1/func_call.go -- @@ -6 +6,2 @@ - x0 := append([]int{}, 1) //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) -+ x := append([]int{}, 1) -+ x0 := x //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) ++ newVar := append([]int{}, 1) ++ x0 := newVar //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) -- @func_call2/func_call.go -- @@ -8 +8,2 @@ - b, err := strconv.Atoi(str) //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) -+ x, x1 := strconv.Atoi(str) -+ b, err := x, x1 //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) ++ newVar, newVar1 := strconv.Atoi(str) ++ b, err := newVar, newVar1 //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) -- scope.go -- package extract @@ -61,10 +61,10 @@ func _() { -- @scope1/scope.go -- @@ -8 +8,2 @@ - y := ast.CompositeLit{} //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) -+ x := ast.CompositeLit{} -+ y := x //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) ++ newVar := ast.CompositeLit{} ++ y := newVar //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) -- @scope2/scope.go -- @@ -11 +11,2 @@ - x1 := !false //@codeactionedit("!false", "refactor.extract.variable", scope2) -+ x := !false -+ x1 := x //@codeactionedit("!false", "refactor.extract.variable", scope2) ++ newVar := !false ++ x1 := newVar //@codeactionedit("!false", "refactor.extract.variable", scope2) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt index b3a9a67059f..3b1dc24687c 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt @@ -26,13 +26,13 @@ func _() { -- @basic_lit1/basic_lit.go -- @@ -4 +4,2 @@ - var _ = 1 + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) -+ x := 1 -+ var _ = x + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) ++ newVar := 1 ++ var _ = newVar + 2 //@codeactionedit("1", "refactor.extract.variable", basic_lit1) -- @basic_lit2/basic_lit.go -- @@ -5 +5,2 @@ - var _ = 3 + 4 //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) -+ x := 3 + 4 -+ var _ = x //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) ++ newVar := 3 + 4 ++ var _ = newVar //@codeactionedit("3 + 4", "refactor.extract.variable", basic_lit2) -- func_call.go -- package extract @@ -47,13 +47,13 @@ func _() { -- @func_call1/func_call.go -- @@ -6 +6,2 @@ - x0 := append([]int{}, 1) //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) -+ x := append([]int{}, 1) -+ x0 := x //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) ++ newVar := append([]int{}, 1) ++ x0 := newVar //@codeactionedit("append([]int{}, 1)", "refactor.extract.variable", func_call1) -- @func_call2/func_call.go -- @@ -8 +8,2 @@ - b, err := strconv.Atoi(str) //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) -+ x, x1 := strconv.Atoi(str) -+ b, err := x, x1 //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) ++ newVar, newVar1 := strconv.Atoi(str) ++ b, err := newVar, newVar1 //@codeactionedit("strconv.Atoi(str)", "refactor.extract.variable", func_call2) -- scope.go -- package extract @@ -72,10 +72,10 @@ func _() { -- @scope1/scope.go -- @@ -8 +8,2 @@ - y := ast.CompositeLit{} //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) -+ x := ast.CompositeLit{} -+ y := x //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) ++ newVar := ast.CompositeLit{} ++ y := newVar //@codeactionedit("ast.CompositeLit{}", "refactor.extract.variable", scope1) -- @scope2/scope.go -- @@ -11 +11,2 @@ - x1 := !false //@codeactionedit("!false", "refactor.extract.variable", scope2) -+ x := !false -+ x1 := x //@codeactionedit("!false", "refactor.extract.variable", scope2) ++ newVar := !false ++ x1 := newVar //@codeactionedit("!false", "refactor.extract.variable", scope2) diff --git a/gopls/internal/test/marker/testdata/codeaction/replace_expressions.txt b/gopls/internal/test/marker/testdata/codeaction/replace_expressions.txt new file mode 100644 index 00000000000..97f1308192e --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/replace_expressions.txt @@ -0,0 +1,188 @@ +This test checks the behavior of the 'replace occurrences of expression' code action. +See replace_expressions_resolve.txt for the same test with resolve support. + +-- flags -- +-ignore_extra_diags + +-- basic_lit.go -- +package extract_all + +func _() { + var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) + var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) +} +-- @basic_lit/basic_lit.go -- +@@ -4,2 +4,3 @@ +- var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) +- var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) ++ newVar := 1 + 2 ++ var _ = newVar + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) ++ var _ = newVar + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) +-- nested_scope.go -- +package extract_all + +func _() { + if true { + x := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) + } + if true { + if false { + y := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) + } + } + z := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +} +-- @nested_scope/nested_scope.go -- +@@ -4 +4 @@ ++ newVar := 1 + 2 + 3 +@@ -5 +6 @@ +- x := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) ++ x := newVar //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +@@ -9 +10 @@ +- y := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) ++ y := newVar //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +@@ -12 +13 @@ +- z := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) ++ z := newVar //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +-- function_call.go -- +package extract_all + +import "fmt" + +func _() { + result := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) + if result != "" { + anotherResult := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) + _ = anotherResult + } +} +-- @replace_func_call/function_call.go -- +@@ -6 +6,2 @@ +- result := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) ++ newVar := fmt.Sprintf("%d", 42) ++ result := newVar //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) +@@ -8 +9 @@ +- anotherResult := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) ++ anotherResult := newVar //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) +-- composite_literals.go -- +package extract_all + +func _() { + data := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) + processData(data) + moreData := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) + processData(moreData) +} + +func processData(d []int) {} +-- @composite/composite_literals.go -- +@@ -4 +4,2 @@ +- data := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) ++ newVar := []int{1, 2, 3} ++ data := newVar //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) +@@ -6 +7 @@ +- moreData := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) ++ moreData := newVar //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) +-- selector.go -- +package extract_all + +type MyStruct struct { + Value int +} + +func _() { + s := MyStruct{Value: 10} + v := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) + if v > 0 { + w := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) + _ = w + } +} +-- @sel/selector.go -- +@@ -9 +9,2 @@ +- v := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) ++ newVar := s.Value ++ v := newVar //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) +@@ -11 +12 @@ +- w := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) ++ w := newVar //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) +-- index.go -- +package extract_all + +func _() { + arr := []int{1, 2, 3} + val := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) + val2 := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) +} +-- @index/index.go -- +@@ -5,2 +5,3 @@ +- val := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) +- val2 := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) ++ newVar := arr[0] ++ val := newVar //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) ++ val2 := newVar //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) +-- slice_expr.go -- +package extract_all + +func _() { + data := []int{1, 2, 3, 4, 5} + part := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) + anotherPart := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) +} +-- @slice/slice_expr.go -- +@@ -5,2 +5,3 @@ +- part := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) +- anotherPart := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) ++ newVar := data[1:3] ++ part := newVar //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) ++ anotherPart := newVar //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) +-- nested_func.go -- +package extract_all + +func outer() { + inner := func() { + val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) + _ = val + } + inner() + val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) + _ = val +} +-- @nested/nested_func.go -- +@@ -4 +4 @@ ++ newVar := 100 + 200 +@@ -5 +6 @@ +- val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) ++ val := newVar //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) +@@ -9 +10 @@ +- val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) ++ val := newVar //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) +-- switch.go -- +package extract_all + +func _() { + value := 2 + switch value { + case 1: + result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) + _ = result + case 2: + result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) + _ = result + default: + result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) + _ = result + } +} +-- @switch/switch.go -- +@@ -5 +5 @@ ++ newVar := value * 10 +@@ -7 +8 @@ +- result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) ++ result := newVar //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) +@@ -10 +11 @@ +- result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) ++ result := newVar //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) +@@ -13 +14 @@ +- result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) ++ result := newVar //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) diff --git a/gopls/internal/test/marker/testdata/codeaction/replace_expressions_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/replace_expressions_resolve.txt new file mode 100644 index 00000000000..51457d0bd31 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/replace_expressions_resolve.txt @@ -0,0 +1,199 @@ +This test checks the behavior of the 'replace occurrences of expression' code action. +See extract_variable_resolve.txt for the same test with resolve support. + +-- capabilities.json -- +{ + "textDocument": { + "codeAction": { + "dataSupport": true, + "resolveSupport": { + "properties": ["edit"] + } + } + } +} +-- flags -- +-ignore_extra_diags + +-- basic_lit.go -- +package extract_all + +func _() { + var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) + var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) +} +-- @basic_lit/basic_lit.go -- +@@ -4,2 +4,3 @@ +- var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) +- var _ = 1 + 2 + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) ++ newVar := 1 + 2 ++ var _ = newVar + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) ++ var _ = newVar + 3 //@codeactionedit("1 + 2", "refactor.extract.variable.all", basic_lit) +-- nested_scope.go -- +package extract_all + +func _() { + if true { + x := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) + } + if true { + if false { + y := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) + } + } + z := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +} +-- @nested_scope/nested_scope.go -- +@@ -4 +4 @@ ++ newVar := 1 + 2 + 3 +@@ -5 +6 @@ +- x := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) ++ x := newVar //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +@@ -9 +10 @@ +- y := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) ++ y := newVar //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +@@ -12 +13 @@ +- z := 1 + 2 + 3 //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) ++ z := newVar //@codeactionedit("1 + 2 + 3", "refactor.extract.variable.all", nested_scope) +-- function_call.go -- +package extract_all + +import "fmt" + +func _() { + result := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) + if result != "" { + anotherResult := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) + _ = anotherResult + } +} +-- @replace_func_call/function_call.go -- +@@ -6 +6,2 @@ +- result := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) ++ newVar := fmt.Sprintf("%d", 42) ++ result := newVar //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) +@@ -8 +9 @@ +- anotherResult := fmt.Sprintf("%d", 42) //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) ++ anotherResult := newVar //@codeactionedit(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable.all", replace_func_call) +-- composite_literals.go -- +package extract_all + +func _() { + data := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) + processData(data) + moreData := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) + processData(moreData) +} + +func processData(d []int) {} +-- @composite/composite_literals.go -- +@@ -4 +4,2 @@ +- data := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) ++ newVar := []int{1, 2, 3} ++ data := newVar //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) +@@ -6 +7 @@ +- moreData := []int{1, 2, 3} //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) ++ moreData := newVar //@codeactionedit("[]int{1, 2, 3}", "refactor.extract.variable.all", composite) +-- selector.go -- +package extract_all + +type MyStruct struct { + Value int +} + +func _() { + s := MyStruct{Value: 10} + v := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) + if v > 0 { + w := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) + _ = w + } +} +-- @sel/selector.go -- +@@ -9 +9,2 @@ +- v := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) ++ newVar := s.Value ++ v := newVar //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) +@@ -11 +12 @@ +- w := s.Value //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) ++ w := newVar //@codeactionedit("s.Value", "refactor.extract.variable.all", sel) +-- index.go -- +package extract_all + +func _() { + arr := []int{1, 2, 3} + val := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) + val2 := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) +} +-- @index/index.go -- +@@ -5,2 +5,3 @@ +- val := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) +- val2 := arr[0] //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) ++ newVar := arr[0] ++ val := newVar //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) ++ val2 := newVar //@codeactionedit("arr[0]", "refactor.extract.variable.all", index) +-- slice_expr.go -- +package extract_all + +func _() { + data := []int{1, 2, 3, 4, 5} + part := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) + anotherPart := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) +} +-- @slice/slice_expr.go -- +@@ -5,2 +5,3 @@ +- part := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) +- anotherPart := data[1:3] //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) ++ newVar := data[1:3] ++ part := newVar //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) ++ anotherPart := newVar //@codeactionedit("data[1:3]", "refactor.extract.variable.all", slice) +-- nested_func.go -- +package extract_all + +func outer() { + inner := func() { + val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) + _ = val + } + inner() + val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) + _ = val +} +-- @nested/nested_func.go -- +@@ -4 +4 @@ ++ newVar := 100 + 200 +@@ -5 +6 @@ +- val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) ++ val := newVar //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) +@@ -9 +10 @@ +- val := 100 + 200 //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) ++ val := newVar //@codeactionedit("100 + 200", "refactor.extract.variable.all", nested) +-- switch.go -- +package extract_all + +func _() { + value := 2 + switch value { + case 1: + result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) + _ = result + case 2: + result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) + _ = result + default: + result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) + _ = result + } +} +-- @switch/switch.go -- +@@ -5 +5 @@ ++ newVar := value * 10 +@@ -7 +8 @@ +- result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) ++ result := newVar //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) +@@ -10 +11 @@ +- result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) ++ result := newVar //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) +@@ -13 +14 @@ +- result := value * 10 //@codeactionedit("value * 10", "refactor.extract.variable.all", switch) ++ result := newVar //@codeactionedit("value * 10", "refactor.extract.variable.all", switch)