Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions _generated/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct {
A A
B B
}

// GenericTest4 has the msgp.RTFor constraint as a sub-constraint.
type GenericTest4[T any, P interface {
*T
msgp.RTFor[T]
}] struct {
Totals [60]T
}
59 changes: 49 additions & 10 deletions parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,8 @@ func formatTypeParams(params *ast.FieldList) string {

var paramStrs []string
for _, field := range params.List {
str := stringify(field.Type)
// Convert underscores to _RTn where n is the number of the parameter
convert := strings.HasPrefix(str, "msgp.RTFor[")
convert := isrtfor(field.Type)

// Each field can have multiple names (e.g., T, U constraint)
for _, name := range field.Names {
Expand All @@ -247,6 +246,31 @@ func formatTypeParams(params *ast.FieldList) string {
return "[" + strings.Join(paramStrs, ", ") + "]"
}

Comment thread
klauspost marked this conversation as resolved.
// isrtfor returns whether the provided expression is a msgp.RTFor[T] pattern.
func isrtfor(t ast.Expr) bool { return strings.HasPrefix(stringify(t), "msgp.RTFor[") }

// findRTForInInterface recursively searches for msgp.RTFor[T] patterns within interface types
func findRTForInInterface(iface *ast.InterfaceType) []string {
var rtfors []string
if iface.Methods == nil {
return rtfors
}

for _, method := range iface.Methods.List {
// Check if this is an embedded interface/type
if len(method.Names) == 0 {
if isrtfor(method.Type) {
rtfors = append(rtfors, stringify(method.Type))
}
// Recursively check nested interfaces
if nestedIface, ok := method.Type.(*ast.InterfaceType); ok {
rtfors = append(rtfors, findRTForInInterface(nestedIface)...)
}
}
}
return rtfors
}

// formatTypeParams converts an AST FieldList to a string representation.
// For 'Foo[T any, P msgp.RTFor[T]]' will return {"T": "P"}.
func getMspTypeParams(params *ast.FieldList) map[string]string {
Expand All @@ -256,16 +280,31 @@ func getMspTypeParams(params *ast.FieldList) map[string]string {

paramStrs := make(map[string]string)
for _, field := range params.List {
str := stringify(field.Type)
if !strings.HasPrefix(str, "msgp.RTFor[") {

// Handle simple msgp.RTFor[T] constraints
if isrtfor(field.Type) {
t := strings.TrimSuffix(strings.TrimPrefix(stringify(field.Type), "msgp.RTFor["), "]")
for _, name := range field.Names {
paramStrs[t] = name.Name + "(&%s)"
paramStrs["*"+t] = name.Name + "(%s)"
paramStrs[name.Name] = "%s"
infof("found generic type %s, with roundtrippper %s\n", t, name.Name)
}
continue
}
for _, name := range field.Names {
t := strings.TrimSuffix(strings.TrimPrefix(str, "msgp.RTFor["), "]")
paramStrs[t] = name.Name + "(&%s)"
paramStrs["*"+t] = name.Name + "(%s)"
paramStrs[name.Name] = "%s"
infof("found generic type %s, with roundtrippper %s\n", t, name.Name)

// Handle complex interface constraints that embed msgp.RTFor[T]
if iface, ok := field.Type.(*ast.InterfaceType); ok {
rtfors := findRTForInInterface(iface)
for _, rtfor := range rtfors {
t := strings.TrimSuffix(strings.TrimPrefix(rtfor, "msgp.RTFor["), "]")
for _, name := range field.Names {
paramStrs[t] = name.Name + "(&%s)"
paramStrs["*"+t] = name.Name + "(%s)"
paramStrs[name.Name] = "%s"
infof("found generic type %s, with roundtrippper %s (in complex interface)\n", t, name.Name)
}
}
}
}

Expand Down
Loading