diff --git a/_generated/generics.go b/_generated/generics.go index 6beb306d..3f402f21 100644 --- a/_generated/generics.go +++ b/_generated/generics.go @@ -82,3 +82,11 @@ type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct { A A B B } + +// GenericTest4 has the msgp.RTFor constraint as a sub-constraint. +type GenericTest4[T any, P interface { + *T + msgp.RTFor[T] +}] struct { + Totals [60]T +} diff --git a/parse/getast.go b/parse/getast.go index ec9e1531..cdc91699 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -229,9 +229,8 @@ func formatTypeParams(params *ast.FieldList) string { var paramStrs []string for _, field := range params.List { - str := stringify(field.Type) // Convert underscores to _RTn where n is the number of the parameter - convert := strings.HasPrefix(str, "msgp.RTFor[") + convert := isrtfor(field.Type) // Each field can have multiple names (e.g., T, U constraint) for _, name := range field.Names { @@ -247,6 +246,31 @@ func formatTypeParams(params *ast.FieldList) string { return "[" + strings.Join(paramStrs, ", ") + "]" } +// isrtfor returns whether the provided expression is a msgp.RTFor[T] pattern. +func isrtfor(t ast.Expr) bool { return strings.HasPrefix(stringify(t), "msgp.RTFor[") } + +// findRTForInInterface recursively searches for msgp.RTFor[T] patterns within interface types +func findRTForInInterface(iface *ast.InterfaceType) []string { + var rtfors []string + if iface.Methods == nil { + return rtfors + } + + for _, method := range iface.Methods.List { + // Check if this is an embedded interface/type + if len(method.Names) == 0 { + if isrtfor(method.Type) { + rtfors = append(rtfors, stringify(method.Type)) + } + // Recursively check nested interfaces + if nestedIface, ok := method.Type.(*ast.InterfaceType); ok { + rtfors = append(rtfors, findRTForInInterface(nestedIface)...) + } + } + } + return rtfors +} + // formatTypeParams converts an AST FieldList to a string representation. // For 'Foo[T any, P msgp.RTFor[T]]' will return {"T": "P"}. func getMspTypeParams(params *ast.FieldList) map[string]string { @@ -256,16 +280,31 @@ func getMspTypeParams(params *ast.FieldList) map[string]string { paramStrs := make(map[string]string) for _, field := range params.List { - str := stringify(field.Type) - if !strings.HasPrefix(str, "msgp.RTFor[") { + + // Handle simple msgp.RTFor[T] constraints + if isrtfor(field.Type) { + t := strings.TrimSuffix(strings.TrimPrefix(stringify(field.Type), "msgp.RTFor["), "]") + for _, name := range field.Names { + paramStrs[t] = name.Name + "(&%s)" + paramStrs["*"+t] = name.Name + "(%s)" + paramStrs[name.Name] = "%s" + infof("found generic type %s, with roundtrippper %s\n", t, name.Name) + } continue } - for _, name := range field.Names { - t := strings.TrimSuffix(strings.TrimPrefix(str, "msgp.RTFor["), "]") - paramStrs[t] = name.Name + "(&%s)" - paramStrs["*"+t] = name.Name + "(%s)" - paramStrs[name.Name] = "%s" - infof("found generic type %s, with roundtrippper %s\n", t, name.Name) + + // Handle complex interface constraints that embed msgp.RTFor[T] + if iface, ok := field.Type.(*ast.InterfaceType); ok { + rtfors := findRTForInInterface(iface) + for _, rtfor := range rtfors { + t := strings.TrimSuffix(strings.TrimPrefix(rtfor, "msgp.RTFor["), "]") + for _, name := range field.Names { + paramStrs[t] = name.Name + "(&%s)" + paramStrs["*"+t] = name.Name + "(%s)" + paramStrs[name.Name] = "%s" + infof("found generic type %s, with roundtrippper %s (in complex interface)\n", t, name.Name) + } + } } }