diff --git a/gopls/doc/assets/extract-val-all-before.png b/gopls/doc/assets/extract-val-all-before.png new file mode 100644 index 00000000000..1791283f30f Binary files /dev/null and b/gopls/doc/assets/extract-val-all-before.png differ diff --git a/gopls/doc/assets/extract-var-all-after.png b/gopls/doc/assets/extract-var-all-after.png new file mode 100644 index 00000000000..0340e4c6e7b Binary files /dev/null and b/gopls/doc/assets/extract-var-all-after.png differ diff --git a/gopls/doc/features/transformation.md b/gopls/doc/features/transformation.md index 7d32c24ca93..37015da2983 100644 --- a/gopls/doc/features/transformation.md +++ b/gopls/doc/features/transformation.md @@ -76,6 +76,7 @@ Gopls supports the following code actions: - [`refactor.extract.method`](#extract) - [`refactor.extract.toNewFile`](#extract.toNewFile) - [`refactor.extract.variable`](#extract) +- [`refactor.extract.variable-all`](#extract) - [`refactor.inline.call`](#refactor.inline.call) - [`refactor.rewrite.changeQuote`](#refactor.rewrite.changeQuote) - [`refactor.rewrite.fillStruct`](#refactor.rewrite.fillStruct) @@ -364,14 +365,22 @@ newly created declaration that contains the selected code: will be a method of the same receiver type. - **`refactor.extract.variable`** replaces an expression by a reference to a new - local variable named `x` initialized by the expression: + local variable named `newVar` initialized by the expression: ![Before extracting a var](../assets/extract-var-before.png) ![After extracting a var](../assets/extract-var-after.png) - **`refactor.extract.constant** does the same thing for a constant expression, introducing a local const declaration. +- **`refactor.extract.variable-all`** replaces all occurrences of the selected expression +within the function with a reference to a new local variable named `newVar`. +This extracts the expression once and reuses it wherever it appears in the function. + ![Before extracting all occurrences of EXPR](../assets/extract-var-all-before.png) + ![After extracting all occurrences of EXPR](../assets/extract-var-all-after.png) + + - **`refactor.extract.constant-all** does the same thing for a constant + expression, introducing a local const declaration. If the default name for the new declaration is already in use, gopls generates a fresh name. @@ -387,10 +396,8 @@ number of cases where it falls short, including: - https://github.com/golang/go/issues/66289 - https://github.com/golang/go/issues/65944 -- https://github.com/golang/go/issues/64821 - https://github.com/golang/go/issues/63394 - https://github.com/golang/go/issues/61496 -- https://github.com/golang/go/issues/50851 The following Extract features are planned for 2024 but not yet supported: diff --git a/gopls/doc/release/v0.18.0.md b/gopls/doc/release/v0.18.0.md index f80eeea5929..8b3d5d7d5b4 100644 --- a/gopls/doc/release/v0.18.0.md +++ b/gopls/doc/release/v0.18.0.md @@ -26,3 +26,8 @@ func (C[T]) Pop() (T, bool) { ... } var _ Stack[int] = C[int]{} ``` +## Extract all occurrences of the same expression under selection + +When you have multiple instances of the same expression in a function, +you can use this code action to extract it into a variable. +All occurrences of the expression will be replaced with a reference to the new variable. diff --git a/gopls/internal/golang/codeaction.go b/gopls/internal/golang/codeaction.go index 0a778ba758b..097d6e95a12 100644 --- a/gopls/internal/golang/codeaction.go +++ b/gopls/internal/golang/codeaction.go @@ -238,6 +238,8 @@ var codeActionProducers = [...]codeActionProducer{ {kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile}, {kind: settings.RefactorExtractConstant, fn: refactorExtractVariable, needPkg: true}, {kind: settings.RefactorExtractVariable, fn: refactorExtractVariable, needPkg: true}, + {kind: settings.RefactorExtractConstantAll, fn: refactorExtractVariableAll, needPkg: true}, + {kind: settings.RefactorExtractVariableAll, fn: refactorExtractVariableAll, needPkg: true}, {kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true}, {kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote}, {kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true}, @@ -467,14 +469,15 @@ func refactorExtractMethod(ctx context.Context, req *codeActionsRequest) error { // See [extractVariable] for command implementation. func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error { info := req.pkg.TypesInfo() - if expr, _, err := canExtractVariable(info, req.pgf.File, req.start, req.end); err == nil { + if exprs, err := canExtractVariable(info, req.pgf.File, req.start, req.end, false); err == nil { // Offer one of refactor.extract.{constant,variable} // based on the constness of the expression; this is a // limitation of the codeActionProducers mechanism. // Beware that future evolutions of the refactorings // may make them diverge to become non-complementary, // for example because "if const x = ...; y {" is illegal. - constant := info.Types[expr].Value != nil + // Same as [refactorExtractVariableAll]. + constant := info.Types[exprs[0]].Value != nil if (req.kind == settings.RefactorExtractConstant) == constant { title := "Extract variable" if constant { @@ -486,6 +489,35 @@ func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error return nil } +// refactorExtractVariableAll produces "Extract N occurrences of EXPR" code action. +// See [extractAllOccursOfExpr] for command implementation. +func refactorExtractVariableAll(ctx context.Context, req *codeActionsRequest) error { + info := req.pkg.TypesInfo() + // Don't suggest if only one expr is found, + // otherwise it will duplicate with [refactorExtractVariable] + if exprs, err := canExtractVariable(info, req.pgf.File, req.start, req.end, true); err == nil && len(exprs) > 1 { + start, end, err := req.pgf.NodeOffsets(exprs[0]) + if err != nil { + return err + } + desc := string(req.pgf.Src[start:end]) + if len(desc) >= 40 || strings.Contains(desc, "\n") { + desc = astutil.NodeDescription(exprs[0]) + } + constant := info.Types[exprs[0]].Value != nil + if (req.kind == settings.RefactorExtractConstantAll) == constant { + var title string + if constant { + title = fmt.Sprintf("Extract %d occurrences of const expression: %s", len(exprs), desc) + } else { + title = fmt.Sprintf("Extract %d occurrences of %s", len(exprs), desc) + } + req.addApplyFixAction(title, fixExtractVariableAll, req.loc) + } + } + return nil +} + // refactorExtractToNewFile produces "Extract declarations to new file" code actions. // See [server.commandHandler.ExtractToNewFile] for command implementation. func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) error { diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 72d718c2faf..458e7b155dc 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -20,6 +20,7 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" + goplsastutil "golang.org/x/tools/gopls/internal/util/astutil" "golang.org/x/tools/gopls/internal/util/bug" "golang.org/x/tools/gopls/internal/util/safetoken" "golang.org/x/tools/internal/analysisinternal" @@ -27,24 +28,54 @@ import ( ) // extractVariable implements the refactor.extract.{variable,constant} CodeAction command. -func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { +func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { + return extractExprs(fset, start, end, src, file, info, false) +} + +// extractVariableAll implements the refactor.extract.{variable,constant}-all CodeAction command. +func extractVariableAll(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) { + return extractExprs(fset, start, end, src, file, info, true) +} + +// extractExprs replaces occurrence(s) of a specified expression within the same function +// with newVar. If 'all' is true, it replaces all occurrences of the same expression; +// otherwise, it only replaces the selected expression. +// +// The new variable/constant is declared as close as possible to the first found expression +// within the deepest common scope accessible to all candidate occurrences. +func extractExprs(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, info *types.Info, all bool) (*token.FileSet, *analysis.SuggestedFix, error) { tokFile := fset.File(file.FileStart) - expr, path, err := canExtractVariable(info, file, start, end) + exprs, err := canExtractVariable(info, file, start, end, all) if err != nil { - return nil, nil, fmt.Errorf("cannot extract %s: %v", safetoken.StartPosition(fset, start), err) + return nil, nil, fmt.Errorf("cannot extract: %v", err) + } + + // innermost scope enclosing ith expression + exprScopes := make([]*types.Scope, len(exprs)) + for i, e := range exprs { + exprScopes[i] = info.Scopes[file].Innermost(e.Pos()) } - constant := info.Types[expr].Value != nil + + hasCollision := func(name string) bool { + for _, scope := range exprScopes { + if s, _ := scope.LookupParent(name, token.NoPos); s != nil { + return true + } + } + return false + } + constant := info.Types[exprs[0]].Value != nil // Generate name(s) for new declaration. - baseName := cond(constant, "k", "x") + baseName := cond(constant, "newConst", "newVar") var lhsNames []string - switch expr := expr.(type) { + switch expr := exprs[0].(type) { case *ast.CallExpr: tup, ok := info.TypeOf(expr).(*types.Tuple) if !ok { // conversion or single-valued call: // treat it the same as our standard extract variable case. - name, _ := freshName(info, file, expr.Pos(), baseName, 0) + name, _ := generateName(0, baseName, hasCollision) lhsNames = append(lhsNames, name) } else { @@ -53,17 +84,55 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file for range tup.Len() { // Generate a unique variable for each result. var name string - name, idx = freshName(info, file, expr.Pos(), baseName, idx) + name, idx = generateName(idx, baseName, hasCollision) lhsNames = append(lhsNames, name) } } default: // TODO: stricter rules for selectorExpr. - name, _ := freshName(info, file, expr.Pos(), baseName, 0) + name, _ := generateName(0, baseName, hasCollision) lhsNames = append(lhsNames, name) } + // Where all the extractable positions can see variable being declared. + var commonScope *types.Scope + counter := make(map[*types.Scope]int) +Outer: + for _, scope := range exprScopes { + for s := scope; s != nil; s = s.Parent() { + counter[s]++ + if counter[s] == len(exprScopes) { + // A scope whose count is len(scopes) is common to all ancestor paths. + // Stop at the first (innermost) one. + commonScope = s + break Outer + } + } + } + + var visiblePath []ast.Node + if commonScope != exprScopes[0] { + // This means the first expr within function body is not the largest scope, + // we need to find the scope immediately follow the common + // scope where we will insert the statement before. + child := exprScopes[0] + for p := child; p != nil; p = p.Parent() { + if p == commonScope { + break + } + child = p + } + visiblePath, _ = astutil.PathEnclosingInterval(file, child.Pos(), child.End()) + } else { + // Insert newVar inside commonScope before the first occurrence of the expression. + visiblePath, _ = astutil.PathEnclosingInterval(file, exprs[0].Pos(), exprs[0].End()) + } + variables, err := collectFreeVars(info, file, exprs[0].Pos(), exprs[0].End(), exprs[0]) + if err != nil { + return nil, nil, err + } + // TODO: There is a bug here: for a variable declared in a labeled // switch/for statement it returns the for/switch statement itself // which produces the below code which is a compiler error. e.g. @@ -74,26 +143,16 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file // x := r() // switch r1 := x { ... break label ... } // compiler error // - // TODO(golang/go#70563): Another bug: extracting the - // expression to the recommended place may cause it to migrate - // across one or more declarations that it references. - // - // Before: - // if x := 1; cond { - // } else if y := «x + 2»; cond { - // } - // - // After: - // x1 := x + 2 // error: undefined x - // if x := 1; cond { - // } else if y := x1; cond { - // } var ( insertPos token.Pos indentation string stmtOK bool // ok to use ":=" instead of var/const decl? ) - if before := analysisinternal.StmtToInsertVarBefore(path); before != nil { + if funcDecl, ok := visiblePath[len(visiblePath)-2].(*ast.FuncDecl); ok && goplsastutil.NodeContains(funcDecl.Body, start) { + before, err := stmtToInsertVarBefore(visiblePath, variables) + if err != nil { + return nil, nil, fmt.Errorf("cannot find location to insert extraction: %v", err) + } // Within function: compute appropriate statement indentation. indent, err := calculateIndentation(src, tokFile, before) if err != nil { @@ -116,7 +175,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file } else { // Outside any statement: insert before the current // declaration, without indentation. - currentDecl := path[len(path)-2] + currentDecl := visiblePath[len(visiblePath)-2] insertPos = currentDecl.Pos() indentation = "\n" } @@ -152,7 +211,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file Specs: []ast.Spec{ &ast.ValueSpec{ Names: names, - Values: []ast.Expr{expr}, + Values: []ast.Expr{exprs[0]}, }, }, } @@ -166,7 +225,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file newNode = &ast.AssignStmt{ Tok: token.DEFINE, Lhs: lhs, - Rhs: []ast.Expr{expr}, + Rhs: []ast.Expr{exprs[0]}, } } @@ -177,50 +236,248 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file } // TODO(adonovan): not sound for `...` string literals containing newlines. assignment := strings.ReplaceAll(buf.String(), "\n", indentation) + indentation - + textEdits := []analysis.TextEdit{{ + Pos: insertPos, + End: insertPos, + NewText: []byte(assignment), + }} + for _, e := range exprs { + textEdits = append(textEdits, analysis.TextEdit{ + Pos: e.Pos(), + End: e.End(), + NewText: []byte(strings.Join(lhsNames, ", ")), + }) + } return fset, &analysis.SuggestedFix{ - TextEdits: []analysis.TextEdit{ - { - Pos: insertPos, - End: insertPos, - NewText: []byte(assignment), - }, - { - Pos: start, - End: end, - NewText: []byte(strings.Join(lhsNames, ", ")), - }, - }, + TextEdits: textEdits, }, nil } +// stmtToInsertVarBefore returns the ast.Stmt before which we can safely insert a new variable, +// and ensures that the new declaration is inserted at a point where all free variables are declared before. +// Some examples: +// +// Basic Example: +// +// z := 1 +// y := z + x +// +// If x is undeclared, then this function would return `y := z + x`, so that we +// can insert `x := ` on the line before `y := z + x`. +// +// valid IfStmt example: +// +// if z == 1 { +// } else if z == y {} +// +// If y is undeclared, then this function would return `if z == 1 {`, because we cannot +// insert a statement between an if and an else if statement. As a result, we need to find +// the top of the if chain to insert `y := ` before. +// +// invalid IfStmt example: +// +// if x := 1; true { +// } else if y := x + 1; true { //apply refactor.extract.variable to x +// } +// +// `x` is a free variable defined in the IfStmt, we should not insert +// the extracted expression outside the IfStmt scope, instead, return an error. +func stmtToInsertVarBefore(path []ast.Node, variables []*variable) (ast.Stmt, error) { + enclosingIndex := -1 // index in path of enclosing stmt + for i, p := range path { + if _, ok := p.(ast.Stmt); ok { + enclosingIndex = i + break + } + } + if enclosingIndex == -1 { + return nil, fmt.Errorf("no enclosing statement") + } + enclosingStmt := path[enclosingIndex].(ast.Stmt) + + // hasFreeVar reports if any free variables is defined inside stmt (which may be nil). + // If true, indicates that the insertion point will sit before the variable declaration. + hasFreeVar := func(stmt ast.Stmt) bool { + if stmt == nil { + return false + } + for _, v := range variables { + if goplsastutil.NodeContains(stmt, v.obj.Pos()) { + return true + } + } + return false + } + + // baseIfStmt walks up the if/else-if chain until we get to + // the top of the current if chain. + baseIfStmt := func(index int) (ast.Stmt, error) { + stmt := path[index] + for _, node := range path[index+1:] { + ifStmt, ok := node.(*ast.IfStmt) + if !ok || ifStmt.Else != stmt { + break + } + if hasFreeVar(ifStmt.Init) { + return nil, fmt.Errorf("Else's init statement has free variable declaration") + } + stmt = ifStmt + } + return stmt.(ast.Stmt), nil + } + + switch enclosingStmt := enclosingStmt.(type) { + case *ast.IfStmt: + if hasFreeVar(enclosingStmt.Init) { + return nil, fmt.Errorf("IfStmt's init statement has free variable declaration") + } + // The enclosingStmt is inside of the if declaration, + // We need to check if we are in an else-if stmt and + // get the base if statement. + return baseIfStmt(enclosingIndex) + case *ast.CaseClause: + // Get the enclosing switch stmt if the enclosingStmt is + // inside of the case statement. + for _, node := range path[enclosingIndex+1:] { + switch stmt := node.(type) { + case *ast.SwitchStmt: + if hasFreeVar(stmt.Init) { + return nil, fmt.Errorf("SwitchStmt's init statement has free variable declaration") + } + return stmt, nil + case *ast.TypeSwitchStmt: + if hasFreeVar(stmt.Init) { + return nil, fmt.Errorf("TypeSwitchStmt's init statement has free variable declaration") + } + return stmt, nil + } + } + } + // Check if the enclosing statement is inside another node. + switch parent := path[enclosingIndex+1].(type) { + case *ast.IfStmt: + if hasFreeVar(parent.Init) { + return nil, fmt.Errorf("IfStmt's init statement has free variable declaration") + } + return baseIfStmt(enclosingIndex + 1) + case *ast.ForStmt: + if parent.Init == enclosingStmt || parent.Post == enclosingStmt { + return parent, nil + } + case *ast.SwitchStmt: + if hasFreeVar(parent.Init) { + return nil, fmt.Errorf("SwitchStmt's init statement has free variable declaration") + } + return parent, nil + case *ast.TypeSwitchStmt: + if hasFreeVar(parent.Init) { + return nil, fmt.Errorf("TypeSwitchStmt's init statement has free variable declaration") + } + return parent, nil + } + return enclosingStmt.(ast.Stmt), nil +} + // canExtractVariable reports whether the code in the given range can be -// extracted to a variable (or constant). -func canExtractVariable(info *types.Info, file *ast.File, start, end token.Pos) (ast.Expr, []ast.Node, error) { +// extracted to a variable (or constant). It returns the selected expression or, if 'all', +// all structurally equivalent expressions within the same function body, in lexical order. +func canExtractVariable(info *types.Info, file *ast.File, start, end token.Pos, all bool) ([]ast.Expr, error) { if start == end { - return nil, nil, fmt.Errorf("empty selection") + return nil, fmt.Errorf("empty selection") } path, exact := astutil.PathEnclosingInterval(file, start, end) if !exact { - return nil, nil, fmt.Errorf("selection is not an expression") + return nil, fmt.Errorf("selection is not an expression") } if len(path) == 0 { - return nil, nil, bug.Errorf("no path enclosing interval") + return nil, bug.Errorf("no path enclosing interval") } for _, n := range path { if _, ok := n.(*ast.ImportSpec); ok { - return nil, nil, fmt.Errorf("cannot extract variable or constant in an import block") + return nil, fmt.Errorf("cannot extract variable or constant in an import block") } } expr, ok := path[0].(ast.Expr) if !ok { - return nil, nil, fmt.Errorf("selection is not an expression") // e.g. statement + return nil, fmt.Errorf("selection is not an expression") // e.g. statement } if tv, ok := info.Types[expr]; !ok || !tv.IsValue() || tv.Type == nil || tv.HasOk() { // e.g. type, builtin, x.(type), 2-valued m[k], or ill-typed - return nil, nil, fmt.Errorf("selection is not a single-valued expression") + return nil, fmt.Errorf("selection is not a single-valued expression") + } + + var exprs []ast.Expr + if !all { + exprs = append(exprs, expr) + } else if funcDecl, ok := path[len(path)-2].(*ast.FuncDecl); ok { + // Find all expressions in the same function body that + // are equal to the selected expression. + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { + if e, ok := n.(ast.Expr); ok { + if goplsastutil.Equal(e, expr, func(x, y *ast.Ident) bool { + return x.Name == y.Name && info.Uses[x] == info.Uses[y] + }) { + exprs = append(exprs, e) + } + } + return true + }) + } else { + return nil, fmt.Errorf("node %T is not inside a function", expr) } - return expr, path, nil + + // Disallow any expr that sits in lhs of an AssignStmt or ValueSpec for now. + // + // TODO(golang/go#70784): In such cases, exprs are operated in "variable" mode (L-value mode in C). + // In contrast, exprs in the RHS operate in "value" mode (R-value mode in C). + // L-value mode refers to exprs that represent storage locations, + // while R-value mode refers to exprs that represent values. + // There are a number of expressions that may have L-value mode, given by: + // + // lvalue = ident -- Ident such that info.Uses[id] is a *Var + // | '(' lvalue ') ' -- ParenExpr + // | lvalue '[' expr ']' -- IndexExpr + // | lvalue '.' ident -- SelectorExpr. + // + // For example: + // + // type foo struct { + // bar int + // } + // f := foo{bar: 1} + // x := f.bar + 1 // f.bar operates in "value" mode. + // f.bar = 2 // f.bar operates in "variable" mode. + // + // When extracting exprs in variable mode, we must be cautious. Any such extraction + // may require capturing the address of the expression and replacing its uses with dereferenced access. + // The type checker records this information in info.Types[id].{IsValue,Addressable}(). + // The correct result should be: + // + // newVar := &f.bar + // x := *newVar + 1 + // *newVar = 2 + for _, e := range exprs { + path, _ := astutil.PathEnclosingInterval(file, e.Pos(), e.End()) + for _, n := range path { + if assignment, ok := n.(*ast.AssignStmt); ok { + for _, lhs := range assignment.Lhs { + if lhs == e { + return nil, fmt.Errorf("node %T is in LHS of an AssignStmt", expr) + } + } + break + } + if value, ok := n.(*ast.ValueSpec); ok { + for _, name := range value.Names { + if name == e { + return nil, fmt.Errorf("node %T is in LHS of a ValueSpec", expr) + } + } + break + } + } + } + return exprs, nil } // Calculate indentation for insertion. @@ -331,14 +588,6 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte safetoken.StartPosition(fset, start), err) } tok, path, start, end, outer, node := p.tok, p.path, p.start, p.end, p.outer, p.node - fileScope := info.Scopes[file] - if fileScope == nil { - return nil, nil, fmt.Errorf("%s: file scope is empty", errorPrefix) - } - pkgScope := fileScope.Parent() - if pkgScope == nil { - return nil, nil, fmt.Errorf("%s: package scope is empty", errorPrefix) - } // A return statement is non-nested if its parent node is equal to the parent node // of the first node in the selection. These cases must be handled separately because @@ -373,7 +622,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte // we must determine the signature of the extracted function. We will then replace // the block with an assignment statement that calls the extracted function with // the appropriate parameters and return values. - variables, err := collectFreeVars(info, file, fileScope, pkgScope, start, end, path[0]) + variables, err := collectFreeVars(info, file, start, end, path[0]) if err != nil { return nil, nil, err } @@ -922,7 +1171,15 @@ type variable struct { // variables will be used as arguments in the extracted function. It also returns a // list of identifiers that may need to be returned by the extracted function. // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. -func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, start, end token.Pos, node ast.Node) ([]*variable, error) { +func collectFreeVars(info *types.Info, file *ast.File, start, end token.Pos, node ast.Node) ([]*variable, error) { + fileScope := info.Scopes[file] + if fileScope == nil { + return nil, bug.Errorf("file scope is empty") + } + pkgScope := fileScope.Parent() + if pkgScope == nil { + return nil, bug.Errorf("package scope is empty") + } // id returns non-nil if n denotes an object that is referenced by the span // and defined either within the span or in the lexical environment. The bool // return value acts as an indicator for where it was defined. diff --git a/gopls/internal/golang/fix.go b/gopls/internal/golang/fix.go index f88343f029c..7e83c1d6700 100644 --- a/gopls/internal/golang/fix.go +++ b/gopls/internal/golang/fix.go @@ -59,6 +59,7 @@ func singleFile(fixer1 singleFileFixer) fixer { // Names of ApplyFix.Fix created directly by the CodeAction handler. const ( fixExtractVariable = "extract_variable" // (or constant) + fixExtractVariableAll = "extract_variable_all" fixExtractFunction = "extract_function" fixExtractMethod = "extract_method" fixInlineCall = "inline_call" @@ -106,6 +107,7 @@ func ApplyFix(ctx context.Context, fix string, snapshot *cache.Snapshot, fh file fixExtractFunction: singleFile(extractFunction), fixExtractMethod: singleFile(extractMethod), fixExtractVariable: singleFile(extractVariable), + fixExtractVariableAll: singleFile(extractVariableAll), fixInlineCall: inlineCall, fixInvertIfCondition: singleFile(invertIfCondition), fixSplitLines: singleFile(splitLines), diff --git a/gopls/internal/golang/undeclared.go b/gopls/internal/golang/undeclared.go index 3d9954639b4..ef32e949588 100644 --- a/gopls/internal/golang/undeclared.go +++ b/gopls/internal/golang/undeclared.go @@ -17,7 +17,6 @@ import ( "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/gopls/internal/util/typesutil" - "golang.org/x/tools/internal/analysisinternal" "golang.org/x/tools/internal/typesinternal" ) @@ -126,9 +125,9 @@ func CreateUndeclared(fset *token.FileSet, start, end token.Pos, content []byte, return nil, nil, fmt.Errorf("no identifier found") } p, _ := astutil.PathEnclosingInterval(file, firstRef.Pos(), firstRef.Pos()) - insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(p) - if insertBeforeStmt == nil { - return nil, nil, fmt.Errorf("could not locate insertion point") + insertBeforeStmt, err := stmtToInsertVarBefore(p, nil) + if err != nil { + return nil, nil, fmt.Errorf("could not locate insertion point: %v", err) } indent, err := calculateIndentation(content, fset.File(file.FileStart), insertBeforeStmt) if err != nil { diff --git a/gopls/internal/settings/codeactionkind.go b/gopls/internal/settings/codeactionkind.go index 7bc4f4e4d66..0daf3cb5999 100644 --- a/gopls/internal/settings/codeactionkind.go +++ b/gopls/internal/settings/codeactionkind.go @@ -99,11 +99,13 @@ const ( RefactorInlineCall protocol.CodeActionKind = "refactor.inline.call" // refactor.extract - RefactorExtractConstant protocol.CodeActionKind = "refactor.extract.constant" - RefactorExtractFunction protocol.CodeActionKind = "refactor.extract.function" - RefactorExtractMethod protocol.CodeActionKind = "refactor.extract.method" - RefactorExtractVariable protocol.CodeActionKind = "refactor.extract.variable" - RefactorExtractToNewFile protocol.CodeActionKind = "refactor.extract.toNewFile" + RefactorExtractConstant protocol.CodeActionKind = "refactor.extract.constant" + RefactorExtractConstantAll protocol.CodeActionKind = "refactor.extract.constant-all" + RefactorExtractFunction protocol.CodeActionKind = "refactor.extract.function" + RefactorExtractMethod protocol.CodeActionKind = "refactor.extract.method" + RefactorExtractVariable protocol.CodeActionKind = "refactor.extract.variable" + RefactorExtractVariableAll protocol.CodeActionKind = "refactor.extract.variable-all" + RefactorExtractToNewFile protocol.CodeActionKind = "refactor.extract.toNewFile" // Note: add new kinds to: // - the SupportedCodeActions map in default.go diff --git a/gopls/internal/settings/default.go b/gopls/internal/settings/default.go index 0354101f045..3048f90bd3b 100644 --- a/gopls/internal/settings/default.go +++ b/gopls/internal/settings/default.go @@ -62,9 +62,11 @@ func DefaultOptions(overrides ...func(*Options)) *Options { RefactorRewriteSplitLines: true, RefactorInlineCall: true, RefactorExtractConstant: true, + RefactorExtractConstantAll: true, RefactorExtractFunction: true, RefactorExtractMethod: true, RefactorExtractVariable: true, + RefactorExtractVariableAll: true, RefactorExtractToNewFile: true, // Not GoTest: it must be explicit in CodeActionParams.Context.Only }, diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt index fabbbee99d3..96c09cd0246 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-67905.txt @@ -25,5 +25,5 @@ func main() { -- @type_switch_func_call/extract_switch.go -- @@ -10 +10,2 @@ - switch r := f().(type) { //@codeaction("f()", "refactor.extract.variable", edit=type_switch_func_call) -+ x := f() -+ switch r := x.(type) { //@codeaction("f()", "refactor.extract.variable", edit=type_switch_func_call) ++ newVar := f() ++ switch r := newVar.(type) { //@codeaction("f()", "refactor.extract.variable", edit=type_switch_func_call) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-70563.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-70563.txt new file mode 100644 index 00000000000..1317815ea32 --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-70563.txt @@ -0,0 +1,50 @@ +This test verifies the fix for golang/go#70563: refactor.extract.variable +inserts new statement before the scope of its free symbols. + +-- flags -- +-ignore_extra_diags + +-- inside_else.go -- +package extract + +func _() { + if x := 1; true { + + } else if y := x + 1; true { //@codeaction("x + 1", "refactor.extract.variable", err=re"Else's init statement has free variable declaration") + + } +} +-- inside_case.go -- +package extract + +func _() { + switch x := 1; x { + case x + 1: //@codeaction("x + 1", "refactor.extract.variable-all", err=re"SwitchStmt's init statement has free variable declaration") + y := x + 1 //@codeaction("x + 1", "refactor.extract.variable-all", err=re"SwitchStmt's init statement has free variable declaration") + _ = y + case 3: + y := x + 1 //@codeaction("x + 1", "refactor.extract.variable-all", err=re"SwitchStmt's init statement has free variable declaration") + _ = y + } +} +-- parent_if.go -- +package extract + +func _() { + if x := 1; x > 0 { + y = x + 1 //@codeaction("x + 1", "refactor.extract.variable-all", err=re"IfStmt's init statement has free variable declaration") + } else { + y = x + 1 //@codeaction("x + 1", "refactor.extract.variable-all", err=re"IfStmt's init statement has free variable declaration") + } +} +-- parent_switch.go -- +package extract + +func _() { + switch x := 1; x { + case 1: + y = x + 1 //@codeaction("x + 1", "refactor.extract.variable-all", err=re"SwitchStmt's init statement has free variable declaration") + case 3: + y = x + 1 //@codeaction("x + 1", "refactor.extract.variable-all", err=re"SwitchStmt's init statement has free variable declaration") + } +} diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt index ab9d76b8602..fdc00d3bf8f 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-if.txt @@ -29,13 +29,13 @@ func variable(y int) { -- @constant/a.go -- @@ -4 +4 @@ -+ const k = 1 + 2 ++ const newConst = 1 + 2 @@ -5 +6 @@ - } else if 1 + 2 > 0 { //@ codeaction("1 + 2", "refactor.extract.constant", edit=constant) -+ } else if k > 0 { //@ codeaction("1 + 2", "refactor.extract.constant", edit=constant) ++ } else if newConst > 0 { //@ codeaction("1 + 2", "refactor.extract.constant", edit=constant) -- @variable/a.go -- @@ -10 +10 @@ -+ x := y + y ++ newVar := y + y @@ -11 +12 @@ - } else if y + y > 0 { //@ codeaction("y + y", "refactor.extract.variable", edit=variable) -+ } else if x > 0 { //@ codeaction("y + y", "refactor.extract.variable", edit=variable) ++ } else if newVar > 0 { //@ codeaction("y + y", "refactor.extract.variable", edit=variable) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt index 1781b3ce6af..5ddff1182f6 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-inexact.txt @@ -17,20 +17,20 @@ func _(ptr *int) { -- @spaces/a.go -- @@ -4 +4,2 @@ - var _ = 1 + 2 + 3 //@codeaction("1 + 2 ", "refactor.extract.constant", edit=spaces) -+ const k = 1 + 2 -+ var _ = k+ 3 //@codeaction("1 + 2 ", "refactor.extract.constant", edit=spaces) ++ const newConst = 1 + 2 ++ var _ = newConst + 3 //@codeaction("1 + 2 ", "refactor.extract.constant", edit=spaces) -- @funclit/a.go -- @@ -5 +5,2 @@ - var _ = func() {} //@codeaction("func() {}", "refactor.extract.variable", edit=funclit) -+ x := func() {} -+ var _ = x //@codeaction("func() {}", "refactor.extract.variable", edit=funclit) ++ newVar := func() {} ++ var _ = newVar //@codeaction("func() {}", "refactor.extract.variable", edit=funclit) -- @ptr/a.go -- @@ -6 +6,2 @@ - var _ = *ptr //@codeaction("*ptr", "refactor.extract.variable", edit=ptr) -+ x := *ptr -+ var _ = x //@codeaction("*ptr", "refactor.extract.variable", edit=ptr) ++ newVar := *ptr ++ var _ = newVar //@codeaction("*ptr", "refactor.extract.variable", edit=ptr) -- @paren/a.go -- @@ -7 +7,2 @@ - var _ = (ptr) //@codeaction("(ptr)", "refactor.extract.variable", edit=paren) -+ x := (ptr) -+ var _ = x //@codeaction("(ptr)", "refactor.extract.variable", edit=paren) ++ newVar := (ptr) ++ var _ = newVar //@codeaction("(ptr)", "refactor.extract.variable", edit=paren) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt index b9166c6299d..d41fee42c9f 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable-toplevel.txt @@ -15,37 +15,37 @@ func f([2]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtype -- @lenhello/a.go -- @@ -3 +3,2 @@ -const length = len("hello") + 2 //@codeaction(`len("hello")`, "refactor.extract.constant", edit=lenhello) -+const k = len("hello") -+const length = k + 2 //@codeaction(`len("hello")`, "refactor.extract.constant", edit=lenhello) ++const newConst = len("hello") ++const length = newConst + 2 //@codeaction(`len("hello")`, "refactor.extract.constant", edit=lenhello) -- @sliceliteral/a.go -- @@ -5 +5,2 @@ -var slice = append([]int{}, 1, 2, 3) //@codeaction("[]int{}", "refactor.extract.variable", edit=sliceliteral) -+var x = []int{} -+var slice = append(x, 1, 2, 3) //@codeaction("[]int{}", "refactor.extract.variable", edit=sliceliteral) ++var newVar = []int{} ++var slice = append(newVar, 1, 2, 3) //@codeaction("[]int{}", "refactor.extract.variable", edit=sliceliteral) -- @arraylen/a.go -- @@ -7 +7,2 @@ -type SHA256 [32]byte //@codeaction("32", "refactor.extract.constant", edit=arraylen) -+const k = 32 -+type SHA256 [k]byte //@codeaction("32", "refactor.extract.constant", edit=arraylen) ++const newConst = 32 ++type SHA256 [newConst]byte //@codeaction("32", "refactor.extract.constant", edit=arraylen) -- @paramtypearraylen/a.go -- @@ -9 +9,2 @@ -func f([2]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtypearraylen) -+const k = 2 -+func f([k]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtypearraylen) ++const newConst = 2 ++func f([newConst]int) {} //@codeaction("2", "refactor.extract.constant", edit=paramtypearraylen) -- b/b.go -- package b // Check that package- and file-level name collisions are avoided. -import x3 "errors" +import newVar3 "errors" -var x, x1, x2 any // these names are taken already -var _ = x3.New("") +var newVar, newVar1, newVar2 any // these names are taken already +var _ = newVar3.New("") var a, b int var c = a + b //@codeaction("a + b", "refactor.extract.variable", edit=fresh) -- @fresh/b/b.go -- @@ -10 +10,2 @@ -var c = a + b //@codeaction("a + b", "refactor.extract.variable", edit=fresh) -+var x4 = a + b -+var c = x4 //@codeaction("a + b", "refactor.extract.variable", edit=fresh) ++var newVar4 = a + b ++var c = newVar4 //@codeaction("a + b", "refactor.extract.variable", edit=fresh) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt index c14fb732978..9dd0f766e05 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable.txt @@ -15,13 +15,13 @@ func _() { -- @basic_lit1/basic_lit.go -- @@ -4 +4,2 @@ - var _ = 1 + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) -+ const k = 1 -+ var _ = k + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) ++ const newConst = 1 ++ var _ = newConst + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) -- @basic_lit2/basic_lit.go -- @@ -5 +5,2 @@ - var _ = 3 + 4 //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) -+ const k = 3 + 4 -+ var _ = k //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) ++ const newConst = 3 + 4 ++ var _ = newConst //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) -- func_call.go -- package extract @@ -36,13 +36,13 @@ func _() { -- @func_call1/func_call.go -- @@ -6 +6,2 @@ - x0 := append([]int{}, 1) //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) -+ x := append([]int{}, 1) -+ x0 := x //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) ++ newVar := append([]int{}, 1) ++ x0 := newVar //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) -- @func_call2/func_call.go -- @@ -8 +8,2 @@ - b, err := strconv.Atoi(str) //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) -+ x, x1 := strconv.Atoi(str) -+ b, err := x, x1 //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) ++ newVar, newVar1 := strconv.Atoi(str) ++ b, err := newVar, newVar1 //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) -- scope.go -- package extract @@ -61,10 +61,10 @@ func _() { -- @scope1/scope.go -- @@ -8 +8,2 @@ - y := ast.CompositeLit{} //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) -+ x := ast.CompositeLit{} -+ y := x //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) ++ newVar := ast.CompositeLit{} ++ y := newVar //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) -- @scope2/scope.go -- @@ -11 +11,2 @@ - x := !false //@codeaction("!false", "refactor.extract.constant", edit=scope2) -+ const k = !false -+ x := k //@codeaction("!false", "refactor.extract.constant", edit=scope2) ++ const newConst = !false ++ x := newConst //@codeaction("!false", "refactor.extract.constant", edit=scope2) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable_all.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable_all.txt new file mode 100644 index 00000000000..700990238fb --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable_all.txt @@ -0,0 +1,212 @@ +This test checks the behavior of the 'replace all occurrences of expression' code action, with resolve support. +See extract_expressions.txt for the same test without resolve support. + +-- flags -- +-ignore_extra_diags + +-- basic_lit.go -- +package extract_all + +func _() { + var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) + var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) +} +-- @basic_lit/basic_lit.go -- +@@ -4,2 +4,3 @@ +- var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) +- var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) ++ const newConst = 1 + 2 ++ var _ = newConst + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) ++ var _ = newConst + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) +-- nested_scope.go -- +package extract_all + +func _() { + newConst1 := 0 + if true { + x := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) + } + if true { + newConst := 0 + if false { + y := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) + } + } + z := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +} +-- @nested_scope/nested_scope.go -- +@@ -5 +5 @@ ++ const newConst2 = 1 + 2 + 3 +@@ -6 +7 @@ +- x := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) ++ x := newConst2 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +@@ -11 +12 @@ +- y := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) ++ y := newConst2 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +@@ -14 +15 @@ +- z := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) ++ z := newConst2 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +-- function_call.go -- +package extract_all + +import "fmt" + +func _() { + result := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) + if result != "" { + anotherResult := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) + _ = anotherResult + } +} +-- @replace_func_call/function_call.go -- +@@ -6 +6,2 @@ +- result := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) ++ newVar := fmt.Sprintf("%d", 42) ++ result := newVar //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) +@@ -8 +9 @@ +- anotherResult := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) ++ anotherResult := newVar //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) +-- composite_literals.go -- +package extract_all + +func _() { + data := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) + processData(data) + moreData := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) + processData(moreData) +} + +func processData(d []int) {} +-- @composite/composite_literals.go -- +@@ -4 +4,2 @@ +- data := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) ++ newVar := []int{1, 2, 3} ++ data := newVar //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) +@@ -6 +7 @@ +- moreData := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) ++ moreData := newVar //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) +-- selector.go -- +package extract_all + +type MyStruct struct { + Value int +} + +func _() { + s := MyStruct{Value: 10} + v := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) + if v > 0 { + w := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) + _ = w + } +} +-- @sel/selector.go -- +@@ -9 +9,2 @@ +- v := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) ++ newVar := s.Value ++ v := newVar //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) +@@ -11 +12 @@ +- w := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) ++ w := newVar //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) +-- index.go -- +package extract_all + +func _() { + arr := []int{1, 2, 3} + val := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) + val2 := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) +} +-- @index/index.go -- +@@ -5,2 +5,3 @@ +- val := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) +- val2 := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) ++ newVar := arr[0] ++ val := newVar //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) ++ val2 := newVar //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) +-- slice_expr.go -- +package extract_all + +func _() { + data := []int{1, 2, 3, 4, 5} + part := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) + anotherPart := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) +} +-- @slice/slice_expr.go -- +@@ -5,2 +5,3 @@ +- part := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) +- anotherPart := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) ++ newVar := data[1:3] ++ part := newVar //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) ++ anotherPart := newVar //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) +-- nested_func.go -- +package extract_all + +func outer() { + inner := func() { + val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) + _ = val + } + inner() + val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) + _ = val +} +-- @nested/nested_func.go -- +@@ -4 +4 @@ ++ const newConst = 100 + 200 +@@ -5 +6 @@ +- val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) ++ val := newConst //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) +@@ -9 +10 @@ +- val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) ++ val := newConst //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) +-- switch.go -- +package extract_all + +func _() { + value := 2 + switch value { + case 1: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) + _ = result + case 2: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) + _ = result + default: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) + _ = result + } +} +-- @switch/switch.go -- +@@ -5 +5 @@ ++ newVar := value * 10 +@@ -7 +8 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) +@@ -10 +11 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) +@@ -13 +14 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) +-- switch_single.go -- +package extract_all + +func _() { + value := 2 + switch value { + case 1: + result := value * 10 + _ = result + case 2: + result := value * 10 + _ = result + default: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable", edit=switch_single) + _ = result + } +} +-- @switch_single/switch_single.go -- +@@ -13 +13,2 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable", edit=switch_single) ++ newVar := value * 10 ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable", edit=switch_single) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable_all_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable_all_resolve.txt new file mode 100644 index 00000000000..5728af4a5bb --- /dev/null +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable_all_resolve.txt @@ -0,0 +1,223 @@ +This test checks the behavior of the 'replace all occurrences of expression' code action, with resolve support. +See extract_expressions.txt for the same test without resolve support. + +-- capabilities.json -- +{ + "textDocument": { + "codeAction": { + "dataSupport": true, + "resolveSupport": { + "properties": ["edit"] + } + } + } +} +-- flags -- +-ignore_extra_diags + +-- basic_lit.go -- +package extract_all + +func _() { + var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) + var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) +} +-- @basic_lit/basic_lit.go -- +@@ -4,2 +4,3 @@ +- var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) +- var _ = 1 + 2 + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) ++ const newConst = 1 + 2 ++ var _ = newConst + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) ++ var _ = newConst + 3 //@codeaction("1 + 2", "refactor.extract.constant-all", edit=basic_lit) +-- nested_scope.go -- +package extract_all + +func _() { + newConst1 := 0 + if true { + x := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) + } + if true { + newConst := 0 + if false { + y := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) + } + } + z := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +} +-- @nested_scope/nested_scope.go -- +@@ -5 +5 @@ ++ const newConst2 = 1 + 2 + 3 +@@ -6 +7 @@ +- x := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) ++ x := newConst2 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +@@ -11 +12 @@ +- y := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) ++ y := newConst2 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +@@ -14 +15 @@ +- z := 1 + 2 + 3 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) ++ z := newConst2 //@codeaction("1 + 2 + 3", "refactor.extract.constant-all", edit=nested_scope) +-- function_call.go -- +package extract_all + +import "fmt" + +func _() { + result := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) + if result != "" { + anotherResult := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) + _ = anotherResult + } +} +-- @replace_func_call/function_call.go -- +@@ -6 +6,2 @@ +- result := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) ++ newVar := fmt.Sprintf("%d", 42) ++ result := newVar //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) +@@ -8 +9 @@ +- anotherResult := fmt.Sprintf("%d", 42) //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) ++ anotherResult := newVar //@codeaction(`fmt.Sprintf("%d", 42)`, "refactor.extract.variable-all", edit=replace_func_call) +-- composite_literals.go -- +package extract_all + +func _() { + data := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) + processData(data) + moreData := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) + processData(moreData) +} + +func processData(d []int) {} +-- @composite/composite_literals.go -- +@@ -4 +4,2 @@ +- data := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) ++ newVar := []int{1, 2, 3} ++ data := newVar //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) +@@ -6 +7 @@ +- moreData := []int{1, 2, 3} //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) ++ moreData := newVar //@codeaction("[]int{1, 2, 3}", "refactor.extract.variable-all", edit=composite) +-- selector.go -- +package extract_all + +type MyStruct struct { + Value int +} + +func _() { + s := MyStruct{Value: 10} + v := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) + if v > 0 { + w := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) + _ = w + } +} +-- @sel/selector.go -- +@@ -9 +9,2 @@ +- v := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) ++ newVar := s.Value ++ v := newVar //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) +@@ -11 +12 @@ +- w := s.Value //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) ++ w := newVar //@codeaction("s.Value", "refactor.extract.variable-all", edit=sel) +-- index.go -- +package extract_all + +func _() { + arr := []int{1, 2, 3} + val := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) + val2 := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) +} +-- @index/index.go -- +@@ -5,2 +5,3 @@ +- val := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) +- val2 := arr[0] //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) ++ newVar := arr[0] ++ val := newVar //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) ++ val2 := newVar //@codeaction("arr[0]", "refactor.extract.variable-all", edit=index) +-- slice_expr.go -- +package extract_all + +func _() { + data := []int{1, 2, 3, 4, 5} + part := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) + anotherPart := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) +} +-- @slice/slice_expr.go -- +@@ -5,2 +5,3 @@ +- part := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) +- anotherPart := data[1:3] //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) ++ newVar := data[1:3] ++ part := newVar //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) ++ anotherPart := newVar //@codeaction("data[1:3]", "refactor.extract.variable-all", edit=slice) +-- nested_func.go -- +package extract_all + +func outer() { + inner := func() { + val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) + _ = val + } + inner() + val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) + _ = val +} +-- @nested/nested_func.go -- +@@ -4 +4 @@ ++ const newConst = 100 + 200 +@@ -5 +6 @@ +- val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) ++ val := newConst //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) +@@ -9 +10 @@ +- val := 100 + 200 //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) ++ val := newConst //@codeaction("100 + 200", "refactor.extract.constant-all", edit=nested) +-- switch.go -- +package extract_all + +func _() { + value := 2 + switch value { + case 1: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) + _ = result + case 2: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) + _ = result + default: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) + _ = result + } +} +-- @switch/switch.go -- +@@ -5 +5 @@ ++ newVar := value * 10 +@@ -7 +8 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) +@@ -10 +11 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) +@@ -13 +14 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable-all", edit=switch) +-- switch_single.go -- +package extract_all + +func _() { + value := 2 + switch value { + case 1: + result := value * 10 + _ = result + case 2: + result := value * 10 + _ = result + default: + result := value * 10 //@codeaction("value * 10", "refactor.extract.variable", edit=switch_single) + _ = result + } +} +-- @switch_single/switch_single.go -- +@@ -13 +13,2 @@ +- result := value * 10 //@codeaction("value * 10", "refactor.extract.variable", edit=switch_single) ++ newVar := value * 10 ++ result := newVar //@codeaction("value * 10", "refactor.extract.variable", edit=switch_single) diff --git a/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt b/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt index 2bf1803a7d8..203b6d1eadc 100644 --- a/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt +++ b/gopls/internal/test/marker/testdata/codeaction/extract_variable_resolve.txt @@ -26,13 +26,13 @@ func _() { -- @basic_lit1/basic_lit.go -- @@ -4 +4,2 @@ - var _ = 1 + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) -+ const k = 1 -+ var _ = k + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) ++ const newConst = 1 ++ var _ = newConst + 2 //@codeaction("1", "refactor.extract.constant", edit=basic_lit1) -- @basic_lit2/basic_lit.go -- @@ -5 +5,2 @@ - var _ = 3 + 4 //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) -+ const k = 3 + 4 -+ var _ = k //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) ++ const newConst = 3 + 4 ++ var _ = newConst //@codeaction("3 + 4", "refactor.extract.constant", edit=basic_lit2) -- func_call.go -- package extract @@ -47,13 +47,13 @@ func _() { -- @func_call1/func_call.go -- @@ -6 +6,2 @@ - x0 := append([]int{}, 1) //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) -+ x := append([]int{}, 1) -+ x0 := x //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) ++ newVar := append([]int{}, 1) ++ x0 := newVar //@codeaction("append([]int{}, 1)", "refactor.extract.variable", edit=func_call1) -- @func_call2/func_call.go -- @@ -8 +8,2 @@ - b, err := strconv.Atoi(str) //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) -+ x, x1 := strconv.Atoi(str) -+ b, err := x, x1 //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) ++ newVar, newVar1 := strconv.Atoi(str) ++ b, err := newVar, newVar1 //@codeaction("strconv.Atoi(str)", "refactor.extract.variable", edit=func_call2) -- scope.go -- package extract @@ -72,10 +72,10 @@ func _() { -- @scope1/scope.go -- @@ -8 +8,2 @@ - y := ast.CompositeLit{} //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) -+ x := ast.CompositeLit{} -+ y := x //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) ++ newVar := ast.CompositeLit{} ++ y := newVar //@codeaction("ast.CompositeLit{}", "refactor.extract.variable", edit=scope1) -- @scope2/scope.go -- @@ -11 +11,2 @@ - x := !false //@codeaction("!false", "refactor.extract.constant", edit=scope2) -+ const k = !false -+ x := k //@codeaction("!false", "refactor.extract.constant", edit=scope2) ++ const newConst = !false ++ x := newConst //@codeaction("!false", "refactor.extract.constant", edit=scope2) diff --git a/gopls/internal/util/astutil/util.go b/gopls/internal/util/astutil/util.go index ac7515d1daf..5a9f993449a 100644 --- a/gopls/internal/util/astutil/util.go +++ b/gopls/internal/util/astutil/util.go @@ -7,6 +7,7 @@ package astutil import ( "go/ast" "go/token" + "reflect" "golang.org/x/tools/internal/typeparams" ) @@ -69,3 +70,86 @@ L: // unpack receiver type func NodeContains(n ast.Node, pos token.Pos) bool { return n.Pos() <= pos && pos <= n.End() } + +// Equal recursively compares two nodes for structural equality, +// ignoring fields of type [token.Pos] and [ast.Object]. +// The operands x and y may be nil. A nil slice is not equal to an empty slice. +// The provided identical function reports whether two identifiers should be considered identical. +func Equal(x, y ast.Node, identical func(x, y *ast.Ident) bool) bool { + if x == nil || y == nil { + return x == y + } + return equal(reflect.ValueOf(x), reflect.ValueOf(y), identical) +} + +func equal(x, y reflect.Value, identical func(x, y *ast.Ident) bool) bool { + // Ensure types are the same + if x.Type() != y.Type() { + return false + } + switch x.Kind() { + case reflect.Pointer: + if x.IsNil() || y.IsNil() { + return x.IsNil() == y.IsNil() + } + switch t := x.Interface().(type) { + // Skip fields of types potentially involved in cycles. + case *ast.Object, *ast.Scope, *ast.CommentGroup: + return true + case *ast.Ident: + return identical(t, y.Interface().(*ast.Ident)) + default: + return equal(x.Elem(), y.Elem(), identical) + } + + case reflect.Interface: + if x.IsNil() || y.IsNil() { + return x.IsNil() == y.IsNil() + } + return equal(x.Elem(), y.Elem(), identical) + + case reflect.Struct: + for i := range x.NumField() { + xf := x.Field(i) + yf := y.Field(i) + // Skip position fields. + if xpos, ok := xf.Interface().(token.Pos); ok { + ypos := yf.Interface().(token.Pos) + // Numeric value of a Pos is not significant but its "zeroness" is, + // because it is often significant, e.g. CallExpr.Variadic(Ellipsis), ChanType.Arrow. + if xpos.IsValid() != ypos.IsValid() { + return false + } + } else if !equal(xf, yf, identical) { + return false + } + } + return true + + case reflect.Slice: + if x.IsNil() || y.IsNil() { + return x.IsNil() == y.IsNil() + } + if x.Len() != y.Len() { + return false + } + for i := range x.Len() { + if !equal(x.Index(i), y.Index(i), identical) { + return false + } + } + return true + + case reflect.String: + return x.String() == y.String() + + case reflect.Bool: + return x.Bool() == y.Bool() + + case reflect.Int: + return x.Int() == y.Int() + + default: + panic(x) + } +} diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index fe67b0fa27a..58615232ff9 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -65,90 +65,6 @@ func TypeErrorEndPos(fset *token.FileSet, src []byte, start token.Pos) token.Pos return end } -// StmtToInsertVarBefore returns the ast.Stmt before which we can -// safely insert a new var declaration, or nil if the path denotes a -// node outside any statement. -// -// Basic Example: -// -// z := 1 -// y := z + x -// -// If x is undeclared, then this function would return `y := z + x`, so that we -// can insert `x := ` on the line before `y := z + x`. -// -// If stmt example: -// -// if z == 1 { -// } else if z == y {} -// -// If y is undeclared, then this function would return `if z == 1 {`, because we cannot -// insert a statement between an if and an else if statement. As a result, we need to find -// the top of the if chain to insert `y := ` before. -func StmtToInsertVarBefore(path []ast.Node) ast.Stmt { - enclosingIndex := -1 - for i, p := range path { - if _, ok := p.(ast.Stmt); ok { - enclosingIndex = i - break - } - } - if enclosingIndex == -1 { - return nil // no enclosing statement: outside function - } - enclosingStmt := path[enclosingIndex] - switch enclosingStmt.(type) { - case *ast.IfStmt: - // The enclosingStmt is inside of the if declaration, - // We need to check if we are in an else-if stmt and - // get the base if statement. - // TODO(adonovan): for non-constants, it may be preferable - // to add the decl as the Init field of the innermost - // enclosing ast.IfStmt. - return baseIfStmt(path, enclosingIndex) - case *ast.CaseClause: - // Get the enclosing switch stmt if the enclosingStmt is - // inside of the case statement. - for i := enclosingIndex + 1; i < len(path); i++ { - if node, ok := path[i].(*ast.SwitchStmt); ok { - return node - } else if node, ok := path[i].(*ast.TypeSwitchStmt); ok { - return node - } - } - } - if len(path) <= enclosingIndex+1 { - return enclosingStmt.(ast.Stmt) - } - // Check if the enclosing statement is inside another node. - switch expr := path[enclosingIndex+1].(type) { - case *ast.IfStmt: - // Get the base if statement. - return baseIfStmt(path, enclosingIndex+1) - case *ast.ForStmt: - if expr.Init == enclosingStmt || expr.Post == enclosingStmt { - return expr - } - case *ast.SwitchStmt, *ast.TypeSwitchStmt: - return expr.(ast.Stmt) - } - return enclosingStmt.(ast.Stmt) -} - -// baseIfStmt walks up the if/else-if chain until we get to -// the top of the current if chain. -func baseIfStmt(path []ast.Node, index int) ast.Stmt { - stmt := path[index] - for i := index + 1; i < len(path); i++ { - if node, ok := path[i].(*ast.IfStmt); ok && node.Else == stmt { - stmt = node - continue - } - break - } - return stmt.(ast.Stmt) -} - // WalkASTWithParent walks the AST rooted at n. The semantics are // similar to ast.Inspect except it does not call f(nil). func WalkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) {