Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions _generated/generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ type GenericTestTwo[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] struct {
GP2 map[string]*GenericTest2[B, BP, string] `msg:",allownil"`
}

type GenericTest3List[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] []GenericTest3[A, B, AP, BP]
type GenericTest3Map[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] map[string]GenericTest3[A, B, AP, BP]

type GenericTest3Array[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] [5]GenericTest3[A, B, AP, BP]

type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct {
A A
B B
Expand Down
80 changes: 80 additions & 0 deletions _generated/generics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,83 @@ func TestGenericsEncode(t *testing.T) {
t.Errorf("\n got=%#v\nwant=%#v", got, x)
}
}

// Test for generic alias types (slices, maps, arrays) - these verify the type parameter fix
func TestGenericTest3List(t *testing.T) {
// Test GenericTest3List
list := GenericTest3List[Fixed, Fixed, *Fixed, *Fixed]{
{A: Fixed{A: 1.5}, B: Fixed{A: 2.5}},
{A: Fixed{A: 3.5}, B: Fixed{A: 4.5}},
}

// Test marshaling
data, err := list.MarshalMsg(nil)
if err != nil {
t.Fatalf("GenericTest3List.MarshalMsg failed: %v", err)
}

// Test unmarshaling
var list2 GenericTest3List[Fixed, Fixed, *Fixed, *Fixed]
_, err = list2.UnmarshalMsg(data)
if err != nil {
t.Fatalf("GenericTest3List.UnmarshalMsg failed: %v", err)
}

// Verify round-trip
if len(list2) != 2 || list2[0].A.A != 1.5 || list2[0].B.A != 2.5 {
t.Errorf("GenericTest3List round-trip failed")
}
}

func TestGenericTest3Map(t *testing.T) {
// Test GenericTest3Map
testMap := GenericTest3Map[Fixed, Fixed, *Fixed, *Fixed]{
"key1": {A: Fixed{A: 1.5}, B: Fixed{A: 2.5}},
"key2": {A: Fixed{A: 3.5}, B: Fixed{A: 4.5}},
}

// Test marshaling
data, err := testMap.MarshalMsg(nil)
if err != nil {
t.Fatalf("GenericTest3Map.MarshalMsg failed: %v", err)
}

// Test unmarshaling
var testMap2 GenericTest3Map[Fixed, Fixed, *Fixed, *Fixed]
_, err = testMap2.UnmarshalMsg(data)
if err != nil {
t.Fatalf("GenericTest3Map.UnmarshalMsg failed: %v", err)
}

// Verify round-trip
if len(testMap2) != 2 || testMap2["key1"].A.A != 1.5 || testMap2["key1"].B.A != 2.5 {
t.Errorf("GenericTest3Map round-trip failed")
}
}

func TestGenericTest3Array(t *testing.T) {
// Test GenericTest3Array
testArray := GenericTest3Array[Fixed, Fixed, *Fixed, *Fixed]{
{A: Fixed{A: 1.5}, B: Fixed{A: 2.5}},
{A: Fixed{A: 3.5}, B: Fixed{A: 4.5}},
{A: Fixed{A: 5.5}, B: Fixed{A: 6.5}},
}

// Test marshaling
data, err := testArray.MarshalMsg(nil)
if err != nil {
t.Fatalf("GenericTest3Array.MarshalMsg failed: %v", err)
}

// Test unmarshaling
var testArray2 GenericTest3Array[Fixed, Fixed, *Fixed, *Fixed]
_, err = testArray2.UnmarshalMsg(data)
if err != nil {
t.Fatalf("GenericTest3Array.UnmarshalMsg failed: %v", err)
}

// Verify round-trip
if testArray2[0].A.A != 1.5 || testArray2[0].B.A != 2.5 || testArray2[2].A.A != 5.5 {
t.Errorf("GenericTest3Array round-trip failed")
}
}
5 changes: 3 additions & 2 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,15 @@ func (d *decodeGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}

