Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions _generated/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct {
A A
B B
}

// GenericTest4 has the msgp.RTFor constraint as a sub-constraint.
type GenericTest4[T any, P interface {
*T
msgp.RTFor[T]
}] struct {
Totals [60]T
}
53 changes: 46 additions & 7 deletions parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,29 @@ func formatTypeParams(params *ast.FieldList) string {
return "[" + strings.Join(paramStrs, ", ") + "]"
}

Comment thread
klauspost marked this conversation as resolved.
// findRTForInInterface recursively searches for msgp.RTFor[T] patterns within interface types
func findRTForInInterface(iface *ast.InterfaceType) []string {
var rtfors []string
if iface.Methods == nil {
return rtfors
}

for _, method := range iface.Methods.List {
// Check if this is an embedded interface/type
if len(method.Names) == 0 {
typeStr := stringify(method.Type)
if strings.HasPrefix(typeStr, "msgp.RTFor[") {
rtfors = append(rtfors, typeStr)
}
// Recursively check nested interfaces
if nestedIface, ok := method.Type.(*ast.InterfaceType); ok {
rtfors = append(rtfors, findRTForInInterface(nestedIface)...)
}
}
}
return rtfors
}

// formatTypeParams converts an AST FieldList to a string representation.
// For 'Foo[T any, P msgp.RTFor[T]]' will return {"T": "P"}.
func getMspTypeParams(params *ast.FieldList) map[string]string {
Expand All @@ -257,15 +280,31 @@ func getMspTypeParams(params *ast.FieldList) map[string]string {
paramStrs := make(map[string]string)
for _, field := range params.List {
str := stringify(field.Type)
if !strings.HasPrefix(str, "msgp.RTFor[") {

// Handle simple msgp.RTFor[T] constraints
if strings.HasPrefix(str, "msgp.RTFor[") {
for _, name := range field.Names {
t := strings.TrimSuffix(strings.TrimPrefix(str, "msgp.RTFor["), "]")
paramStrs[t] = name.Name + "(&%s)"
paramStrs["*"+t] = name.Name + "(%s)"
paramStrs[name.Name] = "%s"
infof("found generic type %s, with roundtrippper %s\n", t, name.Name)
}
continue
}
for _, name := range field.Names {
t := strings.TrimSuffix(strings.TrimPrefix(str, "msgp.RTFor["), "]")
paramStrs[t] = name.Name + "(&%s)"
paramStrs["*"+t] = name.Name + "(%s)"
paramStrs[name.Name] = "%s"
infof("found generic type %s, with roundtrippper %s\n", t, name.Name)

// Handle complex interface constraints that embed msgp.RTFor[T]
if iface, ok := field.Type.(*ast.InterfaceType); ok {
rtfors := findRTForInInterface(iface)
for _, rtfor := range rtfors {
for _, name := range field.Names {
t := strings.TrimSuffix(strings.TrimPrefix(rtfor, "msgp.RTFor["), "]")
paramStrs[t] = name.Name + "(&%s)"
paramStrs["*"+t] = name.Name + "(%s)"
paramStrs[name.Name] = "%s"
infof("found generic type %s, with roundtrippper %s (in complex interface)\n", t, name.Name)
}
}
}
}

Expand Down
Loading