Skip to content

Commit

Permalink
replace_all
Browse files Browse the repository at this point in the history
  • Loading branch information
xzbdmw committed Oct 31, 2024
1 parent 386503d commit 00f4d8a
Show file tree
Hide file tree
Showing 9 changed files with 572 additions and 33 deletions.
10 changes: 10 additions & 0 deletions gopls/internal/golang/codeaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ var codeActionProducers = [...]codeActionProducer{
{kind: settings.RefactorExtractMethod, fn: refactorExtractMethod},
{kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile},
{kind: settings.RefactorExtractVariable, fn: refactorExtractVariable},
{kind: settings.RefactorReplaceAllOccursOfExpr, fn: refactorReplaceAllOccursOfExpr},
{kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true},
{kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote},
{kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true},
Expand Down Expand Up @@ -458,6 +459,15 @@ func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error
return nil
}

// refactorReplaceAllOccursOfExpr produces "Replace all occcurrances of expr" code action.
// See [replaceAllOccursOfExpr] for command implementation.
func refactorReplaceAllOccursOfExpr(ctx context.Context, req *codeActionsRequest) error {
if _, ok, _ := allOccurs(req.start, req.end, req.pgf.File); ok {
req.addApplyFixAction(fmt.Sprintf("Replace all occcurrances of expression"), fixReplaceAllOccursOfExpr, req.loc)
}
return nil
}

// refactorExtractToNewFile produces "Extract declarations to new file" code actions.
// See [server.commandHandler.ExtractToNewFile] for command implementation.
func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) error {
Expand Down
343 changes: 340 additions & 3 deletions gopls/internal/golang/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
// TODO: stricter rules for selectorExpr.
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0)
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0)
lhsNames = append(lhsNames, lhsName)
case *ast.CallExpr:
tup, ok := info.TypeOf(expr).(*types.Tuple)
if !ok {
// If the call expression only has one return value, we can treat it the
// same as our standard extract variable case.
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0)
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0)
lhsNames = append(lhsNames, lhsName)
break
}
idx := 0
for i := 0; i < tup.Len(); i++ {
// Generate a unique variable for each return value.
var lhsName string
lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx)
lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", idx)
lhsNames = append(lhsNames, lhsName)
}
default:
Expand Down Expand Up @@ -105,6 +105,343 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
}, nil
}

func replaceAllOccursOfExpr(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) {
tokFile := fset.File(file.Pos())
exprs, _, err := allOccurs(start, end, file)
if err != nil {
return nil, nil, fmt.Errorf("extractVariable: cannot extract %s: %v", safetoken.StartPosition(fset, start), err)
}

scopes := make([][]*types.Scope, len(exprs))
for i, e := range exprs {
path, _ := astutil.PathEnclosingInterval(file, e.Pos(), e.End())
scopes[i] = CollectScopes(info, path, e.Pos())
}

// Find the deepest common scope among all expressions.
commonScope, err := findDeepestCommonScope(scopes)
if err != nil {
return nil, nil, fmt.Errorf("extractVariable: %v", err)
}

var innerScopes []*types.Scope
for _, scope := range scopes {
for _, s := range scope {
if s != nil {
innerScopes = append(innerScopes, s)
break
}
}
}
if len(innerScopes) != len(exprs) {
return nil, nil, fmt.Errorf("extractVariable: nil scope")
}
// So the largest scope's name won't conflict.
innerScopes = append(innerScopes, commonScope)

// Create new AST node for extracted code.
var lhsNames []string
switch expr := exprs[0].(type) {
// TODO: stricter rules for selectorExpr.
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
lhsName, _ := generateAvailableIdentifierForAllScopes(innerScopes, "newVar", 0)
lhsNames = append(lhsNames, lhsName)
case *ast.CallExpr:
tup, ok := info.TypeOf(expr).(*types.Tuple)
if !ok {
// If the call expression only has one return value, we can treat it the
// same as our standard extract variable case.
lhsName, _ := generateAvailableIdentifierForAllScopes(innerScopes, "newVar", 0)
lhsNames = append(lhsNames, lhsName)
break
}
idx := 0
for i := 0; i < tup.Len(); i++ {
// Generate a unique variable for each return value.
var lhsName string
lhsName, idx = generateAvailableIdentifierForAllScopes(innerScopes, "newVar", idx)
lhsNames = append(lhsNames, lhsName)
}
default:
return nil, nil, fmt.Errorf("cannot extract %T", expr)
}

var validPath []ast.Node
if commonScope != innerScopes[0] {
// This means the first expr within function body is not the largest scope,
// we need to find the scope immediately follow the common
// scope where we will insert the statement before.
child := innerScopes[0]
for p := child; p != nil; p = p.Parent() {
if p == commonScope {
break
}
child = p
}
validPath, _ = astutil.PathEnclosingInterval(file, child.Pos(), child.End())
} else {
// Just insert before the first expr.
validPath, _ = astutil.PathEnclosingInterval(file, exprs[0].Pos(), exprs[0].End())
}
//
// TODO: There is a bug here: for a variable declared in a labeled
// switch/for statement it returns the for/switch statement itself
// which produces the below code which is a compiler error e.g.
// label:
// switch r1 := r() { ... break label ... }
// On extracting "r()" to a variable
// label:
// x := r()
// switch r1 := x { ... break label ... } // compiler error
//
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(validPath)
if insertBeforeStmt == nil {
return nil, nil, fmt.Errorf("cannot find location to insert extraction")
}
indent, err := calculateIndentation(src, tokFile, insertBeforeStmt)
if err != nil {
return nil, nil, err
}
newLineIndent := "\n" + indent

lhs := strings.Join(lhsNames, ", ")
assignStmt := &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(lhs)},
Tok: token.DEFINE,
Rhs: []ast.Expr{exprs[0]},
}
var buf bytes.Buffer
if err := format.Node(&buf, fset, assignStmt); err != nil {
return nil, nil, err
}
assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
var textEdits []analysis.TextEdit
textEdits = append(textEdits, analysis.TextEdit{
Pos: insertBeforeStmt.Pos(),
End: insertBeforeStmt.Pos(),
NewText: []byte(assignment),
})
for _, e := range exprs {
textEdits = append(textEdits, analysis.TextEdit{
Pos: e.Pos(),
End: e.End(),
NewText: []byte(lhs),
})
}
return fset, &analysis.SuggestedFix{
TextEdits: textEdits,
}, nil
}

// findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains.
// Each scope chain represents the scopes of an expression from innermost to outermost.
// If no common scope is found, it returns an error.
func findDeepestCommonScope(scopeChains [][]*types.Scope) (*types.Scope, error) {
if len(scopeChains) == 0 {
return nil, fmt.Errorf("no scopes provided")
}
// Get the first scope chain as the reference.
referenceChain := scopeChains[0]

// Iterate from innermost to outermost scope.
for i := 0; i < len(referenceChain); i++ {
candidateScope := referenceChain[i]
if candidateScope == nil {
continue
}
isCommon := true
// See if other exprs' chains all have candidateScope as a common ancestor.
for _, chain := range scopeChains[1:] {
found := false
for j := 0; j < len(chain); j++ {
if chain[j] == candidateScope {
found = true
break
}
}
if !found {
isCommon = false
break
}
}
if isCommon {
return candidateScope, nil
}
}
return nil, fmt.Errorf("no common scope found")
}

// allOccurs finds all occurrences of an expression identical to the one
// specified by the start and end positions within the same function.
// It returns at least one ast.Expr.
func allOccurs(start, end token.Pos, file *ast.File) ([]ast.Expr, bool, error) {
if start == end {
return nil, false, fmt.Errorf("start and end are equal")
}
path, _ := astutil.PathEnclosingInterval(file, start, end)
if len(path) == 0 {
return nil, false, fmt.Errorf("no path enclosing interval")
}
for _, n := range path {
if _, ok := n.(*ast.ImportSpec); ok {
return nil, false, fmt.Errorf("cannot extract variable in an import block")
}
}
node := path[0]
if start != node.Pos() || end != node.End() {
return nil, false, fmt.Errorf("range does not map to an AST node")
}
expr, ok := node.(ast.Expr)
if !ok {
return nil, false, fmt.Errorf("node is not an expression")
}

var exprs []ast.Expr
exprs = append(exprs, expr)
if funcDecl, ok := path[len(path)-2].(*ast.FuncDecl); ok {
ast.Inspect(funcDecl, func(n ast.Node) bool {
if e, ok := n.(ast.Expr); ok && e != expr {
if exprIdentical(e, expr) {
exprs = append(exprs, e)
}
}
return true
})
}
sort.Slice(exprs, func(i, j int) bool {
return exprs[i].Pos() < exprs[j].Pos()
})

switch expr.(type) {
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
return exprs, true, nil
}
return nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
}

