Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions _generated/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ type GenericTestTwo[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] struct {
GP2 map[string]*GenericTest2[B, BP, string] `msg:",allownil"`
}

type GenericTest3List[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] []GenericTest3[A, B, AP, BP]
type GenericTest3Map[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] map[string]GenericTest3[A, B, AP, BP]

type GenericTest3Array[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] [5]GenericTest3[A, B, AP, BP]

type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct {
A A
B B
Expand Down
80 changes: 80 additions & 0 deletions _generated/generics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,83 @@ func TestGenericsEncode(t *testing.T) {
t.Errorf("\n got=%#v\nwant=%#v", got, x)
}
}

// Test for generic alias types (slices, maps, arrays) - these verify the type parameter fix
func TestGenericTest3List(t *testing.T) {
// Test GenericTest3List
list := GenericTest3List[Fixed, Fixed, *Fixed, *Fixed]{
{A: Fixed{A: 1.5}, B: Fixed{A: 2.5}},
{A: Fixed{A: 3.5}, B: Fixed{A: 4.5}},
}

// Test marshaling
data, err := list.MarshalMsg(nil)
if err != nil {
t.Fatalf("GenericTest3List.MarshalMsg failed: %v", err)
}

// Test unmarshaling
var list2 GenericTest3List[Fixed, Fixed, *Fixed, *Fixed]
_, err = list2.UnmarshalMsg(data)
if err != nil {
t.Fatalf("GenericTest3List.UnmarshalMsg failed: %v", err)
}

// Verify round-trip
if len(list2) != 2 || list2[0].A.A != 1.5 || list2[0].B.A != 2.5 {
t.Errorf("GenericTest3List round-trip failed")
}
}

func TestGenericTest3Map(t *testing.T) {
// Test GenericTest3Map
testMap := GenericTest3Map[Fixed, Fixed, *Fixed, *Fixed]{
"key1": {A: Fixed{A: 1.5}, B: Fixed{A: 2.5}},
"key2": {A: Fixed{A: 3.5}, B: Fixed{A: 4.5}},
}

// Test marshaling
data, err := testMap.MarshalMsg(nil)
if err != nil {
t.Fatalf("GenericTest3Map.MarshalMsg failed: %v", err)
}

// Test unmarshaling
var testMap2 GenericTest3Map[Fixed, Fixed, *Fixed, *Fixed]
_, err = testMap2.UnmarshalMsg(data)
if err != nil {
t.Fatalf("GenericTest3Map.UnmarshalMsg failed: %v", err)
}

// Verify round-trip
if len(testMap2) != 2 || testMap2["key1"].A.A != 1.5 || testMap2["key1"].B.A != 2.5 {
t.Errorf("GenericTest3Map round-trip failed")
}
}

func TestGenericTest3Array(t *testing.T) {
// Test GenericTest3Array
testArray := GenericTest3Array[Fixed, Fixed, *Fixed, *Fixed]{
{A: Fixed{A: 1.5}, B: Fixed{A: 2.5}},
{A: Fixed{A: 3.5}, B: Fixed{A: 4.5}},
{A: Fixed{A: 5.5}, B: Fixed{A: 6.5}},
}

// Test marshaling
data, err := testArray.MarshalMsg(nil)
if err != nil {
t.Fatalf("GenericTest3Array.MarshalMsg failed: %v", err)
}

// Test unmarshaling
var testArray2 GenericTest3Array[Fixed, Fixed, *Fixed, *Fixed]
_, err = testArray2.UnmarshalMsg(data)
if err != nil {
t.Fatalf("GenericTest3Array.UnmarshalMsg failed: %v", err)
}

// Verify round-trip
if testArray2[0].A.A != 1.5 || testArray2[0].B.A != 2.5 || testArray2[2].A.A != 5.5 {
t.Errorf("GenericTest3Array round-trip failed")
}
}
5 changes: 3 additions & 2 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,15 @@ func (d *decodeGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}

