Skip to content

Commit

Permalink
jacob/inference: patchset 4 cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobzim-stl committed Nov 21, 2024
1 parent e0b62f9 commit 551bd57
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 34 deletions.
Binary file added gopls/gopls
Binary file not shown.
71 changes: 37 additions & 34 deletions gopls/internal/golang/completion/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -2065,8 +2065,7 @@ func enclosingFunction(path []ast.Node, info *types.Info) *funcInfo {
return nil
}

func (c *completer) expectedCompositeLiteralType() types.Type {
clInfo := c.enclosingCompositeLiteral
func expectedCompositeLiteralType(clInfo *compLitInfo, pos token.Pos) types.Type {
switch t := clInfo.clType.(type) {
case *types.Slice:
if clInfo.inKey {
Expand Down Expand Up @@ -2105,7 +2104,7 @@ func (c *completer) expectedCompositeLiteralType() types.Type {

// The order of the literal fields must match the order in the struct definition.
// Find the element that the position belongs to and suggest that field's type.
if i := exprAtPos(c.pos, clInfo.cl.Elts); i < t.NumFields() {
if i := exprAtPos(pos, clInfo.cl.Elts); i < t.NumFields() {
return t.Field(i).Type()
}
}
Expand Down Expand Up @@ -2280,7 +2279,7 @@ func expectedCandidate(ctx context.Context, c *completer) (inf candidateInferenc
inf.typeName = expectTypeName(c)

if c.enclosingCompositeLiteral != nil {
inf.objType = c.expectedCompositeLiteralType()
inf.objType = expectedCompositeLiteralType(c.enclosingCompositeLiteral, c.pos)
}

Nodes:
Expand Down Expand Up @@ -2314,7 +2313,10 @@ Nodes:
}
return inf
case *ast.SendStmt:
return expectedSendStmtCandidate(c, node, inf)
if typ := expectedSendStmtType(c, node); typ != nil {
inf.objType = typ
}
return inf
case *ast.CallExpr:
// Only consider CallExpr args if position falls between parens.
if node.Lparen < c.pos && c.pos <= node.Rparen {
Expand All @@ -2332,7 +2334,7 @@ Nodes:
}

targs := c.getTypeArgs(node)
res := inferExpectedCallExprResults(c, i)
res := inferExpectedResultTypes(c, i)
substs := reverseInferTypeArgs(sig, targs, res)
inst := instantiate(sig, substs)
if inst != nil {
Expand Down Expand Up @@ -2416,7 +2418,9 @@ Nodes:
inf.objType = ct
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
inf = c.inferExpectedTypeArgs(i+1, 0, inf)
if typ := c.inferExpectedTypeArg(i+1, 0); typ != nil {
inf.objType = typ
}
}
}
}
Expand All @@ -2429,7 +2433,9 @@ Nodes:
inf.objType = ct
inf.typeName.wantTypeName = true
inf.typeName.isTypeParam = true
inf = c.inferExpectedTypeArgs(i+1, typeParamIdx, inf)
if typ := c.inferExpectedTypeArg(i+1, typeParamIdx); typ != nil {
inf.objType = typ
}
}
}
}
Expand Down Expand Up @@ -2468,7 +2474,7 @@ Nodes:
return inf
}

