diff --git a/internal/analysisinternal/analysis.go b/internal/analysisinternal/analysis.go index 098298c0ead..a194f533902 100644 --- a/internal/analysisinternal/analysis.go +++ b/internal/analysisinternal/analysis.go @@ -84,9 +84,95 @@ func TypeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Ty default: return ast.NewIdent(t.Name()) } + case *types.Pointer: + x := TypeExpr(fset, f, pkg, t.Elem()) + if x == nil { + return nil + } + return &ast.UnaryExpr{ + Op: token.MUL, + X: x, + } + case *types.Array: + elt := TypeExpr(fset, f, pkg, t.Elem()) + if elt == nil { + return nil + } + return &ast.ArrayType{ + Len: &ast.BasicLit{ + Kind: token.INT, + Value: fmt.Sprintf("%d", t.Len()), + }, + Elt: elt, + } + case *types.Slice: + elt := TypeExpr(fset, f, pkg, t.Elem()) + if elt == nil { + return nil + } + return &ast.ArrayType{ + Elt: elt, + } + case *types.Map: + key := TypeExpr(fset, f, pkg, t.Key()) + value := TypeExpr(fset, f, pkg, t.Elem()) + if key == nil || value == nil { + return nil + } + return &ast.MapType{ + Key: key, + Value: value, + } + case *types.Chan: + dir := ast.ChanDir(t.Dir()) + if t.Dir() == types.SendRecv { + dir = ast.SEND | ast.RECV + } + value := TypeExpr(fset, f, pkg, t.Elem()) + if value == nil { + return nil + } + return &ast.ChanType{ + Dir: dir, + Value: value, + } + case *types.Signature: + var params []*ast.Field + for i := 0; i < t.Params().Len(); i++ { + p := TypeExpr(fset, f, pkg, t.Params().At(i).Type()) + if p == nil { + return nil + } + params = append(params, &ast.Field{ + Type: p, + Names: []*ast.Ident{ + { + Name: t.Params().At(i).Name(), + }, + }, + }) + } + var returns []*ast.Field + for i := 0; i < t.Results().Len(); i++ { + r := TypeExpr(fset, f, pkg, t.Results().At(i).Type()) + if r == nil { + return nil + } + returns = append(returns, &ast.Field{ + Type: r, + }) + } + return &ast.FuncType{ + Params: &ast.FieldList{ + List: params, + }, + Results: &ast.FieldList{ + List: returns, + }, + } case *types.Named: if t.Obj().Pkg() == nil { - return nil + return ast.NewIdent(t.Obj().Name()) } if t.Obj().Pkg() == pkg { return ast.NewIdent(t.Obj().Name()) @@ -109,11 +195,6 @@ func TypeExpr(fset *token.FileSet, f *ast.File, pkg *types.Package, typ types.Ty X: ast.NewIdent(pkgName), Sel: ast.NewIdent(t.Obj().Name()), } - case *types.Pointer: - return &ast.UnaryExpr{ - Op: token.MUL, - X: TypeExpr(fset, f, pkg, t.Elem()), - } default: return nil // TODO: anonymous structs, but who does that } diff --git a/internal/lsp/cmd/test/cmdtest.go b/internal/lsp/cmd/test/cmdtest.go index 01f3951810f..cb911cb42fd 100644 --- a/internal/lsp/cmd/test/cmdtest.go +++ b/internal/lsp/cmd/test/cmdtest.go @@ -133,6 +133,10 @@ func (r *runner) RankCompletion(t *testing.T, src span.Span, test tests.Completi //TODO: add command line completions tests when it works } +func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span) { + //TODO: function extraction not supported on command line +} + func (r *runner) runGoplsCmd(t testing.TB, args ...string) (string, string) { rStdout, wStdout, err := os.Pipe() if err != nil { diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go index 618b0f900bc..cdeb263ece2 100644 --- a/internal/lsp/code_action.go +++ b/internal/lsp/code_action.go @@ -361,22 +361,34 @@ func extractionFixes(ctx context.Context, snapshot source.Snapshot, ph source.Pa if err != nil { return nil, nil } + var actions []protocol.CodeAction edits, err := source.ExtractVariable(ctx, snapshot, fh, rng) if err != nil { return nil, err } - if len(edits) == 0 { - return nil, nil - } - return []protocol.CodeAction{ - { + if len(edits) > 0 { + actions = append(actions, protocol.CodeAction{ Title: "Extract to variable", Kind: protocol.RefactorExtract, Edit: protocol.WorkspaceEdit{ DocumentChanges: documentChanges(fh, edits), }, - }, - }, nil + }) + } + edits, err = source.ExtractFunction(ctx, snapshot, fh, rng) + if err != nil { + return nil, err + } + if len(edits) > 0 { + actions = append(actions, protocol.CodeAction{ + Title: "Extract to function", + Kind: protocol.RefactorExtract, + Edit: protocol.WorkspaceEdit{ + DocumentChanges: documentChanges(fh, edits), + }, + }) + } + return actions, nil } func documentChanges(fh source.FileHandle, edits []protocol.TextEdit) []protocol.TextDocumentEdit { diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 620c936ab82..535bed36164 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -433,6 +433,52 @@ func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string) } } +func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span) { + uri := start.URI() + _, err := r.server.session.ViewOf(uri) + if err != nil { + t.Fatal(err) + } + m, err := r.data.Mapper(uri) + if err != nil { + t.Fatal(err) + } + spn := span.New(start.URI(), start.Start(), end.End()) + rng, err := m.Range(spn) + if err != nil { + t.Fatal(err) + } + actions, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.URIFromSpanURI(uri), + }, + Range: rng, + Context: protocol.CodeActionContext{ + Only: []protocol.CodeActionKind{"refactor.extract"}, + }, + }) + if err != nil { + t.Fatal(err) + } + // Hack: We assume that we only get one code action per range. + // TODO(rstambler): Support multiple code actions per test. + if len(actions) == 0 || len(actions) > 1 { + t.Fatalf("unexpected number of code actions, want 1, got %v", len(actions)) + } + res, err := applyWorkspaceEdits(r, actions[0].Edit) + if err != nil { + t.Fatal(err) + } + for u, got := range res { + want := string(r.data.Golden("functionextraction_"+tests.SpanName(spn), u.Filename(), func() ([]byte, error) { + return []byte(got), nil + })) + if want != got { + t.Errorf("function extraction failed for %s:\n%s", u.Filename(), tests.Diff(want, got)) + } + } +} + func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { sm, err := r.data.Mapper(d.Src.URI()) if err != nil { diff --git a/internal/lsp/source/extract.go b/internal/lsp/source/extract.go index fb53e61727e..30bb4bd2648 100644 --- a/internal/lsp/source/extract.go +++ b/internal/lsp/source/extract.go @@ -10,8 +10,11 @@ import ( "fmt" "go/ast" "go/format" + "go/parser" "go/token" "go/types" + "strings" + "unicode" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/internal/analysisinternal" @@ -36,6 +39,9 @@ func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, prot if err != nil { return nil, err } + if rng.Start == rng.End { + return nil, nil + } path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) if len(path) == 0 { return nil, nil @@ -53,25 +59,17 @@ func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, prot if rng.Start != node.Pos() || rng.End != node.End() { return nil, nil } - - // Adjust new variable name until no collisons in scope. - scopes := collectScopes(pkg, path, node.Pos()) - name := "x0" - idx := 0 - for !isValidName(name, scopes) { - idx++ - name = fmt.Sprintf("x%d", idx) - } + name := generateAvailableIdentifier(node.Pos(), pkg, path, file) var assignment string expr, ok := node.(ast.Expr) if !ok { return nil, nil } - // Create new AST node for extracted code + // Create new AST node for extracted code. switch expr.(type) { case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, - *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: // TODO: stricter rules for selectorExpr + *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: // TODO: stricter rules for selectorExpr. assignStmt := &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent(name)}, Tok: token.DEFINE, @@ -93,7 +91,7 @@ func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, prot return nil, nil } - // Convert token.Pos to protcol.Position + // Convert token.Pos to protocol.Position. rng = span.NewRange(fset, insertBeforeStmt.Pos(), insertBeforeStmt.End()) spn, err = rng.Span() if err != nil { @@ -107,17 +105,12 @@ func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, prot Start: beforeStmtStart, End: beforeStmtStart, } - - // Calculate indentation for insertion - line := tok.Line(insertBeforeStmt.Pos()) - lineOffset := tok.Offset(tok.LineStart(line)) - stmtOffset := tok.Offset(insertBeforeStmt.Pos()) - indent := content[lineOffset:stmtOffset] // space between these is indentation. + indent := calculateIndentation(content, tok, insertBeforeStmt) return []protocol.TextEdit{ { Range: stmtBeforeRng, - NewText: assignment + "\n" + string(indent), + NewText: assignment + "\n" + indent, }, { Range: protoRng, @@ -126,6 +119,17 @@ func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, prot }, nil } +// Calculate indentation for insertion. +// When inserting lines of code, we must ensure that the lines have consistent +// formatting (i.e. the proper indentation). To do so, we observe the indentation on the +// line of code on which the insertion occurs. +func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string { + line := tok.Line(insertBeforeStmt.Pos()) + lineOffset := tok.Offset(tok.LineStart(line)) + stmtOffset := tok.Offset(insertBeforeStmt.Pos()) + return string(content[lineOffset:stmtOffset]) +} + // Check for variable collision in scope. func isValidName(name string, scopes []*types.Scope) bool { for _, scope := range scopes { @@ -138,3 +142,503 @@ func isValidName(name string, scopes []*types.Scope) bool { } return true } + +// ExtractFunction refactors the selected block of code into a new function. It also +// replaces the selected block of code with a call to the extracted function. First, we +// manually adjust the selection range. We remove trailing and leading whitespace +// characters to ensure the range is precisely bounded by AST nodes. Next, we +// determine the variables that will be the paramters and return values of the +// extracted function. Lastly, we construct the call of the function and insert +// this call as well as the extracted function into their proper locations. +func ExtractFunction(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) { + pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle) + if err != nil { + return nil, fmt.Errorf("ExtractFunction: %v", err) + } + file, _, m, _, err := pgh.Cached() + if err != nil { + return nil, err + } + spn, err := m.RangeSpan(protoRng) + if err != nil { + return nil, err + } + rng, err := spn.Range(m.Converter) + if err != nil { + return nil, err + } + if rng.Start == rng.End { + return nil, nil + } + content, err := fh.Read() + if err != nil { + return nil, err + } + fset := snapshot.View().Session().Cache().FileSet() + tok := fset.File(file.Pos()) + if tok == nil { + return nil, fmt.Errorf("ExtractFunction: no token.File for %s", fh.URI()) + } + rng = adjustRangeForWhitespace(content, tok, rng) + path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) + if len(path) == 0 { + return nil, nil + } + // Node that encloses selection must be a statement. + // TODO: Support function extraction for an expression. + if _, ok := path[0].(ast.Stmt); !ok { + return nil, nil + } + info := pkg.GetTypesInfo() + if info == nil { + return nil, fmt.Errorf("nil TypesInfo") + } + fileScope := info.Scopes[file] + if fileScope == nil { + return nil, nil + } + pkgScope := fileScope.Parent() + if pkgScope == nil { + return nil, nil + } + // Find function enclosing the selection. + var outer *ast.FuncDecl + for _, p := range path { + if p, ok := p.(*ast.FuncDecl); ok { + outer = p + break + } + } + if outer == nil { + return nil, nil + } + // At the moment, we don't extract selections containing return statements, + // as they are more complex and need to be adjusted to maintain correctness. + // TODO: Support extracting and rewriting code with return statements. + var containsReturn bool + ast.Inspect(outer, func(n ast.Node) bool { + if n == nil { + return true + } + if rng.Start <= n.Pos() && n.End() <= rng.End { + if _, ok := n.(*ast.ReturnStmt); ok { + containsReturn = true + return false + } + } + return n.Pos() <= rng.End + }) + if containsReturn { + return nil, nil + } + // Find the nodes at the start and end of the selection. + var start, end ast.Node + ast.Inspect(outer, func(n ast.Node) bool { + if n == nil { + return true + } + if n.Pos() == rng.Start && n.End() <= rng.End { + start = n + } + if n.End() == rng.End && n.Pos() >= rng.Start { + end = n + } + return n.Pos() <= rng.End + }) + if start == nil || end == nil { + return nil, nil + } + + // Now that we have determined the correct range for the selection block, + // we must determine the signature of the extracted function. We will then replace + // the block with an assignment statement that calls the extracted function with + // the appropriate parameters and return values. + free, vars, assigned := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0]) + + var ( + params, returns []ast.Expr // used when calling the extracted function + paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function + uninitialized []types.Object // vars we will need to initialize before the call + ) + + // Avoid duplicates while traversing vars and uninitialzed. + seenVars := make(map[types.Object]ast.Expr) + seenUninitialized := make(map[types.Object]struct{}) + + // Each identifier in the selected block must become (1) a parameter to the + // extracted function, (2) a return value of the extracted function, or (3) a local + // variable in the extracted function. Determine the outcome(s) for each variable + // based on whether it is free, altered within the selected block, and used outside + // of the selected block. + for _, obj := range vars { + if _, ok := seenVars[obj]; ok { + continue + } + typ := analysisinternal.TypeExpr(fset, file, pkg.GetTypes(), obj.Type()) + if typ == nil { + return nil, fmt.Errorf("nil AST expression for type: %v", obj.Name()) + } + seenVars[obj] = typ + identifier := ast.NewIdent(obj.Name()) + // An identifier must meet two conditions to become a return value of the + // extracted function. (1) it must be used at least once after the + // selection (isUsed), and (2) its value must be initialized or reassigned + // within the selection (isAssigned). + isUsed := objUsed(obj, info, rng.End, obj.Parent().End()) + _, isAssigned := assigned[obj] + _, isFree := free[obj] + if isUsed && isAssigned { + returnTypes = append(returnTypes, &ast.Field{Type: typ}) + returns = append(returns, identifier) + if !isFree { + uninitialized = append(uninitialized, obj) + } + } + // All free variables are parameters of and passed as arguments to the + // extracted function. + if isFree { + params = append(params, identifier) + paramTypes = append(paramTypes, &ast.Field{ + Names: []*ast.Ident{identifier}, + Type: typ, + }) + } + } + + // Our preference is to replace the selected block with an "x, y, z := fn()" style + // assignment statement. We can use this style when none of the variables in the + // extracted function's return statement have already be initialized outside of the + // selected block. However, for example, if z is already defined elsewhere, we + // replace the selected block with: + // + // var x int + // var y string + // x, y, z = fn() + // + var initializations string + if len(uninitialized) > 0 && len(uninitialized) != len(returns) { + var declarations []ast.Stmt + for _, obj := range uninitialized { + if _, ok := seenUninitialized[obj]; ok { + continue + } + seenUninitialized[obj] = struct{}{} + valSpec := &ast.ValueSpec{ + Names: []*ast.Ident{ast.NewIdent(obj.Name())}, + Type: seenVars[obj], + } + genDecl := &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{valSpec}, + } + declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) + } + var declBuf bytes.Buffer + if err = format.Node(&declBuf, fset, declarations); err != nil { + return nil, err + } + indent := calculateIndentation(content, tok, start) + // Add proper indentation to each declaration. Also add formatting to + // the line following the last initialization to ensure that subsequent + // edits begin at the proper location. + initializations = strings.ReplaceAll(declBuf.String(), "\n", "\n"+indent) + + "\n" + indent + } + + name := generateAvailableIdentifier(start.Pos(), pkg, path, file) + var replace ast.Node + if len(returns) > 0 { + // If none of the variables on the left-hand side of the function call have + // been initialized before the selection, we can use := instead of =. + assignTok := token.ASSIGN + if len(uninitialized) == len(returns) { + assignTok = token.DEFINE + } + callExpr := &ast.CallExpr{ + Fun: ast.NewIdent(name), + Args: params, + } + replace = &ast.AssignStmt{ + Lhs: returns, + Tok: assignTok, + Rhs: []ast.Expr{callExpr}, + } + } else { + replace = &ast.CallExpr{ + Fun: ast.NewIdent(name), + Args: params, + } + } + + startOffset := tok.Offset(rng.Start) + endOffset := tok.Offset(rng.End) + selection := content[startOffset:endOffset] + // Put selection in constructed file to parse and produce block statement. We can + // then use the block statement to traverse and edit extracted function without + // altering the original file. + text := "package main\nfunc _() { " + string(selection) + " }" + extract, err := parser.ParseFile(fset, "", text, 0) + if err != nil { + return nil, err + } + if len(extract.Decls) == 0 { + return nil, fmt.Errorf("parsed file does not contain any declarations") + } + decl, ok := extract.Decls[0].(*ast.FuncDecl) + if !ok { + return nil, fmt.Errorf("parsed file does not contain expected function declaration") + } + // Add return statement to the end of the new function. + if len(returns) > 0 { + decl.Body.List = append(decl.Body.List, + &ast.ReturnStmt{Results: returns}, + ) + } + funcDecl := &ast.FuncDecl{ + Name: ast.NewIdent(name), + Type: &ast.FuncType{ + Params: &ast.FieldList{List: paramTypes}, + Results: &ast.FieldList{List: returnTypes}, + }, + Body: decl.Body, + } + + var replaceBuf, newFuncBuf bytes.Buffer + if err := format.Node(&replaceBuf, fset, replace); err != nil { + return nil, err + } + if err := format.Node(&newFuncBuf, fset, funcDecl); err != nil { + return nil, err + } + + outerStart := tok.Offset(outer.Pos()) + outerEnd := tok.Offset(outer.End()) + // We're going to replace the whole enclosing function, + // so preserve the text before and after the selected block. + before := content[outerStart:startOffset] + after := content[endOffset:outerEnd] + var fullReplacement strings.Builder + fullReplacement.Write(before) + fullReplacement.WriteString(initializations) // add any initializations, if needed + fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function + fullReplacement.Write(after) + fullReplacement.WriteString("\n\n") // add newlines after the enclosing function + fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function + + // Convert enclosing function's span.Range to protocol.Range. + rng = span.NewRange(fset, outer.Pos(), outer.End()) + spn, err = rng.Span() + if err != nil { + return nil, nil + } + startFunc, err := m.Position(spn.Start()) + if err != nil { + return nil, nil + } + endFunc, err := m.Position(spn.End()) + if err != nil { + return nil, nil + } + funcLoc := protocol.Range{ + Start: startFunc, + End: endFunc, + } + return []protocol.TextEdit{ + { + Range: funcLoc, + NewText: fullReplacement.String(), + }, + }, nil +} + +// collectFreeVars maps each identifier in the given range to whether it is "free." +// Given a range, a variable in that range is defined as "free" if it is declared +// outside of the range and neither at the file scope nor package scope. These free +// variables will be used as arguments in the extracted function. It also returns a +// list of identifiers that may need to be returned by the extracted function. +// Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. +func collectFreeVars(info *types.Info, file *ast.File, fileScope *types.Scope, + pkgScope *types.Scope, rng span.Range, node ast.Node) (map[types.Object]struct{}, []types.Object, map[types.Object]struct{}) { + // id returns non-nil if n denotes an object that is referenced by the span + // and defined either within the span or in the lexical environment. The bool + // return value acts as an indicator for where it was defined. + id := func(n *ast.Ident) (types.Object, bool) { + obj := info.Uses[n] + if obj == nil { + return info.Defs[n], false + } + if _, ok := obj.(*types.PkgName); ok { + return nil, false // imported package + } + if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) { + return nil, false // not defined in this file + } + scope := obj.Parent() + if scope == nil { + return nil, false // e.g. interface method, struct field + } + if scope == fileScope || scope == pkgScope { + return nil, false // defined at file or package scope + } + if rng.Start <= obj.Pos() && obj.Pos() <= rng.End { + return obj, false // defined within selection => not free + } + return obj, true + } + // sel returns non-nil if n denotes a selection o.x.y that is referenced by the + // span and defined either within the span or in the lexical environment. The bool + // return value acts as an indicator for where it was defined. + var sel func(n *ast.SelectorExpr) (types.Object, bool) + sel = func(n *ast.SelectorExpr) (types.Object, bool) { + switch x := astutil.Unparen(n.X).(type) { + case *ast.SelectorExpr: + return sel(x) + case *ast.Ident: + return id(x) + } + return nil, false + } + free := make(map[types.Object]struct{}) + var vars []types.Object + ast.Inspect(node, func(n ast.Node) bool { + if n == nil { + return true + } + if rng.Start <= n.Pos() && n.End() <= rng.End { + var obj types.Object + var isFree, prune bool + switch n := n.(type) { + case *ast.Ident: + obj, isFree = id(n) + case *ast.SelectorExpr: + obj, isFree = sel(n) + prune = true + } + if obj != nil && obj.Name() != "_" { + if isFree { + free[obj] = struct{}{} + } + vars = append(vars, obj) + if prune { + return false + } + } + } + return n.Pos() <= rng.End + }) + + // Find identifiers that are initialized or whose values are altered at some + // point in the selected block. For example, in a selected block from lines 2-4, + // variables x, y, and z are included in assigned. However, in a selected block + // from lines 3-4, only variables y and z are included in assigned. + // + // 1: var a int + // 2: var x int + // 3: y := 3 + // 4: z := x + a + // + assigned := make(map[types.Object]struct{}) + ast.Inspect(node, func(n ast.Node) bool { + if n == nil { + return true + } + if n.Pos() < rng.Start || n.End() > rng.End { + return n.Pos() <= rng.End + } + switch n := n.(type) { + case *ast.AssignStmt: + for _, assignment := range n.Lhs { + if assignment, ok := assignment.(*ast.Ident); ok { + obj, _ := id(assignment) + if obj == nil { + continue + } + assigned[obj] = struct{}{} + } + } + return false + case *ast.DeclStmt: + gen, ok := n.Decl.(*ast.GenDecl) + if !ok { + return true + } + for _, spec := range gen.Specs { + vSpecs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for _, vSpec := range vSpecs.Names { + obj, _ := id(vSpec) + if obj == nil { + continue + } + assigned[obj] = struct{}{} + } + } + return false + } + return true + }) + return free, vars, assigned +} + +// Adjust new function name until no collisons in scope. Possible collisions include +// other function and variable names. +func generateAvailableIdentifier(pos token.Pos, pkg Package, path []ast.Node, file *ast.File) string { + scopes := collectScopes(pkg, path, pos) + var idx int + name := "x0" + for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) { + idx++ + name = fmt.Sprintf("x%d", idx) + } + return name +} + +// adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or +// trailing whitespace characters from selection. In the following example, each line +// of the if statement is indented once. There are also two extra spaces after the +// closing bracket before the line break. +// +// \tif (true) { +// \t _ = 1 +// \t} \n +// +// By default, a valid range begins at 'if' and ends at the first whitespace character +// after the '}'. But, users are likely to highlight full lines rather than adjusting +// their cursors for whitespace. To support this use case, we must manually adjust the +// ranges to match the correct AST node. In this particular example, we would adjust +// rng.Start forward by one byte, and rng.End backwards by two bytes. +func adjustRangeForWhitespace(content []byte, tok *token.File, rng span.Range) span.Range { + offset := tok.Offset(rng.Start) + for offset < len(content) { + if !unicode.IsSpace(rune(content[offset])) { + break + } + // Move forwards one byte to find a non-whitespace character. + offset += 1 + } + rng.Start = tok.Pos(offset) + + offset = tok.Offset(rng.End) + for offset-1 >= 0 { + if !unicode.IsSpace(rune(content[offset-1])) { + break + } + // Move backwards one byte to find a non-whitespace character. + offset -= 1 + } + rng.End = tok.Pos(offset) + return rng +} + +// objUsed checks if the object is used after the selection but within +// the scope of the enclosing function. +func objUsed(obj types.Object, info *types.Info, endSel token.Pos, endScope token.Pos) bool { + for id, ob := range info.Uses { + if obj == ob && endSel < id.Pos() && id.End() <= endScope { + return true + } + } + return false +} diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index d6af8a12c1f..9a414e5903a 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -476,6 +476,8 @@ func (r *runner) Import(t *testing.T, spn span.Span) { func (r *runner) SuggestedFix(t *testing.T, spn span.Span, actionKinds []string) {} +func (r *runner) FunctionExtraction(t *testing.T, start span.Span, end span.Span) {} + func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { _, srcRng, err := spanToRange(r.data, d.Src) if err != nil { diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go new file mode 100644 index 00000000000..fc46f96883f --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go @@ -0,0 +1,10 @@ +package extract + +func _() { + a := 1 + a = 5 //@mark(s1, "a") + a = a + 2 //@mark(e1, "2") + //@extractfunc(s1, e1) + b := a * 2 + var _ = 3 + 4 +} diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden new file mode 100644 index 00000000000..73afc2ddead --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_args_returns.go.golden @@ -0,0 +1,17 @@ +-- functionextraction_extract_args_returns_5_2 -- +package extract + +func _() { + a := 1 + a = x0(a) //@mark(e1, "2") + //@extractfunc(s1, e1) + b := a * 2 + var _ = 3 + 4 +} + +func x0(a int) int { + a = 5 + a = a + 2 + return a +} + diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_basic.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_basic.go new file mode 100644 index 00000000000..32cbcf10c39 --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_basic.go @@ -0,0 +1,7 @@ +package extract + +func _() { + a := 1 //@mark(s0, "a") + var _ = 3 + 4 //@mark(e0, "4") + //@extractfunc(s0, e0) +} diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_basic.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_basic.go.golden new file mode 100644 index 00000000000..5a8fb438546 --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_basic.go.golden @@ -0,0 +1,13 @@ +-- functionextraction_extract_basic_4_2 -- +package extract + +func _() { + x0() //@mark(e0, "4") + //@extractfunc(s0, e0) +} + +func x0() { + a := 1 + var _ = 3 + 4 +} + diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_scope.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_scope.go new file mode 100644 index 00000000000..ee264ad5e50 --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_scope.go @@ -0,0 +1,10 @@ +package extract + +func _() { + x0 := 1 + a := x0 //@extractfunc("a", "x0") +} + +func x1() int { + return 1 +} diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_scope.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_scope.go.golden new file mode 100644 index 00000000000..1b3a5c3cf16 --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_scope.go.golden @@ -0,0 +1,16 @@ +-- functionextraction_extract_scope_5_2 -- +package extract + +func _() { + x0 := 1 + x2(x0) //@extractfunc("a", "x0") +} + +func x2(x0 int) { + a := x0 +} + +func x1() int { + return 1 +} + diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_initialization.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_initialization.go new file mode 100644 index 00000000000..1e33e13c8a6 --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_initialization.go @@ -0,0 +1,9 @@ +package extract + +func _() { + var a []int + a = append(a, 2) //@mark(s4, "a") + b := 4 //@mark(e4, "4") + //@extractfunc(s4, e4) + a = append(a, b) +} diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_initialization.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_initialization.go.golden new file mode 100644 index 00000000000..6392ceeedbb --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_initialization.go.golden @@ -0,0 +1,17 @@ +-- functionextraction_extract_smart_initialization_5_2 -- +package extract + +func _() { + var a []int + var b int + a, b = x0(a) //@mark(e4, "4") + //@extractfunc(s4, e4) + a = append(a, b) +} + +func x0(a []int) ([]int, int) { + a = append(a, 2) + b := 4 + return a, b +} + diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_return.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_return.go new file mode 100644 index 00000000000..5f0d28f7ffa --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_return.go @@ -0,0 +1,11 @@ +package extract + +func _() { + var b []int + var a int + a = 2 //@mark(s2, "a") + b = []int{} + b = append(b, a) //@mark(e2, ")") + b[0] = 1 + //@extractfunc(s2, e2) +} diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_return.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_return.go.golden new file mode 100644 index 00000000000..e94d4f29c8e --- /dev/null +++ b/internal/lsp/testdata/lsp/primarymod/extract/extract_function/extract_smart_return.go.golden @@ -0,0 +1,18 @@ +-- functionextraction_extract_smart_return_6_2 -- +package extract + +func _() { + var b []int + var a int + b = x0(a, b) //@mark(e2, ")") + b[0] = 1 + //@extractfunc(s2, e2) +} + +func x0(a int, b []int) []int { + a = 2 + b = []int{} + b = append(b, a) + return b +} + diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_basic_lit.go similarity index 100% rename from internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go rename to internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_basic_lit.go diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_basic_lit.go.golden similarity index 100% rename from internal/lsp/testdata/lsp/primarymod/extract/extract_basic_lit.go.golden rename to internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_basic_lit.go.golden diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go b/internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_scope.go similarity index 100% rename from internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go rename to internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_scope.go diff --git a/internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go.golden b/internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_scope.go.golden similarity index 100% rename from internal/lsp/testdata/lsp/primarymod/extract/extract_scope.go.golden rename to internal/lsp/testdata/lsp/primarymod/extract/extract_variable/extract_scope.go.golden diff --git a/internal/lsp/testdata/lsp/primarymod/fillstruct/fill_struct_package.go.golden b/internal/lsp/testdata/lsp/primarymod/fillstruct/fill_struct_package.go.golden index 852e48aff9c..a72afbded3d 100644 --- a/internal/lsp/testdata/lsp/primarymod/fillstruct/fill_struct_package.go.golden +++ b/internal/lsp/testdata/lsp/primarymod/fillstruct/fill_struct_package.go.golden @@ -27,8 +27,10 @@ func unexported() { a := data.A{} //@suggestedfix("}", "refactor.rewrite") _ = h2.Client{ Transport: nil, - Jar: nil, - Timeout: 0, + CheckRedirect: func(req *h2.Request, via []*h2.Request) error { + }, + Jar: nil, + Timeout: 0, } //@suggestedfix("}", "refactor.rewrite") } diff --git a/internal/lsp/testdata/lsp/summary.txt.golden b/internal/lsp/testdata/lsp/summary.txt.golden index 31f7f4311b2..8fdb792bd8f 100644 --- a/internal/lsp/testdata/lsp/summary.txt.golden +++ b/internal/lsp/testdata/lsp/summary.txt.golden @@ -12,6 +12,7 @@ FoldingRangesCount = 2 FormatCount = 6 ImportCount = 8 SuggestedFixCount = 18 +FunctionExtractionCount = 5 DefinitionsCount = 53 TypeDefinitionsCount = 2 HighlightsCount = 69 diff --git a/internal/lsp/testdata/missingdep/summary.txt.golden b/internal/lsp/testdata/missingdep/summary.txt.golden index 0c7b9bf3b3d..d8443720cce 100644 --- a/internal/lsp/testdata/missingdep/summary.txt.golden +++ b/internal/lsp/testdata/missingdep/summary.txt.golden @@ -12,6 +12,7 @@ FoldingRangesCount = 0 FormatCount = 0 ImportCount = 0 SuggestedFixCount = 1 +FunctionExtractionCount = 0 DefinitionsCount = 0 TypeDefinitionsCount = 0 HighlightsCount = 0 diff --git a/internal/lsp/testdata/missingtwodep/summary.txt.golden b/internal/lsp/testdata/missingtwodep/summary.txt.golden index 96ac4750a8a..ce246c03408 100644 --- a/internal/lsp/testdata/missingtwodep/summary.txt.golden +++ b/internal/lsp/testdata/missingtwodep/summary.txt.golden @@ -12,6 +12,7 @@ FoldingRangesCount = 0 FormatCount = 0 ImportCount = 0 SuggestedFixCount = 3 +FunctionExtractionCount = 0 DefinitionsCount = 0 TypeDefinitionsCount = 0 HighlightsCount = 0 diff --git a/internal/lsp/testdata/unused/summary.txt.golden b/internal/lsp/testdata/unused/summary.txt.golden index 5c4f74a660b..3f09a08ae20 100644 --- a/internal/lsp/testdata/unused/summary.txt.golden +++ b/internal/lsp/testdata/unused/summary.txt.golden @@ -12,6 +12,7 @@ FoldingRangesCount = 0 FormatCount = 0 ImportCount = 0 SuggestedFixCount = 1 +FunctionExtractionCount = 0 DefinitionsCount = 0 TypeDefinitionsCount = 0 HighlightsCount = 0 diff --git a/internal/lsp/testdata/upgradedep/summary.txt.golden b/internal/lsp/testdata/upgradedep/summary.txt.golden index 79042cc6c8e..2719246aaca 100644 --- a/internal/lsp/testdata/upgradedep/summary.txt.golden +++ b/internal/lsp/testdata/upgradedep/summary.txt.golden @@ -12,6 +12,7 @@ FoldingRangesCount = 0 FormatCount = 0 ImportCount = 0 SuggestedFixCount = 0 +FunctionExtractionCount = 0 DefinitionsCount = 0 TypeDefinitionsCount = 0 HighlightsCount = 0 diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 88d07a38f2a..6fc9b5ee22e 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -57,6 +57,7 @@ type FoldingRanges []span.Span type Formats []span.Span type Imports []span.Span type SuggestedFixes map[span.Span][]string +type FunctionExtractions map[span.Span]span.Span type Definitions map[span.Span]Definition type Implementations map[span.Span][]span.Span type Highlights map[span.Span][]span.Span @@ -87,6 +88,7 @@ type Data struct { Formats Formats Imports Imports SuggestedFixes SuggestedFixes + FunctionExtractions FunctionExtractions Definitions Definitions Implementations Implementations Highlights Highlights @@ -128,6 +130,7 @@ type Tests interface { Format(*testing.T, span.Span) Import(*testing.T, span.Span) SuggestedFix(*testing.T, span.Span, []string) + FunctionExtraction(*testing.T, span.Span, span.Span) Definition(*testing.T, span.Span, Definition) Implementation(*testing.T, span.Span, []span.Span) Highlight(*testing.T, span.Span, []span.Span) @@ -288,6 +291,7 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) []*Data { Renames: make(Renames), PrepareRenames: make(PrepareRenames), SuggestedFixes: make(SuggestedFixes), + FunctionExtractions: make(FunctionExtractions), Symbols: make(Symbols), symbolsChildren: make(SymbolsChildren), symbolInformation: make(SymbolInformation), @@ -420,6 +424,7 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) []*Data { "signature": datum.collectSignatures, "link": datum.collectLinks, "suggestedfix": datum.collectSuggestedFixes, + "extractfunc": datum.collectFunctionExtractions, }); err != nil { t.Fatal(err) } @@ -611,6 +616,20 @@ func Run(t *testing.T, tests Tests, data *Data) { } }) + t.Run("FunctionExtraction", func(t *testing.T) { + t.Helper() + for start, end := range data.FunctionExtractions { + // Check if we should skip this spn if the -modfile flag is not available. + if shouldSkip(data, start.URI()) { + continue + } + t.Run(SpanName(start), func(t *testing.T) { + t.Helper() + tests.FunctionExtraction(t, start, end) + }) + } + }) + t.Run("Definition", func(t *testing.T) { t.Helper() for spn, d := range data.Definitions { @@ -801,6 +820,7 @@ func checkData(t *testing.T, data *Data) { fmt.Fprintf(buf, "FormatCount = %v\n", len(data.Formats)) fmt.Fprintf(buf, "ImportCount = %v\n", len(data.Imports)) fmt.Fprintf(buf, "SuggestedFixCount = %v\n", len(data.SuggestedFixes)) + fmt.Fprintf(buf, "FunctionExtractionCount = %v\n", len(data.FunctionExtractions)) fmt.Fprintf(buf, "DefinitionsCount = %v\n", definitionCount) fmt.Fprintf(buf, "TypeDefinitionsCount = %v\n", typeDefinitionCount) fmt.Fprintf(buf, "HighlightsCount = %v\n", len(data.Highlights)) @@ -1023,6 +1043,12 @@ func (data *Data) collectSuggestedFixes(spn span.Span, actionKind string) { data.SuggestedFixes[spn] = append(data.SuggestedFixes[spn], actionKind) } +func (data *Data) collectFunctionExtractions(start span.Span, end span.Span) { + if _, ok := data.FunctionExtractions[start]; !ok { + data.FunctionExtractions[start] = end + } +} + func (data *Data) collectDefinitions(src, target span.Span) { data.Definitions[src] = Definition{ Src: src,