diff --git a/src/cmd/compile/internal/types2/api.go b/src/cmd/compile/internal/types2/api.go index b798f2c888e8f3..33db615cfb27d6 100644 --- a/src/cmd/compile/internal/types2/api.go +++ b/src/cmd/compile/internal/types2/api.go @@ -169,6 +169,11 @@ type Config struct { // If DisableUnusedImportCheck is set, packages are not checked // for unused imports. DisableUnusedImportCheck bool + + // If EnableInterfaceInference is set, type inference uses + // shared methods for improved type inference involving + // interfaces. + EnableInterfaceInference bool } func srcimporter_setUsesCgo(conf *Config) { diff --git a/src/cmd/compile/internal/types2/check_test.go b/src/cmd/compile/internal/types2/check_test.go index b149ae3908907b..c1b4a20624a27c 100644 --- a/src/cmd/compile/internal/types2/check_test.go +++ b/src/cmd/compile/internal/types2/check_test.go @@ -126,6 +126,7 @@ func testFiles(t *testing.T, filenames []string, srcs [][]byte, colDelta uint, m flags := flag.NewFlagSet("", flag.PanicOnError) flags.StringVar(&conf.GoVersion, "lang", "", "") flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "") + flags.BoolVar(&conf.EnableInterfaceInference, "EnableInterfaceInference", false, "") if err := parseFlags(srcs[0], flags); err != nil { t.Fatal(err) } diff --git a/src/cmd/compile/internal/types2/infer.go b/src/cmd/compile/internal/types2/infer.go index 097e9c7ddb6722..9ea488761aaf43 100644 --- a/src/cmd/compile/internal/types2/infer.go +++ b/src/cmd/compile/internal/types2/infer.go @@ -96,7 +96,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type, // Unify parameter and argument types for generic parameters with typed arguments // and collect the indices of generic parameters with untyped arguments. // Terminology: generic parameter = function parameter with a type-parameterized type - u := newUnifier(tparams, targs) + u := check.newUnifier(tparams, targs) errorf := func(kind string, tpar, targ Type, arg *operand) { // provide a better error message if we can diff --git a/src/cmd/compile/internal/types2/unify.go b/src/cmd/compile/internal/types2/unify.go index 997f3556640870..87c845f12cc0fb 100644 --- a/src/cmd/compile/internal/types2/unify.go +++ b/src/cmd/compile/internal/types2/unify.go @@ -67,6 +67,7 @@ const ( // corresponding types inferred for each type parameter. // A unifier is created by calling newUnifier. type unifier struct { + check *Checker // handles maps each type parameter to its inferred type through // an indirection *Type called (inferred type) "handle". // Initially, each type parameter has its own, separate handle, @@ -84,7 +85,7 @@ type unifier struct { // and corresponding type argument lists. The type argument list may be shorter // than the type parameter list, and it may contain nil types. Matching type // parameters and arguments must have the same index. -func newUnifier(tparams []*TypeParam, targs []Type) *unifier { +func (check *Checker) newUnifier(tparams []*TypeParam, targs []Type) *unifier { assert(len(tparams) >= len(targs)) handles := make(map[*TypeParam]*Type, len(tparams)) // Allocate all handles up-front: in a correct program, all type parameters @@ -98,7 +99,7 @@ func newUnifier(tparams []*TypeParam, targs []Type) *unifier { } handles[x] = &t } - return &unifier{handles, 0} + return &unifier{check, handles, 0} } // unify attempts to unify x and y and reports whether it succeeded. @@ -280,7 +281,9 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { // the same type structure are permitted as long as at least one of them // is not a defined type. To accommodate for that possibility, we continue // unification with the underlying type of a defined type if the other type - // is a type literal. + // is a type literal. However, if the type literal is an interface and we + // set EnableInterfaceInference, we continue with the defined type because + // otherwise we may lose its methods. // We also continue if the other type is a basic type because basic types // are valid underlying types and may appear as core types of type constraints. // If we exclude them, inferred defined types for type parameters may not @@ -292,7 +295,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { // we will fail at function instantiation or argument assignment time. // // If we have at least one defined type, there is one in y. - if ny, _ := y.(*Named); ny != nil && isTypeLit(x) { + if ny, _ := y.(*Named); ny != nil && isTypeLit(x) && !(u.check.conf.EnableInterfaceInference && IsInterface(x)) { if traceInference { u.tracef("%s ≡ under %s", x, ny) } @@ -356,6 +359,104 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { x, y = y, x } + // If EnableInterfaceInference is set and both types are interfaces, one + // interface must have a subset of the methods of the other and corresponding + // method signatures must unify. + // If only one type is an interface, all its methods must be present in the + // other type and corresponding method signatures must unify. + if u.check.conf.EnableInterfaceInference { + xi, _ := x.(*Interface) + yi, _ := y.(*Interface) + // If we have two interfaces, check the type terms for equivalence, + // and unify common methods if possible. + if xi != nil && yi != nil { + xset := xi.typeSet() + yset := yi.typeSet() + if xset.comparable != yset.comparable { + return false + } + // For now we require terms to be equal. + // We should be able to relax this as well, eventually. + if !xset.terms.equal(yset.terms) { + return false + } + // Interface types are the only types where cycles can occur + // that are not "terminated" via named types; and such cycles + // can only be created via method parameter types that are + // anonymous interfaces (directly or indirectly) embedding + // the current interface. Example: + // + // type T interface { + // m() interface{T} + // } + // + // If two such (differently named) interfaces are compared, + // endless recursion occurs if the cycle is not detected. + // + // If x and y were compared before, they must be equal + // (if they were not, the recursion would have stopped); + // search the ifacePair stack for the same pair. + // + // This is a quadratic algorithm, but in practice these stacks + // are extremely short (bounded by the nesting depth of interface + // type declarations that recur via parameter types, an extremely + // rare occurrence). An alternative implementation might use a + // "visited" map, but that is probably less efficient overall. + q := &ifacePair{xi, yi, p} + for p != nil { + if p.identical(q) { + return true // same pair was compared before + } + p = p.prev + } + // The method set of x must be a subset of the method set + // of y or vice versa, and the common methods must unify. + xmethods := xset.methods + ymethods := yset.methods + // The smaller method set must be the subset, if it exists. + if len(xmethods) > len(ymethods) { + xmethods, ymethods = ymethods, xmethods + } + // len(xmethods) <= len(ymethods) + // Collect the ymethods in a map for quick lookup. + ymap := make(map[string]*Func, len(ymethods)) + for _, ym := range ymethods { + ymap[ym.Id()] = ym + } + // All xmethods must exist in ymethods and corresponding signatures must unify. + for _, xm := range xmethods { + if ym := ymap[xm.Id()]; ym == nil || !u.nify(xm.typ, ym.typ, p) { + return false + } + } + return true + } + + // We don't have two interfaces. If we have one, make sure it's in xi. + if yi != nil { + xi = yi + y = x + } + + // If we have one interface, at a minimum each of the interface methods + // must be implemented and thus unify with a corresponding method from + // the non-interface type, otherwise unification fails. + if xi != nil { + // All xi methods must exist in y and corresponding signatures must unify. + xmethods := xi.typeSet().methods + for _, xm := range xmethods { + obj, _, _ := LookupFieldOrMethod(y, false, xm.pkg, xm.name) + if ym, _ := obj.(*Func); ym == nil || !u.nify(xm.typ, ym.typ, p) { + return false + } + } + return true + } + + // Neither x nor y are interface types. + // They must be structurally equivalent to unify. + } + switch x := x.(type) { case *Basic: // Basic types are singletons except for the rune and byte @@ -436,6 +537,8 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } case *Interface: + assert(!u.check.conf.EnableInterfaceInference) // handled before this switch + // Two interface types unify if they have the same set of methods with // the same names, and corresponding function types unify. // Lower-case method names from different packages are always different. @@ -508,10 +611,49 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } case *Named: - // Two named types unify if their type names originate + // Two defined types unify if their type names originate // in the same type declaration. If they are instantiated, // their type argument lists must unify. if y, ok := y.(*Named); ok { + sameOrig := indenticalOrigin(x, y) + if u.check.conf.EnableInterfaceInference { + xu := x.under() + yu := y.under() + xi, _ := xu.(*Interface) + yi, _ := yu.(*Interface) + // If one or both defined types are interfaces, use interface unification, + // unless they originated in the same type declaration. + if xi != nil && yi != nil { + // If both interfaces originate in the same declaration, + // their methods unify if the type parameters unify. + // Unify the type parameters rather than the methods in + // case the type parameters are not used in the methods + // (and to preserve existing behavior in this case). + if sameOrig { + xargs := x.TypeArgs().list() + yargs := y.TypeArgs().list() + assert(len(xargs) == len(yargs)) + for i, xarg := range xargs { + if !u.nify(xarg, yargs[i], p) { + return false + } + } + return true + } + return u.nify(xu, yu, p) + } + // We don't have two interfaces. If we have one, make sure it's in xi. + if yi != nil { + xi = yi + y = x + } + // If xi is an interface, use interface unification. + if xi != nil { + return u.nify(xi, y, p) + } + // In all other cases, the type arguments and origins must match. + } + // Check type arguments before origins so they unify // even if the origins don't match; for better error // messages (see go.dev/issue/53692). @@ -525,7 +667,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { return false } } - return indenticalOrigin(x, y) + return sameOrig } case *TypeParam: diff --git a/src/go/types/api.go b/src/go/types/api.go index 08430c9e7a5c85..a68e0ea16c1c30 100644 --- a/src/go/types/api.go +++ b/src/go/types/api.go @@ -170,6 +170,11 @@ type Config struct { // If DisableUnusedImportCheck is set, packages are not checked // for unused imports. DisableUnusedImportCheck bool + + // If _EnableInterfaceInference is set, type inference uses + // shared methods for improved type inference involving + // interfaces. + _EnableInterfaceInference bool } func srcimporter_setUsesCgo(conf *Config) { diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go index 73ac80235cce01..a0b9f54dbfc397 100644 --- a/src/go/types/check_test.go +++ b/src/go/types/check_test.go @@ -137,6 +137,7 @@ func testFiles(t *testing.T, filenames []string, srcs [][]byte, manual bool, opt flags := flag.NewFlagSet("", flag.PanicOnError) flags.StringVar(&conf.GoVersion, "lang", "", "") flags.BoolVar(&conf.FakeImportC, "fakeImportC", false, "") + flags.BoolVar(boolFieldAddr(&conf, "_EnableInterfaceInference"), "EnableInterfaceInference", false, "") if err := parseFlags(srcs[0], flags); err != nil { t.Fatal(err) } diff --git a/src/go/types/generate_test.go b/src/go/types/generate_test.go index 6a8343c615355c..90711d83285d9e 100644 --- a/src/go/types/generate_test.go +++ b/src/go/types/generate_test.go @@ -137,10 +137,13 @@ var filemap = map[string]action{ "typeterm_test.go": nil, "typeterm.go": nil, "under.go": nil, - "unify.go": fixSprintf, - "universe.go": fixGlobalTypVarDecl, - "util_test.go": fixTokenPos, - "validtype.go": nil, + "unify.go": func(f *ast.File) { + fixSprintf(f) + renameIdent(f, "EnableInterfaceInference", "_EnableInterfaceInference") + }, + "universe.go": fixGlobalTypVarDecl, + "util_test.go": fixTokenPos, + "validtype.go": nil, } // TODO(gri) We should be able to make these rewriters more configurable/composable. diff --git a/src/go/types/infer.go b/src/go/types/infer.go index ae1c2af1e424a9..1f09697d2e1c8f 100644 --- a/src/go/types/infer.go +++ b/src/go/types/infer.go @@ -98,7 +98,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, // Unify parameter and argument types for generic parameters with typed arguments // and collect the indices of generic parameters with untyped arguments. // Terminology: generic parameter = function parameter with a type-parameterized type - u := newUnifier(tparams, targs) + u := check.newUnifier(tparams, targs) errorf := func(kind string, tpar, targ Type, arg *operand) { // provide a better error message if we can diff --git a/src/go/types/unify.go b/src/go/types/unify.go index 484c7adeb3a227..757a0932d61e15 100644 --- a/src/go/types/unify.go +++ b/src/go/types/unify.go @@ -69,6 +69,7 @@ const ( // corresponding types inferred for each type parameter. // A unifier is created by calling newUnifier. type unifier struct { + check *Checker // handles maps each type parameter to its inferred type through // an indirection *Type called (inferred type) "handle". // Initially, each type parameter has its own, separate handle, @@ -86,7 +87,7 @@ type unifier struct { // and corresponding type argument lists. The type argument list may be shorter // than the type parameter list, and it may contain nil types. Matching type // parameters and arguments must have the same index. -func newUnifier(tparams []*TypeParam, targs []Type) *unifier { +func (check *Checker) newUnifier(tparams []*TypeParam, targs []Type) *unifier { assert(len(tparams) >= len(targs)) handles := make(map[*TypeParam]*Type, len(tparams)) // Allocate all handles up-front: in a correct program, all type parameters @@ -100,7 +101,7 @@ func newUnifier(tparams []*TypeParam, targs []Type) *unifier { } handles[x] = &t } - return &unifier{handles, 0} + return &unifier{check, handles, 0} } // unify attempts to unify x and y and reports whether it succeeded. @@ -282,7 +283,9 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { // the same type structure are permitted as long as at least one of them // is not a defined type. To accommodate for that possibility, we continue // unification with the underlying type of a defined type if the other type - // is a type literal. + // is a type literal. However, if the type literal is an interface and we + // set EnableInterfaceInference, we continue with the defined type because + // otherwise we may lose its methods. // We also continue if the other type is a basic type because basic types // are valid underlying types and may appear as core types of type constraints. // If we exclude them, inferred defined types for type parameters may not @@ -294,7 +297,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { // we will fail at function instantiation or argument assignment time. // // If we have at least one defined type, there is one in y. - if ny, _ := y.(*Named); ny != nil && isTypeLit(x) { + if ny, _ := y.(*Named); ny != nil && isTypeLit(x) && !(u.check.conf._EnableInterfaceInference && IsInterface(x)) { if traceInference { u.tracef("%s ≡ under %s", x, ny) } @@ -358,6 +361,104 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { x, y = y, x } + // If EnableInterfaceInference is set and both types are interfaces, one + // interface must have a subset of the methods of the other and corresponding + // method signatures must unify. + // If only one type is an interface, all its methods must be present in the + // other type and corresponding method signatures must unify. + if u.check.conf._EnableInterfaceInference { + xi, _ := x.(*Interface) + yi, _ := y.(*Interface) + // If we have two interfaces, check the type terms for equivalence, + // and unify common methods if possible. + if xi != nil && yi != nil { + xset := xi.typeSet() + yset := yi.typeSet() + if xset.comparable != yset.comparable { + return false + } + // For now we require terms to be equal. + // We should be able to relax this as well, eventually. + if !xset.terms.equal(yset.terms) { + return false + } + // Interface types are the only types where cycles can occur + // that are not "terminated" via named types; and such cycles + // can only be created via method parameter types that are + // anonymous interfaces (directly or indirectly) embedding + // the current interface. Example: + // + // type T interface { + // m() interface{T} + // } + // + // If two such (differently named) interfaces are compared, + // endless recursion occurs if the cycle is not detected. + // + // If x and y were compared before, they must be equal + // (if they were not, the recursion would have stopped); + // search the ifacePair stack for the same pair. + // + // This is a quadratic algorithm, but in practice these stacks + // are extremely short (bounded by the nesting depth of interface + // type declarations that recur via parameter types, an extremely + // rare occurrence). An alternative implementation might use a + // "visited" map, but that is probably less efficient overall. + q := &ifacePair{xi, yi, p} + for p != nil { + if p.identical(q) { + return true // same pair was compared before + } + p = p.prev + } + // The method set of x must be a subset of the method set + // of y or vice versa, and the common methods must unify. + xmethods := xset.methods + ymethods := yset.methods + // The smaller method set must be the subset, if it exists. + if len(xmethods) > len(ymethods) { + xmethods, ymethods = ymethods, xmethods + } + // len(xmethods) <= len(ymethods) + // Collect the ymethods in a map for quick lookup. + ymap := make(map[string]*Func, len(ymethods)) + for _, ym := range ymethods { + ymap[ym.Id()] = ym + } + // All xmethods must exist in ymethods and corresponding signatures must unify. + for _, xm := range xmethods { + if ym := ymap[xm.Id()]; ym == nil || !u.nify(xm.typ, ym.typ, p) { + return false + } + } + return true + } + + // We don't have two interfaces. If we have one, make sure it's in xi. + if yi != nil { + xi = yi + y = x + } + + // If we have one interface, at a minimum each of the interface methods + // must be implemented and thus unify with a corresponding method from + // the non-interface type, otherwise unification fails. + if xi != nil { + // All xi methods must exist in y and corresponding signatures must unify. + xmethods := xi.typeSet().methods + for _, xm := range xmethods { + obj, _, _ := LookupFieldOrMethod(y, false, xm.pkg, xm.name) + if ym, _ := obj.(*Func); ym == nil || !u.nify(xm.typ, ym.typ, p) { + return false + } + } + return true + } + + // Neither x nor y are interface types. + // They must be structurally equivalent to unify. + } + switch x := x.(type) { case *Basic: // Basic types are singletons except for the rune and byte @@ -438,6 +539,8 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } case *Interface: + assert(!u.check.conf._EnableInterfaceInference) // handled before this switch + // Two interface types unify if they have the same set of methods with // the same names, and corresponding function types unify. // Lower-case method names from different packages are always different. @@ -510,10 +613,49 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { } case *Named: - // Two named types unify if their type names originate + // Two defined types unify if their type names originate // in the same type declaration. If they are instantiated, // their type argument lists must unify. if y, ok := y.(*Named); ok { + sameOrig := indenticalOrigin(x, y) + if u.check.conf._EnableInterfaceInference { + xu := x.under() + yu := y.under() + xi, _ := xu.(*Interface) + yi, _ := yu.(*Interface) + // If one or both defined types are interfaces, use interface unification, + // unless they originated in the same type declaration. + if xi != nil && yi != nil { + // If both interfaces originate in the same declaration, + // their methods unify if the type parameters unify. + // Unify the type parameters rather than the methods in + // case the type parameters are not used in the methods + // (and to preserve existing behavior in this case). + if sameOrig { + xargs := x.TypeArgs().list() + yargs := y.TypeArgs().list() + assert(len(xargs) == len(yargs)) + for i, xarg := range xargs { + if !u.nify(xarg, yargs[i], p) { + return false + } + } + return true + } + return u.nify(xu, yu, p) + } + // We don't have two interfaces. If we have one, make sure it's in xi. + if yi != nil { + xi = yi + y = x + } + // If xi is an interface, use interface unification. + if xi != nil { + return u.nify(xi, y, p) + } + // In all other cases, the type arguments and origins must match. + } + // Check type arguments before origins so they unify // even if the origins don't match; for better error // messages (see go.dev/issue/53692). @@ -527,7 +669,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) { return false } } - return indenticalOrigin(x, y) + return sameOrig } case *TypeParam: diff --git a/src/internal/types/testdata/fixedbugs/issue41176.go b/src/internal/types/testdata/fixedbugs/issue41176.go new file mode 100644 index 00000000000000..f863880ec52389 --- /dev/null +++ b/src/internal/types/testdata/fixedbugs/issue41176.go @@ -0,0 +1,23 @@ +// -EnableInterfaceInference + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +type S struct{} + +func (S) M() byte { + return 0 +} + +type I[T any] interface { + M() T +} + +func f[T any](x I[T]) {} + +func _() { + f(S{}) +} diff --git a/src/internal/types/testdata/fixedbugs/issue57192.go b/src/internal/types/testdata/fixedbugs/issue57192.go new file mode 100644 index 00000000000000..2b2fb59f0853f4 --- /dev/null +++ b/src/internal/types/testdata/fixedbugs/issue57192.go @@ -0,0 +1,24 @@ +// -EnableInterfaceInference + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +type I1[T any] interface { + m1(T) +} +type I2[T any] interface { + I1[T] + m2(T) +} + +var V1 I1[int] +var V2 I2[int] + +func g[T any](I1[T]) {} +func _() { + g(V1) + g(V2) +}