if b.Convert {
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
lowered := b.ToBase() + "(" + vname + ")"
d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered)
} else {
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
d.p.printf("\nerr = %s.DecodeMsg(dc)", vname)
Expand Down
57 changes: 44 additions & 13 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,42 @@ type GenericTypeParams struct {
isPtr bool
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }

// typeNameWithParams returns the type name with generic parameters appended if they exist
// stripTypeParams removes type parameters from a type name for lookup purposes
// e.g. "MyType[T, U]" becomes "MyType", "*SomeType[A]" becomes "*SomeType"
func stripTypeParams(typeName string) string {
if idx := strings.Index(typeName, "["); idx != -1 {
return typeName[:idx]
}
return typeName
}

func (c *common) typeNameWithParams(baseName string) string {
if c.typeParams.TypeParams != "" && !strings.Contains(baseName, "[") {
// Check if baseName is a single identifier without dots (likely a type parameter)
if !strings.Contains(baseName, ".") && len(baseName) <= 2 && len(baseName) > 0 {
// This looks like a simple type parameter, don't add type parameters
return baseName
}
return baseName + c.typeParams.TypeParams
}
return baseName
}

// baseTypeName returns the type name without generic parameters (for use in method receivers)
func (c *common) baseTypeName() string {
return c.alias
}
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) hidden() {}
func (c *common) AllowNil() bool { return false }
func (c *common) SetIsAllowNil(bool) {}
func (c *common) SetTypeParams(tp GenericTypeParams) { c.typeParams = tp }
func (c *common) TypeParams() GenericTypeParams { return c.typeParams }
func (c *common) BaseTypeName() string { return c.baseTypeName() }
func (c *common) AlwaysPtr(set *bool) bool {
if c != nil && set != nil {
c.ptrRcv = *set
Expand Down Expand Up @@ -245,6 +273,9 @@ type Elem interface {
// TypeParams returns the generic type parameters for this element
TypeParams() GenericTypeParams

// BaseTypeName returns the type name without generic parameters
BaseTypeName() string

hidden()
}

Expand Down Expand Up @@ -283,10 +314,10 @@ ridx:

func (a *Array) TypeName() string {
if a.alias != "" {
return a.alias
return a.typeNameWithParams(a.alias)
}
a.Alias(fmt.Sprintf("[%s]%s", a.Size, a.Els.TypeName()))
return a.alias
return a.typeNameWithParams(a.alias)
}

func (a *Array) Copy() Elem {
Expand Down Expand Up @@ -335,14 +366,14 @@ ridx:

func (m *Map) TypeName() string {
if m.alias != "" {
return m.alias
return m.typeNameWithParams(m.alias)
}
keyType := "string"
if m.Key != nil {
keyType = m.Key.TypeName()
}
m.Alias("map[" + keyType + "]" + m.Value.TypeName())
return m.alias
return m.typeNameWithParams(m.alias)
}

func (m *Map) Copy() Elem {
Expand Down Expand Up @@ -403,10 +434,10 @@ func (s *Slice) SetVarname(a string) {

func (s *Slice) TypeName() string {
if s.alias != "" {
return s.alias
return s.typeNameWithParams(s.alias)
}
s.Alias("[]" + s.Els.TypeName())
return s.alias
return s.typeNameWithParams(s.alias)
}

func (s *Slice) Copy() Elem {
Expand Down Expand Up @@ -480,10 +511,10 @@ func (s *Ptr) SetVarname(a string) {

func (s *Ptr) TypeName() string {
if s.alias != "" {
return s.alias
return s.typeNameWithParams(s.alias)
}
s.Alias("*" + s.Value.TypeName())
return s.alias
return s.typeNameWithParams(s.alias)
}

func (s *Ptr) Copy() Elem {
Expand Down Expand Up @@ -673,10 +704,10 @@ func (s *BaseElem) SetVarname(a string) {
// type name for the base element.
func (s *BaseElem) TypeName() string {
if s.alias != "" {
return s.alias
return s.typeNameWithParams(s.alias)
}
s.common.Alias(s.BaseType())
return s.alias
return s.typeNameWithParams(s.alias)
}

// ToBase, used if Convert==true, is used as tmp = {{ToBase}}({{Varname}})
Expand Down Expand Up @@ -715,7 +746,7 @@ func (s *BaseElem) BaseName() string {
func (s *BaseElem) BaseType() string {
switch s.Value {
case IDENT:
return s.TypeName()
return s.alias

// exceptions to the naming/capitalization
// rule:
Expand Down
9 changes: 8 additions & 1 deletion gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,14 @@ func (e *encodeGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {

// Strip type parameters from dst for lookup in ToPointerMap
lookupKey := stripTypeParams(dst)
if idx := strings.Index(dst, "["); idx != -1 {
lookupKey = dst[:idx]
}

if remap := b.typeParams.ToPointerMap[lookupKey]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
e.p.printf("\nerr = %s.EncodeMsg(en)", vname)
Expand Down
2 changes: 1 addition & 1 deletion gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (m *marshalGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
echeck = true
Expand Down
9 changes: 8 additions & 1 deletion gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,14 @@ func (s *sizeGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {

// Strip type parameters from dst for lookup in ToPointerMap
lookupKey := stripTypeParams(dst)
if idx := strings.Index(dst, "["); idx != -1 {
lookupKey = dst[:idx]
}

if remap := b.typeParams.ToPointerMap[lookupKey]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
s.addConstant(basesizeExpr(b.Value, vname, b.BaseName()))
Expand Down
4 changes: 2 additions & 2 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func next(t traversal, e Elem) {

// possibly-immutable method receiver
func imutMethodReceiver(p Elem) string {
typeName := p.TypeName()
typeName := p.BaseTypeName()
typeParams := p.TypeParams()

switch e := p.(type) {
Expand Down Expand Up @@ -300,7 +300,7 @@ func imutMethodReceiver(p Elem) string {
// so that its method receiver
// is of the write type.
func methodReceiver(p Elem) string {
typeName := p.TypeName()
typeName := p.BaseTypeName()
typeParams := p.TypeParams()

switch p.(type) {
Expand Down
3 changes: 1 addition & 2 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,9 @@ func (u *unmarshalGen) gBase(b *BaseElem) {
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
if remap := b.typeParams.ToPointerMap[stripTypeParams(dst)]; remap != "" {
lowered = fmt.Sprintf(remap, lowered)
}

u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered)
case Time:
if u.ctx.asUTC {
Expand Down
Loading