diff --git a/_generated/generics.go b/_generated/generics.go index 3f402f21..fa480d39 100644 --- a/_generated/generics.go +++ b/_generated/generics.go @@ -78,6 +78,11 @@ type GenericTestTwo[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] struct { GP2 map[string]*GenericTest2[B, BP, string] `msg:",allownil"` } +type GenericTest3List[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] []GenericTest3[A, B, AP, BP] +type GenericTest3Map[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] map[string]GenericTest3[A, B, AP, BP] + +type GenericTest3Array[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] [5]GenericTest3[A, B, AP, BP] + type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct { A A B B diff --git a/_generated/generics_test.go b/_generated/generics_test.go index a669859e..6f8396fd 100644 --- a/_generated/generics_test.go +++ b/_generated/generics_test.go @@ -63,3 +63,83 @@ func TestGenericsEncode(t *testing.T) { t.Errorf("\n got=%#v\nwant=%#v", got, x) } } + +// Test for generic alias types (slices, maps, arrays) - these verify the type parameter fix +func TestGenericTest3List(t *testing.T) { + // Test GenericTest3List + list := GenericTest3List[Fixed, Fixed, *Fixed, *Fixed]{ + {A: Fixed{A: 1.5}, B: Fixed{A: 2.5}}, + {A: Fixed{A: 3.5}, B: Fixed{A: 4.5}}, + } + + // Test marshaling + data, err := list.MarshalMsg(nil) + if err != nil { + t.Fatalf("GenericTest3List.MarshalMsg failed: %v", err) + } + + // Test unmarshaling + var list2 GenericTest3List[Fixed, Fixed, *Fixed, *Fixed] + _, err = list2.UnmarshalMsg(data) + if err != nil { + t.Fatalf("GenericTest3List.UnmarshalMsg failed: %v", err) + } + + // Verify round-trip + if len(list2) != 2 || list2[0].A.A != 1.5 || list2[0].B.A != 2.5 { + t.Errorf("GenericTest3List round-trip failed") + } +} + +func TestGenericTest3Map(t *testing.T) { + // Test GenericTest3Map + testMap := GenericTest3Map[Fixed, Fixed, *Fixed, *Fixed]{ + "key1": {A: Fixed{A: 1.5}, B: Fixed{A: 2.5}}, + "key2": {A: Fixed{A: 3.5}, B: Fixed{A: 4.5}}, + } + + // Test marshaling + data, err := testMap.MarshalMsg(nil) + if err != nil { + t.Fatalf("GenericTest3Map.MarshalMsg failed: %v", err) + } + + // Test unmarshaling + var testMap2 GenericTest3Map[Fixed, Fixed, *Fixed, *Fixed] + _, err = testMap2.UnmarshalMsg(data) + if err != nil { + t.Fatalf("GenericTest3Map.UnmarshalMsg failed: %v", err) + } + + // Verify round-trip + if len(testMap2) != 2 || testMap2["key1"].A.A != 1.5 || testMap2["key1"].B.A != 2.5 { + t.Errorf("GenericTest3Map round-trip failed") + } +} + +func TestGenericTest3Array(t *testing.T) { + // Test GenericTest3Array + testArray := GenericTest3Array[Fixed, Fixed, *Fixed, *Fixed]{ + {A: Fixed{A: 1.5}, B: Fixed{A: 2.5}}, + {A: Fixed{A: 3.5}, B: Fixed{A: 4.5}}, + {A: Fixed{A: 5.5}, B: Fixed{A: 6.5}}, + } + + // Test marshaling + data, err := testArray.MarshalMsg(nil) + if err != nil { + t.Fatalf("GenericTest3Array.MarshalMsg failed: %v", err) + } + + // Test unmarshaling + var testArray2 GenericTest3Array[Fixed, Fixed, *Fixed, *Fixed] + _, err = testArray2.UnmarshalMsg(data) + if err != nil { + t.Fatalf("GenericTest3Array.UnmarshalMsg failed: %v", err) + } + + // Verify round-trip + if testArray2[0].A.A != 1.5 || testArray2[0].B.A != 2.5 || testArray2[2].A.A != 5.5 { + t.Errorf("GenericTest3Array round-trip failed") + } +} diff --git a/gen/decode.go b/gen/decode.go index 07352c6e..5197b854 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -225,14 +225,15 @@ func (d *decodeGen) gBase(b *BaseElem) { if b.typeParams.isPtr { dst = "*" + dst } + if b.Convert { - if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" { vname = fmt.Sprintf(remap, vname) } lowered := b.ToBase() + "(" + vname + ")" d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered) } else { - if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" { vname = fmt.Sprintf(remap, vname) } d.p.printf("\nerr = %s.DecodeMsg(dc)", vname) diff --git a/gen/elem.go b/gen/elem.go index bcf3c0e6..dd757273 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -160,14 +160,42 @@ type GenericTypeParams struct { isPtr bool } -func (c *common) SetVarname(s string) { c.vname = s } -func (c *common) Varname() string { return c.vname } +func (c *common) SetVarname(s string) { c.vname = s } +func (c *common) Varname() string { return c.vname } + +// typeNameWithParams returns the type name with generic parameters appended if they exist +// stripTypeParams removes type parameters from a type name for lookup purposes +// e.g. "MyType[T, U]" becomes "MyType", "*SomeType[A]" becomes "*SomeType" +func stripTypeParams(typeName string) string { + if idx := strings.Index(typeName, "["); idx != -1 { + return typeName[:idx] + } + return typeName +} + +func (c *common) typeNameWithParams(baseName string) string { + if c.typeParams.TypeParams != "" && !strings.Contains(baseName, "[") { + // Check if baseName is a single identifier without dots (likely a type parameter) + if !strings.Contains(baseName, ".") && len(baseName) <= 2 && len(baseName) > 0 { + // This looks like a simple type parameter, don't add type parameters + return baseName + } + return baseName + c.typeParams.TypeParams + } + return baseName +} + +// baseTypeName returns the type name without generic parameters (for use in method receivers) +func (c *common) baseTypeName() string { + return c.alias +} func (c *common) Alias(typ string) { c.alias = typ } func (c *common) hidden() {} func (c *common) AllowNil() bool { return false } func (c *common) SetIsAllowNil(bool) {} func (c *common) SetTypeParams(tp GenericTypeParams) { c.typeParams = tp } func (c *common) TypeParams() GenericTypeParams { return c.typeParams } +func (c *common) BaseTypeName() string { return c.baseTypeName() } func (c *common) AlwaysPtr(set *bool) bool { if c != nil && set != nil { c.ptrRcv = *set @@ -245,6 +273,9 @@ type Elem interface { // TypeParams returns the generic type parameters for this element TypeParams() GenericTypeParams + // BaseTypeName returns the type name without generic parameters + BaseTypeName() string + hidden() } @@ -283,10 +314,10 @@ ridx: func (a *Array) TypeName() string { if a.alias != "" { - return a.alias + return a.typeNameWithParams(a.alias) } a.Alias(fmt.Sprintf("[%s]%s", a.Size, a.Els.TypeName())) - return a.alias + return a.typeNameWithParams(a.alias) } func (a *Array) Copy() Elem { @@ -335,14 +366,14 @@ ridx: func (m *Map) TypeName() string { if m.alias != "" { - return m.alias + return m.typeNameWithParams(m.alias) } keyType := "string" if m.Key != nil { keyType = m.Key.TypeName() } m.Alias("map[" + keyType + "]" + m.Value.TypeName()) - return m.alias + return m.typeNameWithParams(m.alias) } func (m *Map) Copy() Elem { @@ -403,10 +434,10 @@ func (s *Slice) SetVarname(a string) { func (s *Slice) TypeName() string { if s.alias != "" { - return s.alias + return s.typeNameWithParams(s.alias) } s.Alias("[]" + s.Els.TypeName()) - return s.alias + return s.typeNameWithParams(s.alias) } func (s *Slice) Copy() Elem { @@ -480,10 +511,10 @@ func (s *Ptr) SetVarname(a string) { func (s *Ptr) TypeName() string { if s.alias != "" { - return s.alias + return s.typeNameWithParams(s.alias) } s.Alias("*" + s.Value.TypeName()) - return s.alias + return s.typeNameWithParams(s.alias) } func (s *Ptr) Copy() Elem { @@ -673,10 +704,10 @@ func (s *BaseElem) SetVarname(a string) { // type name for the base element. func (s *BaseElem) TypeName() string { if s.alias != "" { - return s.alias + return s.typeNameWithParams(s.alias) } s.common.Alias(s.BaseType()) - return s.alias + return s.typeNameWithParams(s.alias) } // ToBase, used if Convert==true, is used as tmp = {{ToBase}}({{Varname}}) @@ -715,7 +746,7 @@ func (s *BaseElem) BaseName() string { func (s *BaseElem) BaseType() string { switch s.Value { case IDENT: - return s.TypeName() + return s.alias // exceptions to the naming/capitalization // rule: diff --git a/gen/encode.go b/gen/encode.go index 2b8bd7e8..44f5ca57 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -344,7 +344,14 @@ func (e *encodeGen) gBase(b *BaseElem) { if b.typeParams.isPtr { dst = "*" + dst } - if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + + // Strip type parameters from dst for lookup in ToPointerMap + lookupKey := stripTypeParams(dst) + if idx := strings.Index(dst, "["); idx != -1 { + lookupKey = dst[:idx] + } + + if remap := b.typeParams.ToPointerMap[lookupKey]; remap != "" { vname = fmt.Sprintf(remap, vname) } e.p.printf("\nerr = %s.EncodeMsg(en)", vname) diff --git a/gen/marshal.go b/gen/marshal.go index e68464ad..acd3c018 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -358,7 +358,7 @@ func (m *marshalGen) gBase(b *BaseElem) { if b.typeParams.isPtr { dst = "*" + dst } - if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" { vname = fmt.Sprintf(remap, vname) } echeck = true diff --git a/gen/size.go b/gen/size.go index 53504d8d..3efab37d 100644 --- a/gen/size.go +++ b/gen/size.go @@ -247,7 +247,14 @@ func (s *sizeGen) gBase(b *BaseElem) { if b.typeParams.isPtr { dst = "*" + dst } - if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + + // Strip type parameters from dst for lookup in ToPointerMap + lookupKey := stripTypeParams(dst) + if idx := strings.Index(dst, "["); idx != -1 { + lookupKey = dst[:idx] + } + + if remap := b.typeParams.ToPointerMap[lookupKey]; remap != "" { vname = fmt.Sprintf(remap, vname) } s.addConstant(basesizeExpr(b.Value, vname, b.BaseName())) diff --git a/gen/spec.go b/gen/spec.go index 4ff4fbee..0e02036d 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -268,7 +268,7 @@ func next(t traversal, e Elem) { // possibly-immutable method receiver func imutMethodReceiver(p Elem) string { - typeName := p.TypeName() + typeName := p.BaseTypeName() typeParams := p.TypeParams() switch e := p.(type) { @@ -300,7 +300,7 @@ func imutMethodReceiver(p Elem) string { // so that its method receiver // is of the write type. func methodReceiver(p Elem) string { - typeName := p.TypeName() + typeName := p.BaseTypeName() typeParams := p.TypeParams() switch p.(type) { diff --git a/gen/unmarshal.go b/gen/unmarshal.go index f1d6bfce..f972d359 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -247,10 +247,9 @@ func (u *unmarshalGen) gBase(b *BaseElem) { if b.typeParams.isPtr { dst = "*" + dst } - if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" { lowered = fmt.Sprintf(remap, lowered) } - u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered) case Time: if u.ctx.asUTC {