Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type decodingStructType struct {
fieldIndicesByIntKey map[int64]int // Only populated if toArray is false
err error
toArray bool
toIndefArray bool
}

func getDecodingStructType(t reflect.Type) (*decodingStructType, error) {
Expand All @@ -111,9 +112,10 @@ func getDecodingStructType(t reflect.Type) (*decodingStructType, error) {
flds, structOptions := getFields(t)

toArray := hasToArrayOption(structOptions)
toIndefArray := hasToIndefArrayOption(structOptions)

if toArray {
return getDecodingStructToArrayType(t, flds)
if toArray || toIndefArray {
return getDecodingStructToArrayType(t, flds, toIndefArray)
}

fieldIndicesByName := make(map[string]int, len(flds))
Expand Down Expand Up @@ -163,7 +165,7 @@ func getDecodingStructType(t reflect.Type) (*decodingStructType, error) {
return structType, nil
}

func getDecodingStructToArrayType(t reflect.Type, flds fields) (*decodingStructType, error) {
func getDecodingStructToArrayType(t reflect.Type, flds fields, toIndefArray bool) (*decodingStructType, error) {
decFlds := make(decodingFields, len(flds))
for i, f := range flds {
// nameAsInt is set in getFields() except for fields with an unparsable tagged name.
Expand All @@ -185,8 +187,9 @@ func getDecodingStructToArrayType(t reflect.Type, flds fields) (*decodingStructT
}

structType := &decodingStructType{
fields: decFlds,
toArray: true,
fields: decFlds,
toArray: true,
toIndefArray: toIndefArray,
}
decodingStructTypeCache.Store(t, structType)
return structType, nil
Expand All @@ -199,6 +202,7 @@ type encodingStructType struct {
omitEmptyFieldsIdx []int // Only populated if toArray is false
err error
toArray bool
toIndefArray bool
}

func (st *encodingStructType) getFields(em *encMode) encodingFields {
Expand Down Expand Up @@ -258,6 +262,9 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {

flds, structOptions := getFields(t)

if hasToIndefArrayOption(structOptions) {
return getEncodingStructToIndefArrayType(t, flds)
}
if hasToArrayOption(structOptions) {
return getEncodingStructToArrayType(t, flds)
}
Expand Down Expand Up @@ -354,6 +361,27 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) {
return structType, nil
}

func getEncodingStructToIndefArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
encFlds := make(encodingFields, len(flds))
for i, f := range flds {
encFlds[i] = &encodingField{field: *f}
encFlds[i].ef, encFlds[i].ief, encFlds[i].izf = getEncodeFunc(f.typ)
if encFlds[i].ef == nil {
structType := &encodingStructType{err: &UnsupportedTypeError{t}}
encodingStructTypeCache.Store(t, structType)
return nil, structType.err
}
}

structType := &encodingStructType{
fields: encFlds,
toArray: true,
toIndefArray: true,
}
encodingStructTypeCache.Store(t, structType)
return structType, nil
}

func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) {
encFlds := make(encodingFields, len(flds))
for i, f := range flds {
Expand Down Expand Up @@ -398,3 +426,9 @@ func hasToArrayOption(tag string) bool {
idx := strings.Index(tag, s)
return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
}

func hasToIndefArrayOption(tag string) bool {
s := ",toindefarray"
idx := strings.Index(tag, s)
return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',')
}
44 changes: 44 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,47 @@ func putKeyValues(x *[]keyValue) {
keyValuePool.Put(x)
}

func encodeStructToIndefArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) {
structType, err := getEncodingStructType(v.Type())
if err != nil {
return err
}

if b := em.encTagBytes(v.Type()); b != nil {
e.Write(b)
}

flds := structType.fields

// Write indefinite-length array header (0x9f)
e.WriteByte(cborArrayWithIndefiniteLengthHead)
for i := 0; i < len(flds); i++ {
f := flds[i]

var fv reflect.Value
if len(f.idx) == 1 {
fv = v.Field(f.idx[0])
} else {
// Get embedded field value. No error is expected.
fv, _ = getFieldValue(v, f.idx, func(reflect.Value) (reflect.Value, error) {
// Write CBOR nil for null pointer to embedded struct
e.Write(cborNil)
return reflect.Value{}, nil
})
if !fv.IsValid() {
continue
}
}

if err := f.ef(e, em, fv); err != nil {
return err
}
}
// Write break code (0xff)
e.WriteByte(cborBreakFlag)
return nil
}

func encodeStructToArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) {
structType, err := getEncodingStructType(v.Type())
if err != nil {
Expand Down Expand Up @@ -2079,6 +2120,9 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc, izf
if f, ok := t.FieldByName("_"); ok {
tag := f.Tag.Get("cbor")
if tag != "-" {
if hasToIndefArrayOption(tag) {
return encodeStructToIndefArray, isEmptyStruct, isZeroFieldStruct
}
if hasToArrayOption(tag) {
return encodeStructToArray, isEmptyStruct, isZeroFieldStruct
}
Expand Down
91 changes: 91 additions & 0 deletions toindefarray_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package cbor

import (
"encoding/hex"
"testing"
)

func TestEncodeStructToIndefArray(t *testing.T) {
type TestStruct struct {
_ struct{} `cbor:",toindefarray"`
Data []byte
Count int
}

ts := TestStruct{
Data: []byte{0x73, 0xf7, 0x2d, 0xbe},
Count: 3,
}

b, err := Marshal(ts)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}

h := hex.EncodeToString(b)
t.Logf("Encoded: %s", h)

// Check indefinite array header
if b[0] != 0x9f {
t.Errorf("Expected first byte 0x9f (indef array), got 0x%02x", b[0])
}

// Check break code at end
if b[len(b)-1] != 0xff {
t.Errorf("Expected last byte 0xff (break), got 0x%02x", b[len(b)-1])
}

// Decode back
var ts2 TestStruct
err = Unmarshal(b, &ts2)
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}

if hex.EncodeToString(ts2.Data) != hex.EncodeToString(ts.Data) {
t.Errorf("Data mismatch: got %x, want %x", ts2.Data, ts.Data)
}
if ts2.Count != ts.Count {
t.Errorf("Count mismatch: got %d, want %d", ts2.Count, ts.Count)
}
}

func TestEncodeStructToIndefArrayNested(t *testing.T) {
type Inner struct {
_ struct{} `cbor:",toindefarray"`
X []byte
Y int
}

type Outer struct {
_ struct{} `cbor:",toindefarray"`
Inner Inner
Z int
}

o := Outer{
Inner: Inner{X: []byte{0xab, 0xcd}, Y: 42},
Z: 7,
}

b, err := Marshal(o)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
t.Logf("Encoded: %s", hex.EncodeToString(b))

if b[0] != 0x9f {
t.Errorf("Expected outer 0x9f, got 0x%02x", b[0])
}
if b[len(b)-1] != 0xff {
t.Errorf("Expected outer 0xff, got 0x%02x", b[len(b)-1])
}

var o2 Outer
if err := Unmarshal(b, &o2); err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if o2.Z != 7 || o2.Inner.Y != 42 {
t.Errorf("Decoded mismatch: %+v", o2)
}
}
Loading