diff --git a/gopls/internal/golang/completion/completion.go b/gopls/internal/golang/completion/completion.go index cf398693113..60d5cd635e0 100644 --- a/gopls/internal/golang/completion/completion.go +++ b/gopls/internal/golang/completion/completion.go @@ -134,6 +134,33 @@ func (i *CompletionItem) Snippet() string { return i.InsertText } +// addConversion wraps the existing completionItem in a conversion expression. +// Only affects the receiver's InsertText and snippet fields, not the Label. +// An empty conv argument has no effect. +func (i *CompletionItem) addConversion(c *completer, conv conversionEdits) error { + if conv.prefix != "" { + // If we are in a selector, add an edit to place prefix before selector. + if sel := enclosingSelector(c.path, c.pos); sel != nil { + edits, err := c.editText(sel.Pos(), sel.Pos(), conv.prefix) + if err != nil { + return err + } + i.AdditionalTextEdits = append(i.AdditionalTextEdits, edits...) + } else { + // If there is no selector, just stick the prefix at the start. + i.InsertText = conv.prefix + i.InsertText + i.snippet.PrependText(conv.prefix) + } + } + + if conv.suffix != "" { + i.InsertText += conv.suffix + i.snippet.WriteText(conv.suffix) + } + + return nil +} + // Scoring constants are used for weighting the relevance of different candidates. const ( // stdScore is the base score for all completion items. @@ -2164,6 +2191,25 @@ type candidateInference struct { // convertibleTo is a type our candidate type must be convertible to. convertibleTo types.Type + // needsExactType is true if the candidate type must be exactly the type of + // the objType, e.g. an interface rather than it's implementors. + // + // This is necessary when objType is derived using reverse type inference: + // any different (but assignable) type may lead to different type inference, + // which may no longer be valid. + // + // For example, consider the following scenario: + // + // func f[T any](x T) []T { return []T{x} } + // + // var s []any = f(_) + // + // Reverse type inference would infer that the type at _ must be 'any', but + // that does not mean that any object in the lexical scope is valid: the type of + // the object must be *exactly* any, otherwise type inference will cause the + // slice assignment to fail. + needsExactType bool + // typeName holds information about the expected type name at // position, if any. typeName typeNameInference @@ -2259,36 +2305,13 @@ Nodes: break Nodes } case *ast.AssignStmt: - // Only rank completions if you are on the right side of the token. - if c.pos > node.TokPos { - i := exprAtPos(c.pos, node.Rhs) - if i >= len(node.Lhs) { - i = len(node.Lhs) - 1 - } - if tv, ok := c.pkg.TypesInfo().Types[node.Lhs[i]]; ok { - inf.objType = tv.Type - } - - // If we have a single expression on the RHS, record the LHS - // assignees so we can favor multi-return function calls with - // matching result values. - if len(node.Rhs) <= 1 { - for _, lhs := range node.Lhs { - inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(lhs)) - } - } else { - // Otherwise, record our single assignee, even if its type is - // not available. We use this info to downrank functions - // with the wrong number of result values. - inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(node.Lhs[i])) - } - } - return inf + return expectedAssignStmtCandidate(c, node, inf) case *ast.ValueSpec: - if node.Type != nil && c.pos > node.Type.End() { - inf.objType = c.pkg.TypesInfo().TypeOf(node.Type) - } - return inf + return expectedValueSpecCandidate(c, node, inf) + case *ast.ReturnStmt: + return expectedReturnStmtCandidate(c, node, inf) + case *ast.SendStmt: + return expectedSendStmtCandidate(c, node, inf) case *ast.CallExpr: // Only consider CallExpr args if position falls between parens. if node.Lparen < c.pos && c.pos <= node.Rparen { @@ -2299,22 +2322,25 @@ Nodes: break Nodes } - sig, _ := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature) - - if sig != nil && sig.TypeParams().Len() > 0 { - // If we are completing a generic func call, re-check the call expression. - // This allows type param inference to work in cases like: - // - // func foo[T any](T) {} - // foo[int](<>) // <- get "int" completions instead of "T" - // - // TODO: remove this after https://go.dev/issue/52503 - info := &types.Info{Types: make(map[ast.Expr]types.TypeAndValue)} - types.CheckExpr(c.pkg.FileSet(), c.pkg.Types(), node.Fun.Pos(), node.Fun, info) - sig, _ = info.Types[node.Fun].Type.(*types.Signature) - } + if sig, ok := c.pkg.TypesInfo().Types[node.Fun].Type.(*types.Signature); ok { + // Out of bounds arguments get no inference completion. + if !sig.Variadic() && exprAtPos(c.pos, node.Args) >= sig.Params().Len() { + return inf + } + + targs := c.getTypeArgs(node) + expectedResults := reverseInferExpectedCallExprResults(c, i) + substs := makeTypeParamSubstitions(sig, targs, expectedResults) + inst := instantiateSignature(sig, substs) + if inst != nil { + // TODO(jacobz): If partial signature instantiation becomes possible, + // make needsExactType only true if necessary. + // Currently ambigious cases resolve to a correct, but occassionally, + // superfluous conversion expression wrapping the completion. + inf.needsExactType = true + sig = inst + } - if sig != nil { inf = c.expectedCallParamType(inf, node, sig) } @@ -2344,17 +2370,6 @@ Nodes: return inf } - case *ast.ReturnStmt: - if c.enclosingFunc != nil { - sig := c.enclosingFunc.sig - // Find signature result that corresponds to our return statement. - if resultIdx := exprAtPos(c.pos, node.Results); resultIdx < len(node.Results) { - if resultIdx < sig.Results().Len() { - inf.objType = sig.Results().At(resultIdx).Type() - } - } - } - return inf case *ast.CaseClause: if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok { if tv, ok := c.pkg.TypesInfo().Types[swtch.Tag]; ok { @@ -2398,6 +2413,7 @@ Nodes: inf.objType = ct inf.typeName.wantTypeName = true inf.typeName.isTypeParam = true + inf = c.reverseInferExpectedTypeParam(i+1, 0, inf) } } } @@ -2405,20 +2421,12 @@ Nodes: case *ast.IndexListExpr: if node.Lbrack < c.pos && c.pos <= node.Rbrack { if tv, ok := c.pkg.TypesInfo().Types[node.X]; ok { - if ct := expectedConstraint(tv.Type, exprAtPos(c.pos, node.Indices)); ct != nil { + typeParamIdx := exprAtPos(c.pos, node.Indices) + if ct := expectedConstraint(tv.Type, typeParamIdx); ct != nil { inf.objType = ct inf.typeName.wantTypeName = true inf.typeName.isTypeParam = true - } - } - } - return inf - case *ast.SendStmt: - // Make sure we are on right side of arrow (e.g. "foo <- <>"). - if c.pos > node.Arrow+1 { - if tv, ok := c.pkg.TypesInfo().Types[node.Chan]; ok { - if ch, ok := tv.Type.Underlying().(*types.Chan); ok { - inf.objType = ch.Elem() + inf = c.reverseInferExpectedTypeParam(i+1, typeParamIdx, inf) } } } @@ -2457,6 +2465,269 @@ Nodes: return inf } +// reverseInferExpectedCallExprResults takes the index of a call expression within the completion +// path and uses its surroundings to infer the expected result tuple of the call's signature. +// Returns the signature result tuple as a slice, or nil if reverse type inference fails. +// +// # For example +// +// func generic[T any, U any](a T, b U) (T, U) { ... } +// +// var x TypeA +// var y TypeB +// x, y := generic(, ) +// +// reverseInferExpectedCallExprResults can determine that the expected result type of the function is (TypeA, TypeB) +func reverseInferExpectedCallExprResults(c *completer, callNodeIdx int) []types.Type { + callNode := c.path[callNodeIdx].(*ast.CallExpr) + if callNode == nil { + return nil + } + + if len(c.path) <= callNodeIdx+1 { + return nil + } + + var expectedResults []types.Type + var inf candidateInference + + // Check the parents of the call node to extract the expected result types of the call signature. + // Currently reverse inferences are only supported with the the following parent expressions, + // however this list isn't exhaustive. + switch node := c.path[callNodeIdx+1].(type) { + case *ast.KeyValueExpr: + c.enclosingCompositeLiteral = enclosingCompositeLiteral(c.path[callNodeIdx:], callNode.Pos(), c.pkg.TypesInfo()) + if !c.wantStructFieldCompletions() { + expectedResults = append(expectedResults, c.expectedCompositeLiteralType()) + } + // undo side effect + c.enclosingCompositeLiteral = nil + case *ast.AssignStmt: + inf := expectedAssignStmtCandidate(c, node, inf) + if len(inf.assignees) > 0 { + expectedResults = make([]types.Type, len(inf.assignees)) + copy(expectedResults, inf.assignees) + } else if inf.objType != nil { + expectedResults = append(expectedResults, inf.objType) + } + case *ast.ValueSpec: + if resultType := expectedValueSpecCandidate(c, node, inf).objType; resultType != nil { + expectedResults = append(expectedResults, resultType) + } + case *ast.SendStmt: + if resultType := expectedSendStmtCandidate(c, node, inf).objType; resultType != nil { + expectedResults = append(expectedResults, resultType) + } + case *ast.ReturnStmt: + if c.enclosingFunc == nil { + return nil + } + + // As a special case for reverse call inference in + // + // return foo() + // + // Pull the result type from the enclosing function + if exprAtPos(c.pos, node.Results) == 0 { + if callSig := c.pkg.TypesInfo().Types[callNode.Fun].Type.(*types.Signature); callSig != nil { + enclosingResults := c.enclosingFunc.sig.Results() + if callSig.Results().Len() == enclosingResults.Len() { + expectedResults = make([]types.Type, enclosingResults.Len()) + for i := range enclosingResults.Len() { + expectedResults[i] = enclosingResults.At(i).Type() + } + break + } + } + } + + if resultType := expectedReturnStmtCandidate(c, node, inf).objType; resultType != nil { + expectedResults = append(expectedResults, resultType) + } + } + return expectedResults +} + +// expectedSendStmtCandidate returns information about the expected candidate +// for a SendStmt at the query position. +func expectedSendStmtCandidate(c *completer, node *ast.SendStmt, inf candidateInference) candidateInference { + // Make sure we are on right side of arrow (e.g. "foo <- <>"). + if c.pos > node.Arrow+1 { + if tv, ok := c.pkg.TypesInfo().Types[node.Chan]; ok { + if ch, ok := tv.Type.Underlying().(*types.Chan); ok { + inf.objType = ch.Elem() + } + } + } + return inf +} + +// expectedValueSpecCandidate returns information about the expected candidate +// for a ValueSpec at the query position. +func expectedValueSpecCandidate(c *completer, node *ast.ValueSpec, inf candidateInference) candidateInference { + if node.Type != nil && c.pos > node.Type.End() { + inf.objType = c.pkg.TypesInfo().TypeOf(node.Type) + } + return inf +} + +// expectedAssignStmtCandidate returns information about the expected candidate +// for a AssignStmt at the query position. +func expectedAssignStmtCandidate(c *completer, node *ast.AssignStmt, inf candidateInference) candidateInference { + // Only rank completions if you are on the right side of the token. + if c.pos > node.TokPos { + i := exprAtPos(c.pos, node.Rhs) + if i >= len(node.Lhs) { + i = len(node.Lhs) - 1 + } + if tv, ok := c.pkg.TypesInfo().Types[node.Lhs[i]]; ok { + inf.objType = tv.Type + } + + // If we have a single expression on the RHS, record the LHS + // assignees so we can favor multi-return function calls with + // matching result values. + if len(node.Rhs) <= 1 { + for _, lhs := range node.Lhs { + inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(lhs)) + } + } else { + // Otherwise, record our single assignee, even if its type is + // not available. We use this info to downrank functions + // with the wrong number of result values. + inf.assignees = append(inf.assignees, c.pkg.TypesInfo().TypeOf(node.Lhs[i])) + } + } + return inf +} + +// expectedReturnStmtCandidate returns information about the expected candidate +// for a ReturnStmt at the query position. +func expectedReturnStmtCandidate(c *completer, node *ast.ReturnStmt, inf candidateInference) candidateInference { + if c.enclosingFunc != nil { + sig := c.enclosingFunc.sig + // Find signature result that corresponds to our return statement. + if resultIdx := exprAtPos(c.pos, node.Results); resultIdx < len(node.Results) { + if resultIdx < sig.Results().Len() { + inf.objType = sig.Results().At(resultIdx).Type() + } + } + } + return inf +} + +// Returns the number of type arguments in a callExpr +func (c *completer) getTypeArgs(callExpr *ast.CallExpr) []types.Type { + var targs []types.Type + switch fun := callExpr.Fun.(type) { + case *ast.IndexListExpr: + for i := range fun.Indices { + if typ, ok := c.pkg.TypesInfo().Types[fun.Indices[i]]; ok && typeIsValid(typ.Type) { + targs = append(targs, typ.Type) + } + } + case *ast.IndexExpr: + if typ, ok := c.pkg.TypesInfo().Types[fun.Index]; ok && typeIsValid(typ.Type) { + targs = []types.Type{typ.Type} + } + } + return targs +} + +// makeTypeParamSubstitions takes a generic signature, a list of passed type arguments, and the expected concrete return types +// inferred from the signature's call site. If possible, it returns a list of types that could be used as the type arguments +// to the signature. If not possible, it returns nil. +// +// Does not panic if any of the arguments are nil. +func makeTypeParamSubstitions(sig *types.Signature, typeArgs []types.Type, expectedResults []types.Type) []types.Type { + if len(expectedResults) == 0 || sig == nil || sig.TypeParams().Len() == 0 || + sig.Results().Len() != len(expectedResults) { + return nil + } + + tparams := make([]*types.TypeParam, sig.TypeParams().Len()) + for i := range sig.TypeParams().Len() { + tparams[i] = sig.TypeParams().At(i) + } + + for i := len(typeArgs); i < sig.TypeParams().Len(); i++ { + typeArgs = append(typeArgs, nil) + } + + u := newUnifier(tparams, typeArgs) + for i, assignee := range expectedResults { + // Unify does not check the constraints of the type parameters. + // Checks must be applied after. + if !u.unify(sig.Results().At(i).Type(), assignee, unifyModeExact) { + return nil + } + } + + substs := make([]types.Type, sig.TypeParams().Len()) + for i := 0; i < sig.TypeParams().Len(); i++ { + if sub := u.handles[sig.TypeParams().At(i)]; sub != nil && *sub != nil { + // Ensure the inferred subst is assignable to the type parameter's constraint. + if !assignableTo(*sub, sig.TypeParams().At(i).Constraint()) { + return nil + } + substs[i] = *sub + } + } + return substs +} + +// reverseInferExpectedTypeParam gives a type param candidateInference based on the surroundings of it's call site. +// If successful, the inf parameter is returned with only it's objType field updated. +// +// callNodeIdx is the index within the completion path of the type parameter's parent call expression. +// typeParamIdx is the index of the type parameter at the completion pos. +func (c *completer) reverseInferExpectedTypeParam(callNodeIdx int, typeParamIdx int, inf candidateInference) candidateInference { + if len(c.path) <= callNodeIdx { + return inf + } + + callNode, ok := c.path[callNodeIdx].(*ast.CallExpr) + if !ok { + return inf + } + + // Infer the type parameters in a function call based on it's context + sig := c.pkg.TypesInfo().Types[callNode.Fun].Type.(*types.Signature) + expectedResults := reverseInferExpectedCallExprResults(c, callNodeIdx) + if typeParamIdx < 0 || typeParamIdx >= sig.TypeParams().Len() { + return inf + } + substs := makeTypeParamSubstitions(sig, nil, expectedResults) + if substs == nil || substs[typeParamIdx] == nil { + return inf + } + + inf.objType = substs[typeParamIdx] + return inf +} + +// Instantiates a signature with a set of type parameters. +// Wrapper around types.Instantiate but bad arguments won't cause a panic. +func instantiateSignature(sig *types.Signature, substs []types.Type) *types.Signature { + if substs == nil || sig == nil || len(substs) != sig.TypeParams().Len() { + return nil + } + + for i := range substs { + if substs[i] == nil { + substs[i] = sig.TypeParams().At(i) + } + } + + if inst, err := types.Instantiate(nil, sig, substs, true); err == nil { + if inst, ok := inst.(*types.Signature); ok { + return inst + } + } + + return nil +} + func (c *completer) expectedCallParamType(inf candidateInference, node *ast.CallExpr, sig *types.Signature) candidateInference { numParams := sig.Params().Len() if numParams == 0 { @@ -2972,6 +3243,14 @@ func (ci *candidateInference) candTypeMatches(cand *candidate) bool { cand.mods = append(cand.mods, takeDotDotDot) } + // Candidate matches, but isn't exactly identical to the expected type. + // Apply a conversion to allow it to match. + if ci.needsExactType && !types.Identical(candType, expType) { + cand.convertTo = expType + // Ranks barely lower if it needs a conversion, even though it's perfectly valid. + cand.score *= 0.95 + } + // Lower candidate score for untyped conversions. This avoids // ranking untyped constants above candidates with an exact type // match. Don't lower score of builtin constants, e.g. "true". @@ -3161,6 +3440,9 @@ func (c *completer) matchingTypeName(cand *candidate) bool { return false } + wantExactTypeParam := c.inference.typeName.isTypeParam && + c.inference.typeName.wantTypeName && c.inference.needsExactType + typeMatches := func(candType types.Type) bool { // Take into account any type name modifier prefixes. candType = c.inference.applyTypeNameModifiers(candType) @@ -3179,6 +3461,13 @@ func (c *completer) matchingTypeName(cand *candidate) bool { } } + // Suggest the exact type when performing reverse type inference. + // x = Foo[<>]() + // Where x is an interface kind, only suggest the interface type rather than its implementors + if wantExactTypeParam && types.Identical(candType, c.inference.objType) { + return true + } + if c.inference.typeName.wantComparable && !types.Comparable(candType) { return false } diff --git a/gopls/internal/golang/completion/format.go b/gopls/internal/golang/completion/format.go index c2b955ca7e9..1d83b44ae92 100644 --- a/gopls/internal/golang/completion/format.go +++ b/gopls/internal/golang/completion/format.go @@ -196,24 +196,9 @@ Suffixes: } if cand.convertTo != nil { - typeName := types.TypeString(cand.convertTo, c.qf) - - switch t := cand.convertTo.(type) { - // We need extra parens when casting to these types. For example, - // we need "(*int)(foo)", not "*int(foo)". - case *types.Pointer, *types.Signature: - typeName = "(" + typeName + ")" - case *types.Basic: - // If the types are incompatible (as determined by typeMatches), then we - // must need a conversion here. However, if the target type is untyped, - // don't suggest converting to e.g. "untyped float" (golang/go#62141). - if t.Info()&types.IsUntyped != 0 { - typeName = types.TypeString(types.Default(cand.convertTo), c.qf) - } - } - - prefix = typeName + "(" + prefix - suffix = ")" + conv := c.formatConversion(cand.convertTo) + prefix = conv.prefix + prefix + suffix = conv.suffix } if prefix != "" { @@ -288,6 +273,38 @@ Suffixes: return item, nil } +// conversionEdits represents the string edits needed to make a type conversion +// of an expression. +type conversionEdits struct { + prefix, suffix string +} + +// formatConversion returns the edits needed to make a type conversion +// expression, including parentheses if necessary. +// +// Returns empty conversionEdits if convertTo is nil. +func (c *completer) formatConversion(convertTo types.Type) conversionEdits { + if convertTo == nil { + return conversionEdits{} + } + + typeName := types.TypeString(convertTo, c.qf) + switch t := convertTo.(type) { + // We need extra parens when casting to these types. For example, + // we need "(*int)(foo)", not "*int(foo)". + case *types.Pointer, *types.Signature: + typeName = "(" + typeName + ")" + case *types.Basic: + // If the types are incompatible (as determined by typeMatches), then we + // must need a conversion here. However, if the target type is untyped, + // don't suggest converting to e.g. "untyped float" (golang/go#62141). + if t.Info()&types.IsUntyped != 0 { + typeName = types.TypeString(types.Default(convertTo), c.qf) + } + } + return conversionEdits{prefix: typeName + "(", suffix: ")"} +} + // importEdits produces the text edits necessary to add the given import to the current file. func (c *completer) importEdits(imp *importInfo) ([]protocol.TextEdit, error) { if imp == nil { diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 7427d559e94..21b791d4c97 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -73,15 +73,21 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im cand.addressable = true } - if !c.matchingCandidate(&cand) || cand.convertTo != nil { + // Only suggest a literal conversion if the exact type is known. + if !c.matchingCandidate(&cand) || (cand.convertTo != nil && !c.inference.needsExactType) { return } var ( - qf = c.qf - sel = enclosingSelector(c.path, c.pos) + qf = c.qf + sel = enclosingSelector(c.path, c.pos) + conversion conversionEdits ) + if cand.convertTo != nil { + conversion = c.formatConversion(cand.convertTo) + } + // Don't qualify the type name if we are in a selector expression // since the package name is already present. if sel != nil { @@ -129,13 +135,18 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im switch t := literalType.Underlying().(type) { case *types.Struct, *types.Array, *types.Slice, *types.Map: - c.compositeLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + item := c.compositeLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + item.addConversion(c, conversion) + c.items = append(c.items, item) case *types.Signature: // Add a literal completion for a signature type that implements // an interface. For example, offer "http.HandlerFunc()" when // expected type is "http.Handler". if expType != nil && types.IsInterface(expType) { - c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + if item, ok := c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits); ok { + item.addConversion(c, conversion) + c.items = append(c.items, item) + } } case *types.Basic: // Add a literal completion for basic types that implement our @@ -143,7 +154,10 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im // implements http.FileSystem), or are identical to our expected // type (i.e. yielding a type conversion such as "float64()"). if expType != nil && (types.IsInterface(expType) || types.Identical(expType, literalType)) { - c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits) + if item, ok := c.basicLiteral(t, snip.Clone(), typeName, float64(score), addlEdits); ok { + item.addConversion(c, conversion) + c.items = append(c.items, item) + } } } } @@ -155,11 +169,15 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im switch literalType.Underlying().(type) { case *types.Slice: // The second argument to "make()" for slices is required, so default to "0". - c.makeCall(snip.Clone(), typeName, "0", float64(score), addlEdits) + item := c.makeCall(snip.Clone(), typeName, "0", float64(score), addlEdits) + item.addConversion(c, conversion) + c.items = append(c.items, item) case *types.Map, *types.Chan: // Maps and channels don't require the second argument, so omit // to keep things simple for now. - c.makeCall(snip.Clone(), typeName, "", float64(score), addlEdits) + item := c.makeCall(snip.Clone(), typeName, "", float64(score), addlEdits) + item.addConversion(c, conversion) + c.items = append(c.items, item) } } @@ -167,7 +185,10 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im if score := c.matcher.Score("func"); !cand.hasMod(reference) && score > 0 && (expType == nil || !types.IsInterface(expType)) { switch t := literalType.Underlying().(type) { case *types.Signature: - c.functionLiteral(ctx, t, float64(score)) + if item, ok := c.functionLiteral(ctx, t, float64(score)); ok { + item.addConversion(c, conversion) + c.items = append(c.items, item) + } } } } @@ -178,9 +199,9 @@ func (c *completer) literal(ctx context.Context, literalType types.Type, imp *im // correct type, so scale down highScore. const literalCandidateScore = highScore / 2 -// functionLiteral adds a function literal completion item for the -// given signature. -func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, matchScore float64) { +// functionLiteral returns a function literal completion item for the +// given signature, if applicable. +func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, matchScore float64) (CompletionItem, bool) { snip := &snippet.Builder{} snip.WriteText("func(") @@ -216,7 +237,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } name = abbreviateTypeName(typeName) } @@ -284,7 +305,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } if sig.Variadic() && i == sig.Params().Len()-1 { typeStr = strings.Replace(typeStr, "[]", "...", 1) @@ -342,7 +363,7 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m if ctx.Err() == nil { event.Error(ctx, "formatting var type", err) } - return + return CompletionItem{}, false } if tp, ok := types.Unalias(r.Type()).(*types.TypeParam); ok && !c.typeParamInScope(tp) { snip.WritePlaceholder(func(snip *snippet.Builder) { @@ -360,12 +381,12 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m snip.WriteFinalTabstop() snip.WriteText("}") - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: "func(...) {}", Score: matchScore * literalCandidateScore, Kind: protocol.VariableCompletion, snippet: snip, - }) + }, true } // conventionalAcronyms contains conventional acronyms for type names @@ -430,9 +451,9 @@ func abbreviateTypeName(s string) string { return b.String() } -// compositeLiteral adds a composite literal completion item for the given typeName. +// compositeLiteral returns a composite literal completion item for the given typeName. // T is an (unnamed, unaliased) struct, array, slice, or map type. -func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) { +func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) CompletionItem { snip.WriteText("{") // Don't put the tab stop inside the composite literal curlies "{}" // for structs that have no accessible fields. @@ -443,22 +464,24 @@ func (c *completer) compositeLiteral(T types.Type, snip *snippet.Builder, typeNa nonSnippet := typeName + "{}" - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: nonSnippet, InsertText: nonSnippet, Score: matchScore * literalCandidateScore, Kind: protocol.VariableCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + } } -// basicLiteral adds a literal completion item for the given basic +// basicLiteral returns a literal completion item for the given basic // type name typeName. -func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) { +// +// If T is untyped, this function returns false. +func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName string, matchScore float64, edits []protocol.TextEdit) (CompletionItem, bool) { // Never give type conversions like "untyped int()". if isUntyped(T) { - return + return CompletionItem{}, false } snip.WriteText("(") @@ -467,7 +490,7 @@ func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName s nonSnippet := typeName + "()" - c.items = append(c.items, CompletionItem{ + return CompletionItem{ Label: nonSnippet, InsertText: nonSnippet, Detail: T.String(), @@ -475,11 +498,11 @@ func (c *completer) basicLiteral(T types.Type, snip *snippet.Builder, typeName s Kind: protocol.VariableCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + }, true } -// makeCall adds a completion item for a "make()" call given a specific type. -func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg string, matchScore float64, edits []protocol.TextEdit) { +// makeCall returns a completion item for a "make()" call given a specific type. +func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg string, matchScore float64, edits []protocol.TextEdit) CompletionItem { // Keep it simple and don't add any placeholders for optional "make()" arguments. snip.PrependText("make(") @@ -501,14 +524,15 @@ func (c *completer) makeCall(snip *snippet.Builder, typeName string, secondArg s } nonSnippet.WriteByte(')') - c.items = append(c.items, CompletionItem{ - Label: nonSnippet.String(), - InsertText: nonSnippet.String(), - Score: matchScore * literalCandidateScore, + return CompletionItem{ + Label: nonSnippet.String(), + InsertText: nonSnippet.String(), + // make() should be just below other literal completions + Score: matchScore * literalCandidateScore * 0.99, Kind: protocol.FunctionCompletion, AdditionalTextEdits: edits, snippet: snip, - }) + } } // Create a snippet for a type name where type params become placeholders. diff --git a/gopls/internal/golang/completion/unify.go b/gopls/internal/golang/completion/unify.go new file mode 100644 index 00000000000..8f4a1d3cbe0 --- /dev/null +++ b/gopls/internal/golang/completion/unify.go @@ -0,0 +1,710 @@ +// Below was copied from go/types/unify.go on September 24, 2024, +// and combined with snippets from other files as well. +// It is copied to implement unification for code completion inferences, +// in lieu of an official type unification API. +// +// TODO: When such an API is available, the code below should deleted. +// +// Due to complexity of extracting private types from the go/types package, +// the unifier does not fully implement interface unification. +// +// The code has been modified to compile without introducing any key functionality changes. +// + +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements type unification. +// +// Type unification attempts to make two types x and y structurally +// equivalent by determining the types for a given list of (bound) +// type parameters which may occur within x and y. If x and y are +// structurally different (say []T vs chan T), or conflicting +// types are determined for type parameters, unification fails. +// If unification succeeds, as a side-effect, the types of the +// bound type parameters may be determined. +// +// Unification typically requires multiple calls u.unify(x, y) to +// a given unifier u, with various combinations of types x and y. +// In each call, additional type parameter types may be determined +// as a side effect and recorded in u. +// If a call fails (returns false), unification fails. +// +// In the unification context, structural equivalence of two types +// ignores the difference between a defined type and its underlying +// type if one type is a defined type and the other one is not. +// It also ignores the difference between an (external, unbound) +// type parameter and its core type. +// If two types are not structurally equivalent, they cannot be Go +// identical types. On the other hand, if they are structurally +// equivalent, they may be Go identical or at least assignable, or +// they may be in the type set of a constraint. +// Whether they indeed are identical or assignable is determined +// upon instantiation and function argument passing. + +package completion + +import ( + "fmt" + "go/types" + "strings" +) + +const ( + // Upper limit for recursion depth. Used to catch infinite recursions + // due to implementation issues (e.g., see issues go.dev/issue/48619, go.dev/issue/48656). + unificationDepthLimit = 50 + + // Whether to panic when unificationDepthLimit is reached. + // If disabled, a recursion depth overflow results in a (quiet) + // unification failure. + panicAtUnificationDepthLimit = true + + // If enableCoreTypeUnification is set, unification will consider + // the core types, if any, of non-local (unbound) type parameters. + enableCoreTypeUnification = true +) + +// A unifier maintains a list of type parameters and +// corresponding types inferred for each type parameter. +// A unifier is created by calling newUnifier. +type unifier struct { + // handles maps each type parameter to its inferred type through + // an indirection *Type called (inferred type) "handle". + // Initially, each type parameter has its own, separate handle, + // with a nil (i.e., not yet inferred) type. + // After a type parameter P is unified with a type parameter Q, + // P and Q share the same handle (and thus type). This ensures + // that inferring the type for a given type parameter P will + // automatically infer the same type for all other parameters + // unified (joined) with P. + handles map[*types.TypeParam]*types.Type + depth int // recursion depth during unification +} + +// newUnifier returns a new unifier initialized with the given type parameter +// and corresponding type argument lists. The type argument list may be shorter +// than the type parameter list, and it may contain nil types. Matching type +// parameters and arguments must have the same index. +func newUnifier(tparams []*types.TypeParam, targs []types.Type) *unifier { + handles := make(map[*types.TypeParam]*types.Type, len(tparams)) + // Allocate all handles up-front: in a correct program, all type parameters + // must be resolved and thus eventually will get a handle. + // Also, sharing of handles caused by unified type parameters is rare and + // so it's ok to not optimize for that case (and delay handle allocation). + for i, x := range tparams { + var t types.Type + if i < len(targs) { + t = targs[i] + } + handles[x] = &t + } + return &unifier{handles, 0} +} + +// unifyMode controls the behavior of the unifier. +type unifyMode uint + +const ( + // If unifyModeAssign is set, we are unifying types involved in an assignment: + // they may match inexactly at the top, but element types must match + // exactly. + unifyModeAssign unifyMode = 1 << iota + + // If unifyModeExact is set, types unify if they are identical (or can be + // made identical with suitable arguments for type parameters). + // Otherwise, a named type and a type literal unify if their + // underlying types unify, channel directions are ignored, and + // if there is an interface, the other type must implement the + // interface. + unifyModeExact +) + +// This function was copied from go/types/unify.go +// +// unify attempts to unify x and y and reports whether it succeeded. +// As a side-effect, types may be inferred for type parameters. +// The mode parameter controls how types are compared. +func (u *unifier) unify(x, y types.Type, mode unifyMode) bool { + return u.nify(x, y, mode) +} + +type typeParamsById []*types.TypeParam + +// join unifies the given type parameters x and y. +// If both type parameters already have a type associated with them +// and they are not joined, join fails and returns false. +func (u *unifier) join(x, y *types.TypeParam) bool { + switch hx, hy := u.handles[x], u.handles[y]; { + case hx == hy: + // Both type parameters already share the same handle. Nothing to do. + case *hx != nil && *hy != nil: + // Both type parameters have (possibly different) inferred types. Cannot join. + return false + case *hx != nil: + // Only type parameter x has an inferred type. Use handle of x. + u.setHandle(y, hx) + // This case is treated like the default case. + // case *hy != nil: + // // Only type parameter y has an inferred type. Use handle of y. + // u.setHandle(x, hy) + default: + // Neither type parameter has an inferred type. Use handle of y. + u.setHandle(x, hy) + } + return true +} + +// asBoundTypeParam returns x.(*types.TypeParam) if x is a type parameter recorded with u. +// Otherwise, the result is nil. +func (u *unifier) asBoundTypeParam(x types.Type) *types.TypeParam { + if x, _ := types.Unalias(x).(*types.TypeParam); x != nil { + if _, found := u.handles[x]; found { + return x + } + } + return nil +} + +// setHandle sets the handle for type parameter x +// (and all its joined type parameters) to h. +func (u *unifier) setHandle(x *types.TypeParam, h *types.Type) { + hx := u.handles[x] + for y, hy := range u.handles { + if hy == hx { + u.handles[y] = h + } + } +} + +// at returns the (possibly nil) type for type parameter x. +func (u *unifier) at(x *types.TypeParam) types.Type { + return *u.handles[x] +} + +// set sets the type t for type parameter x; +// t must not be nil. +func (u *unifier) set(x *types.TypeParam, t types.Type) { + *u.handles[x] = t +} + +// unknowns returns the number of type parameters for which no type has been set yet. +func (u *unifier) unknowns() int { + n := 0 + for _, h := range u.handles { + if *h == nil { + n++ + } + } + return n +} + +// inferred returns the list of inferred types for the given type parameter list. +// The result is never nil and has the same length as tparams; result types that +// could not be inferred are nil. Corresponding type parameters and result types +// have identical indices. +func (u *unifier) inferred(tparams []*types.TypeParam) []types.Type { + list := make([]types.Type, len(tparams)) + for i, x := range tparams { + list[i] = u.at(x) + } + return list +} + +// asInterface returns the underlying type of x as an interface if +// it is a non-type parameter interface. Otherwise it returns nil. +func asInterface(x types.Type) (i *types.Interface) { + if _, ok := types.Unalias(x).(*types.TypeParam); !ok { + i, _ = x.Underlying().(*types.Interface) + } + return i +} + +func isTypeParam(t types.Type) bool { + _, ok := types.Unalias(t).(*types.TypeParam) + return ok +} + +func asNamed(t types.Type) *types.Named { + n, _ := types.Unalias(t).(*types.Named) + return n +} + +func isTypeLit(t types.Type) bool { + switch types.Unalias(t).(type) { + case *types.Named, *types.TypeParam: + return false + } + return true +} + +// identicalOrigin reports whether x and y originated in the same declaration. +func identicalOrigin(x, y *types.Named) bool { + // TODO(gri) is this correct? + return x.Origin().Obj() == y.Origin().Obj() +} + +func match(x, y types.Type) types.Type { + // Common case: we don't have channels. + if types.Identical(x, y) { + return x + } + + // We may have channels that differ in direction only. + if x, _ := x.(*types.Chan); x != nil { + if y, _ := y.(*types.Chan); y != nil && types.Identical(x.Elem(), y.Elem()) { + // We have channels that differ in direction only. + // If there's an unrestricted channel, select the restricted one. + switch { + case x.Dir() == types.SendRecv: + return y + case y.Dir() == types.SendRecv: + return x + } + } + } + + // types are different + return nil +} + +func coreType(t types.Type) types.Type { + t = types.Unalias(t) + tpar, _ := t.(*types.TypeParam) + if tpar == nil { + return t.Underlying() + } + + return nil +} + +func sameId(obj *types.Var, pkg *types.Package, name string, foldCase bool) bool { + // If we don't care about capitalization, we also ignore packages. + if foldCase && strings.EqualFold(obj.Name(), name) { + return true + } + // spec: + // "Two identifiers are different if they are spelled differently, + // or if they appear in different packages and are not exported. + // Otherwise, they are the same." + if obj.Name() != name { + return false + } + // obj.Name == name + if obj.Exported() { + return true + } + // not exported, so packages must be the same + if obj.Pkg() != nil && pkg != nil { + return obj.Pkg() == pkg + } + return obj.Pkg().Path() == pkg.Path() +} + +// nify implements the core unification algorithm which is an +// adapted version of Checker.identical. For changes to that +// code the corresponding changes should be made here. +// Must not be called directly from outside the unifier. +func (u *unifier) nify(x, y types.Type, mode unifyMode) (result bool) { + u.depth++ + defer func() { + u.depth-- + }() + + // nothing to do if x == y + if x == y || types.Unalias(x) == types.Unalias(y) { + return true + } + + // Stop gap for cases where unification fails. + if u.depth > unificationDepthLimit { + if panicAtUnificationDepthLimit { + panic("unification reached recursion depth limit") + } + return false + } + + // Unification is symmetric, so we can swap the operands. + // Ensure that if we have at least one + // - defined type, make sure one is in y + // - type parameter recorded with u, make sure one is in x + if asNamed(x) != nil || u.asBoundTypeParam(y) != nil { + x, y = y, x + } + + // Unification will fail if we match a defined type against a type literal. + // If we are matching types in an assignment, at the top-level, types with + // the same type structure are permitted as long as at least one of them + // is not a defined type. To accommodate for that possibility, we continue + // unification with the underlying type of a defined type if the other type + // is a type literal. This is controlled by the exact unification mode. + // We also continue if the other type is a basic type because basic types + // are valid underlying types and may appear as core types of type constraints. + // If we exclude them, inferred defined types for type parameters may not + // match against the core types of their constraints (even though they might + // correctly match against some of the types in the constraint's type set). + // Finally, if unification (incorrectly) succeeds by matching the underlying + // type of a defined type against a basic type (because we include basic types + // as type literals here), and if that leads to an incorrectly inferred type, + // we will fail at function instantiation or argument assignment time. + // + // If we have at least one defined type, there is one in y. + if ny := asNamed(y); mode&unifyModeExact == 0 && ny != nil && isTypeLit(x) { + y = ny.Underlying() + // Per the spec, a defined type cannot have an underlying type + // that is a type parameter. + // x and y may be identical now + if x == y || types.Unalias(x) == types.Unalias(y) { + return true + } + } + + // Cases where at least one of x or y is a type parameter recorded with u. + // If we have at least one type parameter, there is one in x. + // If we have exactly one type parameter, because it is in x, + // isTypeLit(x) is false and y was not changed above. In other + // words, if y was a defined type, it is still a defined type + // (relevant for the logic below). + switch px, py := u.asBoundTypeParam(x), u.asBoundTypeParam(y); { + case px != nil && py != nil: + // both x and y are type parameters + if u.join(px, py) { + return true + } + // both x and y have an inferred type - they must match + return u.nify(u.at(px), u.at(py), mode) + + case px != nil: + // x is a type parameter, y is not + if x := u.at(px); x != nil { + // x has an inferred type which must match y + if u.nify(x, y, mode) { + // We have a match, possibly through underlying types. + xi := asInterface(x) + yi := asInterface(y) + xn := asNamed(x) != nil + yn := asNamed(y) != nil + // If we have two interfaces, what to do depends on + // whether they are named and their method sets. + if xi != nil && yi != nil { + // Both types are interfaces. + // If both types are defined types, they must be identical + // because unification doesn't know which type has the "right" name. + if xn && yn { + return types.Identical(x, y) + } + return false + // Below is the original code for reference + + // In all other cases, the method sets must match. + // The types unified so we know that corresponding methods + // match and we can simply compare the number of methods. + // TODO(gri) We may be able to relax this rule and select + // the more general interface. But if one of them is a defined + // type, it's not clear how to choose and whether we introduce + // an order dependency or not. Requiring the same method set + // is conservative. + // if len(xi.typeSet().methods) != len(yi.typeSet().methods) { + // return false + // } + } else if xi != nil || yi != nil { + // One but not both of them are interfaces. + // In this case, either x or y could be viable matches for the corresponding + // type parameter, which means choosing either introduces an order dependence. + // Therefore, we must fail unification (go.dev/issue/60933). + return false + } + // If we have inexact unification and one of x or y is a defined type, select the + // defined type. This ensures that in a series of types, all matching against the + // same type parameter, we infer a defined type if there is one, independent of + // order. Type inference or assignment may fail, which is ok. + // Selecting a defined type, if any, ensures that we don't lose the type name; + // and since we have inexact unification, a value of equally named or matching + // undefined type remains assignable (go.dev/issue/43056). + // + // Similarly, if we have inexact unification and there are no defined types but + // channel types, select a directed channel, if any. This ensures that in a series + // of unnamed types, all matching against the same type parameter, we infer the + // directed channel if there is one, independent of order. + // Selecting a directional channel, if any, ensures that a value of another + // inexactly unifying channel type remains assignable (go.dev/issue/62157). + // + // If we have multiple defined channel types, they are either identical or we + // have assignment conflicts, so we can ignore directionality in this case. + // + // If we have defined and literal channel types, a defined type wins to avoid + // order dependencies. + if mode&unifyModeExact == 0 { + switch { + case xn: + // x is a defined type: nothing to do. + case yn: + // x is not a defined type and y is a defined type: select y. + u.set(px, y) + default: + // Neither x nor y are defined types. + if yc, _ := y.Underlying().(*types.Chan); yc != nil && yc.Dir() != types.SendRecv { + // y is a directed channel type: select y. + u.set(px, y) + } + } + } + return true + } + return false + } + // otherwise, infer type from y + u.set(px, y) + return true + } + + // If u.EnableInterfaceInference is set and we don't require exact unification, + // if both types are interfaces, one interface must have a subset of the + // methods of the other and corresponding method signatures must unify. + // If only one type is an interface, all its methods must be present in the + // other type and corresponding method signatures must unify. + + // Unless we have exact unification, neither x nor y are interfaces now. + // Except for unbound type parameters (see below), x and y must be structurally + // equivalent to unify. + + // If we get here and x or y is a type parameter, they are unbound + // (not recorded with the unifier). + // Ensure that if we have at least one type parameter, it is in x + // (the earlier swap checks for _recorded_ type parameters only). + // This ensures that the switch switches on the type parameter. + // + // TODO(gri) Factor out type parameter handling from the switch. + if isTypeParam(y) { + x, y = y, x + } + + // Type elements (array, slice, etc. elements) use emode for unification. + // Element types must match exactly if the types are used in an assignment. + emode := mode + if mode&unifyModeAssign != 0 { + emode |= unifyModeExact + } + + // Continue with unaliased types but don't lose original alias names, if any (go.dev/issue/67628). + xorig, x := x, types.Unalias(x) + yorig, y := y, types.Unalias(y) + + switch x := x.(type) { + case *types.Basic: + // Basic types are singletons except for the rune and byte + // aliases, thus we cannot solely rely on the x == y check + // above. See also comment in TypeName.IsAlias. + if y, ok := y.(*types.Basic); ok { + return x.Kind() == y.Kind() + } + + case *types.Array: + // Two array types unify if they have the same array length + // and their element types unify. + if y, ok := y.(*types.Array); ok { + // If one or both array lengths are unknown (< 0) due to some error, + // assume they are the same to avoid spurious follow-on errors. + return (x.Len() < 0 || y.Len() < 0 || x.Len() == y.Len()) && u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Slice: + // Two slice types unify if their element types unify. + if y, ok := y.(*types.Slice); ok { + return u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Struct: + // Two struct types unify if they have the same sequence of fields, + // and if corresponding fields have the same names, their (field) types unify, + // and they have identical tags. Two embedded fields are considered to have the same + // name. Lower-case field names from different packages are always different. + if y, ok := y.(*types.Struct); ok { + if x.NumFields() == y.NumFields() { + for i := range x.NumFields() { + f := x.Field(i) + g := y.Field(i) + if f.Embedded() != g.Embedded() || + x.Tag(i) != y.Tag(i) || + !sameId(f, g.Pkg(), g.Name(), false) || + !u.nify(f.Type(), g.Type(), emode) { + return false + } + } + return true + } + } + + case *types.Pointer: + // Two pointer types unify if their base types unify. + if y, ok := y.(*types.Pointer); ok { + return u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Tuple: + // Two tuples types unify if they have the same number of elements + // and the types of corresponding elements unify. + if y, ok := y.(*types.Tuple); ok { + if x.Len() == y.Len() { + if x != nil { + for i := range x.Len() { + v := x.At(i) + w := y.At(i) + if !u.nify(v.Type(), w.Type(), mode) { + return false + } + } + } + return true + } + } + + case *types.Signature: + // Two function types unify if they have the same number of parameters + // and result values, corresponding parameter and result types unify, + // and either both functions are variadic or neither is. + // Parameter and result names are not required to match. + // TODO(gri) handle type parameters or document why we can ignore them. + if y, ok := y.(*types.Signature); ok { + return x.Variadic() == y.Variadic() && + u.nify(x.Params(), y.Params(), emode) && + u.nify(x.Results(), y.Results(), emode) + } + + case *types.Interface: + return false + // Below is the original code + + // Two interface types unify if they have the same set of methods with + // the same names, and corresponding function types unify. + // Lower-case method names from different packages are always different. + // The order of the methods is irrelevant. + // xset := x.typeSet() + // yset := y.typeSet() + // if xset.comparable != yset.comparable { + // return false + // } + // if !xset.terms.equal(yset.terms) { + // return false + // } + // a := xset.methods + // b := yset.methods + // if len(a) == len(b) { + // // Interface types are the only types where cycles can occur + // // that are not "terminated" via named types; and such cycles + // // can only be created via method parameter types that are + // // anonymous interfaces (directly or indirectly) embedding + // // the current interface. Example: + // // + // // type T interface { + // // m() interface{T} + // // } + // // + // // If two such (differently named) interfaces are compared, + // // endless recursion occurs if the cycle is not detected. + // // + // // If x and y were compared before, they must be equal + // // (if they were not, the recursion would have stopped); + // // search the ifacePair stack for the same pair. + // // + // // This is a quadratic algorithm, but in practice these stacks + // // are extremely short (bounded by the nesting depth of interface + // // type declarations that recur via parameter types, an extremely + // // rare occurrence). An alternative implementation might use a + // // "visited" map, but that is probably less efficient overall. + // q := &ifacePair{x, y, p} + // for p != nil { + // if p.identical(q) { + // return true // same pair was compared before + // } + // p = p.prev + // } + // if debug { + // assertSortedMethods(a) + // assertSortedMethods(b) + // } + // for i, f := range a { + // g := b[i] + // if f.Id() != g.Id() || !u.nify(f.typ, g.typ, exact, q) { + // return false + // } + // } + // return true + // } + + case *types.Map: + // Two map types unify if their key and value types unify. + if y, ok := y.(*types.Map); ok { + return u.nify(x.Key(), y.Key(), emode) && u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Chan: + // Two channel types unify if their value types unify + // and if they have the same direction. + // The channel direction is ignored for inexact unification. + if y, ok := y.(*types.Chan); ok { + return (mode&unifyModeExact == 0 || x.Dir() == y.Dir()) && u.nify(x.Elem(), y.Elem(), emode) + } + + case *types.Named: + // Two named types unify if their type names originate in the same type declaration. + // If they are instantiated, their type argument lists must unify. + if y := asNamed(y); y != nil { + // Check type arguments before origins so they unify + // even if the origins don't match; for better error + // messages (see go.dev/issue/53692). + xargs := x.TypeArgs() + yargs := y.TypeArgs() + if xargs.Len() != yargs.Len() { + return false + } + for i := range xargs.Len() { + xarg := xargs.At(i) + yarg := yargs.At(i) + if !u.nify(xarg, yarg, mode) { + return false + } + } + return identicalOrigin(x, y) + } + + case *types.TypeParam: + // By definition, a valid type argument must be in the type set of + // the respective type constraint. Therefore, the type argument's + // underlying type must be in the set of underlying types of that + // constraint. If there is a single such underlying type, it's the + // constraint's core type. It must match the type argument's under- + // lying type, irrespective of whether the actual type argument, + // which may be a defined type, is actually in the type set (that + // will be determined at instantiation time). + // Thus, if we have the core type of an unbound type parameter, + // we know the structure of the possible types satisfying such + // parameters. Use that core type for further unification + // (see go.dev/issue/50755 for a test case). + if enableCoreTypeUnification { + // Because the core type is always an underlying type, + // unification will take care of matching against a + // defined or literal type automatically. + // If y is also an unbound type parameter, we will end + // up here again with x and y swapped, so we don't + // need to take care of that case separately. + if cx := coreType(x); cx != nil { + // If y is a defined type, it may not match against cx which + // is an underlying type (incl. int, string, etc.). Use assign + // mode here so that the unifier automatically takes under(y) + // if necessary. + return u.nify(cx, yorig, unifyModeAssign) + } + } + // x != y and there's nothing to do + + case nil: + // avoid a crash in case of nil type + + default: + panic(fmt.Sprintf("u.nify(%s, %s, %d)", xorig, yorig, mode)) + } + + return false +} diff --git a/gopls/internal/test/integration/completion/completion_test.go b/gopls/internal/test/integration/completion/completion_test.go index c96e569f1ad..19ee5d3ef68 100644 --- a/gopls/internal/test/integration/completion/completion_test.go +++ b/gopls/internal/test/integration/completion/completion_test.go @@ -970,6 +970,277 @@ use ./missing/ }) } +const reverseInferenceSrcPrelude = ` +-- go.mod -- +module mod.com + +go 1.18 +-- a.go -- +package a + +type InterfaceA interface { + implA() +} + +type InterfaceB interface { + implB() +} + + +type TypeA struct{} + +func (TypeA) implA() {} + +type TypeX string + +func (TypeX) implB() {} + +type TypeB struct{} + +func (TypeB) implB() {} + +type TypeC struct{} // should have no impact + +type Wrap[T any] struct { + inner *T +} + +func NewWrap[T any](x T) Wrap[T] { + return Wrap[T]{inner: &x} +} + +func DoubleWrap[T any, U any](t T, u U) (Wrap[T], Wrap[U]) { + return Wrap[T]{inner: &t}, Wrap[U]{inner: &u} +} + +func IntWrap[T int32 | int64](x T) Wrap[T] { + return Wrap[T]{inner: &x} +} + +var ia InterfaceA +var ib InterfaceB + +var avar TypeA +var bvar TypeB + +var i int +var i32 int32 +var i64 int64 +` + +func TestReverseInferCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var _ Wrap[int64] = IntWrap() + } + ` + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `IntWrap\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + wantLabel := []string{"i64", "i", "i32", "int64()"} + + // only check the prefix due to formatting differences with escaped characters + wantText := []string{"i64", "int64(i", "int64(i32", "int64("} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + + if insertText, ok := item.TextEdit.Value.(protocol.InsertReplaceEdit); ok { + if diff := cmp.Diff(wantText[i], insertText.NewText[:len(wantText[i])]); diff != "" { + t.Errorf("Completion: unexpected insertText mismatch (checks prefix only) (-want +got):\n%s", diff) + } + } + } + }) +} + +func TestInterfaceReverseInferCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + var wb Wrap[InterfaceB] + wb = NewWrap() // wb is of type Wrap[InterfaceB] + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `NewWrap\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + wantLabel := []string{"ib", "bvar", "wb.inner", "TypeB{}", "TypeX()", "nil"} + + // only check the prefix due to formatting differences with escaped characters + wantText := []string{"ib", "InterfaceB(", "*wb.inner", "InterfaceB(", "InterfaceB(", "nil"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + + if insertText, ok := item.TextEdit.Value.(protocol.InsertReplaceEdit); ok { + if diff := cmp.Diff(wantText[i], insertText.NewText[:len(wantText[i])]); diff != "" { + t.Errorf("Completion: unexpected insertText mismatch (checks prefix only) (-want +got):\n%s", diff) + } + } + } + }) +} + +func TestInvalidReverseInferenceDefaultsToConstraintCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + // This is ambiguous, so default to the constraint rather the inference. + wa = IntWrap() + } + ` + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `IntWrap\(()\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + + wantLabel := []string{"i32", "i64", "nil"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestInterfaceReverseInferTypeParamCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + var wb Wrap[InterfaceB] + wb = NewWrap[]() + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `NewWrap\[()\]\(\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + want := []string{"InterfaceB", "TypeB", "TypeX", "InterfaceA", "TypeA"} + for i, item := range result.Items[:len(want)] { + if diff := cmp.Diff(want[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestInvalidReverseInferenceTypeParamDefaultsToConstraintCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + // This is ambiguous, so default to the constraint rather the inference. + wb = IntWrap[]() + } + ` + + Run(t, src, func(t *testing.T, env *Env) { + compl := env.RegexpSearch("a.go", `IntWrap\[()\]\(\)`) + + env.OpenFile("a.go") + result := env.Completion(compl) + want := []string{"int32", "int64"} + for i, item := range result.Items[:len(want)] { + if diff := cmp.Diff(want[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestReverseInferDoubleTypeParamCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func main() { + var wa Wrap[InterfaceA] + var wb Wrap[InterfaceB] + + wa, wb = DoubleWrap[]() + // _ is necessary to trick the parser into an index list expression + wa, wb = DoubleWrap[InterfaceA, _]() + } + ` + Run(t, src, func(t *testing.T, env *Env) { + env.OpenFile("a.go") + + compl := env.RegexpSearch("a.go", `DoubleWrap\[()\]\(\)`) + result := env.Completion(compl) + + wantLabel := []string{"InterfaceA", "TypeA", "InterfaceB", "TypeB", "TypeC"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + + compl = env.RegexpSearch("a.go", `DoubleWrap\[InterfaceA, (_)\]\(\)`) + result = env.Completion(compl) + + wantLabel = []string{"InterfaceB", "TypeB", "TypeX", "InterfaceA", "TypeA"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + }) +} + +func TestDoubleParamReturnCompletion(t *testing.T) { + src := reverseInferenceSrcPrelude + ` + func concrete() (Wrap[InterfaceA], Wrap[InterfaceB]) { + return DoubleWrap[]() + } + + func concrete2() (Wrap[InterfaceA], Wrap[InterfaceB]) { + return DoubleWrap[InterfaceA, _]() + } + + func main() {} + ` + + Run(t, src, func(t *testing.T, env *Env) { + env.OpenFile("a.go") + + compl := env.RegexpSearch("a.go", `DoubleWrap\[()\]\(\)`) + result := env.Completion(compl) + + wantLabel := []string{"InterfaceA", "TypeA", "InterfaceB", "TypeB", "TypeC"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + + compl = env.RegexpSearch("a.go", `DoubleWrap\[InterfaceA, (_)\]\(\)`) + result = env.Completion(compl) + + wantLabel = []string{"InterfaceB", "TypeB", "TypeX", "InterfaceA", "TypeA"} + + for i, item := range result.Items[:len(wantLabel)] { + if diff := cmp.Diff(wantLabel[i], item.Label); diff != "" { + t.Errorf("Completion: unexpected label mismatch (-want +got):\n%s", diff) + } + } + }) +} + func TestBuiltinCompletion(t *testing.T) { const files = ` -- go.mod --