// inferExpectedCallExprResults takes the index of a call expression within the completion
// inferExpectedResultTypes takes the index of a call expression within the completion
// path and uses its surroundings to infer the expected result tuple of the call's signature.
// Returns the signature result tuple as a slice, or nil if reverse type inference fails.
//
Expand All @@ -2480,30 +2486,27 @@ Nodes:
// var y TypeB
// x, y := generic(<cursor>, <cursor>)
//
// inferExpectedCallExprResults can determine that the expected result type of the function is (TypeA, TypeB)
func inferExpectedCallExprResults(c *completer, callNodeIdx int) []types.Type {
// inferExpectedResultTypes can determine that the expected result type of the function is (TypeA, TypeB)
func inferExpectedResultTypes(c *completer, callNodeIdx int) []types.Type {
callNode, _ := c.path[callNodeIdx].(*ast.CallExpr)

if len(c.path) <= callNodeIdx+1 {
return nil
}

var expectedResults []types.Type
var inf candidateInference

// Check the parents of the call node to extract the expected result types of the call signature.
// Currently reverse inferences are only supported with the the following parent expressions,
// however this list isn't exhaustive.
switch node := c.path[callNodeIdx+1].(type) {
case *ast.KeyValueExpr:
c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[callNodeIdx:], callNode.Pos(), c.pkg.TypesInfo())
if !wantStructFieldCompletions(c.enclosingCompositeLiteral) {
expectedResults = append(expectedResults, c.expectedCompositeLiteralType())
enclosingCompositeLiteral := enclosingCompositeLiteral(c.path[callNodeIdx:], callNode.Pos(), c.pkg.TypesInfo())
if !wantStructFieldCompletions(enclosingCompositeLiteral) {
expectedResults = append(expectedResults, expectedCompositeLiteralType(enclosingCompositeLiteral, c.pos))
}
// undo side effect
c.enclosingCompositeLiteral = nil
case *ast.AssignStmt:
inf := expectedAssignStmtCandidate(c, node, inf)
inf := expectedAssignStmtCandidate(c, node, candidateInference{})
if len(inf.assignees) > 0 {
expectedResults = make([]types.Type, len(inf.assignees))
copy(expectedResults, inf.assignees)
Expand All @@ -2515,7 +2518,7 @@ func inferExpectedCallExprResults(c *completer, callNodeIdx int) []types.Type {
expectedResults = append(expectedResults, resultType)
}
case *ast.SendStmt:
if resultType := expectedSendStmtCandidate(c, node, inf).objType; resultType != nil {
if resultType := expectedSendStmtType(c, node); resultType != nil {
expectedResults = append(expectedResults, resultType)
}
case *ast.ReturnStmt:
Expand Down Expand Up @@ -2550,18 +2553,18 @@ func inferExpectedCallExprResults(c *completer, callNodeIdx int) []types.Type {
return expectedResults
}

// expectedSendStmtCandidate returns information about the expected candidate
// for a SendStmt at the query position.
func expectedSendStmtCandidate(c *completer, node *ast.SendStmt, inf candidateInference) candidateInference {
// expectedSendStmtType return the expected type at the position.
// Returns nil if unknown.
func expectedSendStmtType(c *completer, node *ast.SendStmt) types.Type {
// Make sure we are on right side of arrow (e.g. "foo <- <>").
if c.pos > node.Arrow+1 {
if tv, ok := c.pkg.TypesInfo().Types[node.Chan]; ok {
if ch, ok := tv.Type.Underlying().(*types.Chan); ok {
inf.objType = ch.Elem()
return ch.Elem()
}
}
}
return inf
return nil
}

// expectedValueSpecType returns the expected type of a ValueSpec at the query
Expand Down Expand Up @@ -2603,7 +2606,8 @@ func expectedAssignStmtCandidate(c *completer, node *ast.AssignStmt, inf candida
return inf
}

// expectedReturnStmtType returns nil if enclosingSig is nil
// expectedReturnStmtType returns the expected type of a return statement.
// Returns nil if enclosingSig is nil.
func expectedReturnStmtType(enclosingSig *types.Signature, node *ast.ReturnStmt, pos token.Pos) types.Type {
if enclosingSig != nil {
if resultIdx := exprAtPos(pos, node.Results); resultIdx < len(node.Results) {
Expand Down Expand Up @@ -2672,34 +2676,33 @@ func reverseInferTypeArgs(sig *types.Signature, typeArgs []types.Type, expectedR
return substs
}

// inferExpectedTypeArgs gives a type param candidateInference based on the surroundings of it's call site.
// inferExpectedTypeArg gives a type param candidateInference based on the surroundings of it's call site.
// If successful, the inf parameter is returned with only it's objType field updated.
//
// callNodeIdx is the index within the completion path of the type parameter's parent call expression.
// typeParamIdx is the index of the type parameter at the completion pos.
func (c *completer) inferExpectedTypeArgs(callNodeIdx int, typeParamIdx int, inf candidateInference) candidateInference {
func (c *completer) inferExpectedTypeArg(callNodeIdx int, typeParamIdx int) types.Type {
if len(c.path) <= callNodeIdx {
return inf
return nil
}

callNode, ok := c.path[callNodeIdx].(*ast.CallExpr)
if !ok {
return inf
return nil
}

// Infer the type parameters in a function call based on it's context
sig := c.pkg.TypesInfo().Types[callNode.Fun].Type.(*types.Signature)
expectedResults := inferExpectedCallExprResults(c, callNodeIdx)
expectedResults := inferExpectedResultTypes(c, callNodeIdx)
if typeParamIdx < 0 || typeParamIdx >= sig.TypeParams().Len() {
return inf
return nil
}
substs := reverseInferTypeArgs(sig, nil, expectedResults)
if substs == nil || substs[typeParamIdx] == nil {
return inf
return nil
}

inf.objType = substs[typeParamIdx]
return inf
return substs[typeParamIdx]
}

// Instantiates a signature with a set of type parameters.
Expand Down

0 comments on commit 551bd57

Please sign in to comment.