// generateAvailableIdentifierForAllScopes adjusts the new identifier name
// until there are no collisions in any of the provided scopes.
func generateAvailableIdentifierForAllScopes(scopes []*types.Scope, prefix string, idx int) (string, int) {
name := prefix
for {
collision := false
for _, scope := range scopes {
if scope.Lookup(name) != nil {
collision = true
break
}
}
if !collision {
return name, idx
}
idx++
name = fmt.Sprintf("%s%d", prefix, idx)
}
}

// exprIdentical recursively compares two ast.Expr nodes for structural equality,
// ignoring position fields.
func exprIdentical(x, y ast.Expr) bool {
if x == nil || y == nil {
return x == y
}
switch x := x.(type) {
case *ast.BasicLit:
y, ok := y.(*ast.BasicLit)
return ok && x.Kind == y.Kind && x.Value == y.Value
case *ast.CompositeLit:
y, ok := y.(*ast.CompositeLit)
if !ok || len(x.Elts) != len(y.Elts) || !exprIdentical(x.Type, y.Type) {
return false
}
for i := range x.Elts {
if !exprIdentical(x.Elts[i], y.Elts[i]) {
return false
}
}
return true
case *ast.ArrayType:
y, ok := y.(*ast.ArrayType)
return ok && exprIdentical(x.Len, y.Len) && exprIdentical(x.Elt, y.Elt)
case *ast.Ellipsis:
y, ok := y.(*ast.Ellipsis)
return ok && exprIdentical(x.Elt, y.Elt)
case *ast.FuncLit:
y, ok := y.(*ast.FuncLit)
return ok && exprIdentical(x.Type, y.Type)
case *ast.IndexExpr:
y, ok := y.(*ast.IndexExpr)
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Index, y.Index)
case *ast.IndexListExpr:
y, ok := y.(*ast.IndexListExpr)
if !ok || len(x.Indices) != len(y.Indices) || !exprIdentical(x.X, y.X) {
return false
}
for i := range x.Indices {
if !exprIdentical(x.Indices[i], y.Indices[i]) {
return false
}
}
return true
case *ast.SliceExpr:
y, ok := y.(*ast.SliceExpr)
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Low, y.Low) && exprIdentical(x.High, y.High) && exprIdentical(x.Max, y.Max) && x.Slice3 == y.Slice3
case *ast.TypeAssertExpr:
y, ok := y.(*ast.TypeAssertExpr)
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Type, y.Type)
case *ast.StarExpr:
y, ok := y.(*ast.StarExpr)
return ok && exprIdentical(x.X, y.X)
case *ast.KeyValueExpr:
y, ok := y.(*ast.KeyValueExpr)
return ok && exprIdentical(x.Key, y.Key) && exprIdentical(x.Value, y.Value)
case *ast.UnaryExpr:
y, ok := y.(*ast.UnaryExpr)
return ok && x.Op == y.Op && exprIdentical(x.X, y.X)
case *ast.MapType:
y, ok := y.(*ast.MapType)
return ok && exprIdentical(x.Value, y.Value) && exprIdentical(x.Key, y.Key)
case *ast.ChanType:
y, ok := y.(*ast.ChanType)
return ok && exprIdentical(x.Value, y.Value) && x.Dir == y.Dir
case *ast.BinaryExpr:
y, ok := y.(*ast.BinaryExpr)
return ok && x.Op == y.Op &&
exprIdentical(x.X, y.X) &&
exprIdentical(x.Y, y.Y)
case *ast.Ident:
y, ok := y.(*ast.Ident)
return ok && x.Name == y.Name
case *ast.ParenExpr:
y, ok := y.(*ast.ParenExpr)
return ok && exprIdentical(x.X, y.X)
case *ast.SelectorExpr:
y, ok := y.(*ast.SelectorExpr)
return ok &&
exprIdentical(x.X, y.X) &&
exprIdentical(x.Sel, y.Sel)
case *ast.CallExpr:
y, ok := y.(*ast.CallExpr)
if !ok || len(x.Args) != len(y.Args) {
return false
}
if !exprIdentical(x.Fun, y.Fun) {
return false
}
for i := range x.Args {
if !exprIdentical(x.Args[i], y.Args[i]) {
return false
}
}
return true
default:
// For unhandled expression types, consider them unequal.
return false
}
}

// canExtractVariable reports whether the code in the given range can be
// extracted to a variable.
func canExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
Expand Down
Loading

0 comments on commit 00f4d8a

Please sign in to comment.