if b.Convert {
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
lowered := b.ToBase() + "(" + vname + ")"
d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered)
} else {
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
d.p.printf("\nerr = %s.DecodeMsg(dc)", vname)
Expand Down
57 changes: 44 additions & 13 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,42 @@ type GenericTypeParams struct {
isPtr bool
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }

// typeNameWithParams returns the type name with generic parameters appended if they exist
// stripTypeParams removes type parameters from a type name for lookup purposes
// e.g. "MyType[T, U]" becomes "MyType", "*SomeType[A]" becomes "*SomeType"
func stripTypeParams(typeName string) string {
if idx := strings.Index(typeName, "["); idx != -1 {
return typeName[:idx]
}
return typeName
}

func (c *common) typeNameWithParams(baseName string) string {
if c.typeParams.TypeParams != "" && !strings.Contains(baseName, "[") {
// Check if baseName is a single identifier without dots (likely a type parameter)
if !strings.Contains(baseName, ".") && len(baseName) <= 2 && len(baseName) > 0 {
// This looks like a simple type parameter, don't add type parameters
return baseName
}
return baseName + c.typeParams.TypeParams
}
return baseName
}

// baseTypeName returns the type name without generic parameters (for use in method receivers)
func (c *common) baseTypeName() string {
return c.alias
}
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) hidden() {}
func (c *common) AllowNil() bool { return false }
func (c *common) SetIsAllowNil(bool) {}
func (c *common) SetTypeParams(tp GenericTypeParams) { c.typeParams = tp }
func (c *common) TypeParams() GenericTypeParams { return c.typeParams }
func (c *common) BaseTypeName() string { return c.baseTypeName() }
func (c *common) AlwaysPtr(set *bool) bool {
if c != nil && set != nil {
c.ptrRcv = *set
Expand Down Expand Up @@ -245,6 +273,9 @@ type Elem interface {
// TypeParams returns the generic type parameters for this element
TypeParams() GenericTypeParams

// BaseTypeName returns the type name without generic parameters
BaseTypeName() string

hidden()
}

Expand Down Expand Up @@ -283,10 +314,10 @@ ridx:

func (a *Array) TypeName() string {
if a.alias != "" {
return a.alias
return a.typeNameWithParams(a.alias)
}
a.Alias(fmt.Sprintf("[%s]%s", a.Size, a.Els.TypeName()))
return a.alias
return a.typeNameWithParams(a.alias)
}

func (a *Array) Copy() Elem {
Expand Down Expand Up @@ -335,14 +366,14 @@ ridx:

func (m *Map) TypeName() string {
if m.alias != "" {
return m.alias
return m.typeNameWithParams(m.alias)
}
keyType := "string"
if m.Key != nil {
keyType = m.Key.TypeName()
}
m.Alias("map[" + keyType + "]" + m.Value.TypeName())
return m.alias
return m.typeNameWithParams(m.alias)
}

func (m *Map) Copy() Elem {
Expand Down Expand Up @@ -403,10 +434,10 @@ func (s *Slice) SetVarname(a string) {

func (s *Slice) TypeName() string {
if s.alias != "" {
return s.alias
return s.typeNameWithParams(s.alias)
}
s.Alias("[]" + s.Els.TypeName())
return s.alias
return s.typeNameWithParams(s.alias)
}

func (s *Slice) Copy() Elem {
Expand Down Expand Up @@ -480,10 +511,10 @@ func (s *Ptr) SetVarname(a string) {

func (s *Ptr) TypeName() string {
if s.alias != "" {
return s.alias
return s.typeNameWithParams(s.alias)
}
s.Alias("*" + s.Value.TypeName())
return s.alias
return s.typeNameWithParams(s.alias)
}

func (s *Ptr) Copy() Elem {
Expand Down Expand Up @@ -673,10 +704,10 @@ func (s *BaseElem) SetVarname(a string) {
// type name for the base element.
func (s *BaseElem) TypeName() string {
if s.alias != "" {
return s.alias
return s.typeNameWithParams(s.alias)
}
s.common.Alias(s.BaseType())
return s.alias
return s.typeNameWithParams(s.alias)
}

// ToBase, used if Convert==true, is used as tmp = {{ToBase}}({{Varname}})
Expand Down Expand Up @@ -715,7 +746,7 @@ func (s *BaseElem) BaseName() string {
func (s *BaseElem) BaseType() string {
switch s.Value {
case IDENT:
return s.TypeName()
return s.alias

// exceptions to the naming/capitalization
// rule:
Expand Down
9 changes: 8 additions & 1 deletion gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,14 @@ func (e *encodeGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {

// Strip type parameters from dst for lookup in ToPointerMap
lookupKey := dst
Comment thread
klauspost marked this conversation as resolved.
Outdated
if idx := strings.Index(dst, "["); idx != -1 {
lookupKey = dst[:idx]
}

if remap := b.typeParams.ToPointerMap[lookupKey]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
e.p.printf("\nerr = %s.EncodeMsg(en)", vname)
Expand Down
2 changes: 1 addition & 1 deletion gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (m *marshalGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
echeck = true
Expand Down
9 changes: 8 additions & 1 deletion gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,14 @@ func (s *sizeGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {

// Strip type parameters from dst for lookup in ToPointerMap
lookupKey := dst
Comment thread
klauspost marked this conversation as resolved.
Outdated
if idx := strings.Index(dst, "["); idx != -1 {
lookupKey = dst[:idx]
}

if remap := b.typeParams.ToPointerMap[lookupKey]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
s.addConstant(basesizeExpr(b.Value, vname, b.BaseName()))
Expand Down
4 changes: 2 additions & 2 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func next(t traversal, e Elem) {

// possibly-immutable method receiver
func imutMethodReceiver(p Elem) string {
typeName := p.TypeName()
typeName := p.BaseTypeName()
typeParams := p.TypeParams()

switch e := p.(type) {
Expand Down Expand Up @@ -300,7 +300,7 @@ func imutMethodReceiver(p Elem) string {
// so that its method receiver
// is of the write type.
func methodReceiver(p Elem) string {
typeName := p.TypeName()
typeName := p.BaseTypeName()
typeParams := p.TypeParams()

switch p.(type) {
Expand Down
3 changes: 1 addition & 2 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,9 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
lowered = fmt.Sprintf(remap, lowered)
}

u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered)
case Time:
if u.ctx.asUTC {
Expand Down
Loading