diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index 3981bbae399..f8eaa2c1263 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -2303,7 +2303,10 @@ Nodes: break Nodes } case *ast.AssignStmt: - return expectedAssignStmtCandidate(c, node, inf) + objType, assignees := expectedAssignStmtTypes(c, node) + inf.objType = objType + inf.assignees = assignees + return inf case *ast.ValueSpec: inf.objType = expectedValueSpecType(c, node) return @@ -2506,12 +2509,11 @@ func inferExpectedResultTypes(c *completer, callNodeIdx int) []types.Type { expectedResults = append(expectedResults, expectedCompositeLiteralType(enclosingCompositeLiteral, c.pos)) } case *ast.AssignStmt: - inf := expectedAssignStmtCandidate(c, node, candidateInference{}) - if len(inf.assignees) > 0 { - expectedResults = make([]types.Type, len(inf.assignees)) - copy(expectedResults, inf.assignees) - } else if inf.objType != nil { - expectedResults = append(expectedResults, inf.objType) + objType, assignees := expectedAssignStmtTypes(c, node) + if len(assignees) > 0 { + return assignees + } else if objType != nil { + expectedResults = append(expectedResults, objType) } case *ast.ValueSpec: if resultType := expectedValueSpecType(c, node); resultType != nil { @@ -2576,9 +2578,8 @@ func expectedValueSpecType(c *completer, node *ast.ValueSpec) types.Type { return nil } -// expectedAssignStmtCandidate returns information about the expected candidate -// for a AssignStmt at the query position. -func expectedAssignStmtCandidate(c *completer, node *ast.AssignStmt, inf candidateInference) candidateInference { +// expectedAssignStmtTypes returns the inference objType and assignees for the assignment. +func expectedAssignStmtTypes(c *completer, node *ast.AssignStmt) (objType types.Type, assignees []types.Type) { // Only rank completions if you are on the right side of the token. if c.pos > node.TokPos { i := exprAtPos(c.pos, node.Rhs) @@ -2586,7 +2587,7 @@ func expectedAssignStmtCandidate(c *completer, node *ast.AssignStmt, inf candida i = len(node.Lhs) - 1 } if tv, ok := c.pkg.TypesInfo().Types[node.Lhs[i]]; ok { - inf.objType = tv.Type + objType = tv.Type } // If we have a single expression on the RHS, record the LHS @@ -2594,16 +2595,16 @@ func expectedAssignStmtCandidate(c *completer, node *ast.AssignStmt, inf candida // matching result values. if len(node.Rhs) <= 1 { for _, lhs := range node.Lhs { - inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(lhs)) + assignees = append(assignees, c.pkg.TypesInfo().TypeOf(lhs)) } } else { // Otherwise, record our single assignee, even if its type is // not available. We use this info to downrank functions // with the wrong number of result values. - inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(node.Lhs[i])) + assignees = append(assignees, c.pkg.TypesInfo().TypeOf(node.Lhs[i])) } } - return inf + return objType, assignees } // expectedReturnStmtType returns the expected type of a return statement.