diff --git a/cache.go b/cache.go index 7270a258..cc177624 100644 --- a/cache.go +++ b/cache.go @@ -97,6 +97,7 @@ type decodingStructType struct { fieldIndicesByIntKey map[int64]int // Only populated if toArray is false err error toArray bool + toIndefArray bool } func getDecodingStructType(t reflect.Type) (*decodingStructType, error) { @@ -111,9 +112,10 @@ func getDecodingStructType(t reflect.Type) (*decodingStructType, error) { flds, structOptions := getFields(t) toArray := hasToArrayOption(structOptions) + toIndefArray := hasToIndefArrayOption(structOptions) - if toArray { - return getDecodingStructToArrayType(t, flds) + if toArray || toIndefArray { + return getDecodingStructToArrayType(t, flds, toIndefArray) } fieldIndicesByName := make(map[string]int, len(flds)) @@ -163,7 +165,7 @@ func getDecodingStructType(t reflect.Type) (*decodingStructType, error) { return structType, nil } -func getDecodingStructToArrayType(t reflect.Type, flds fields) (*decodingStructType, error) { +func getDecodingStructToArrayType(t reflect.Type, flds fields, toIndefArray bool) (*decodingStructType, error) { decFlds := make(decodingFields, len(flds)) for i, f := range flds { // nameAsInt is set in getFields() except for fields with an unparsable tagged name. @@ -185,8 +187,9 @@ func getDecodingStructToArrayType(t reflect.Type, flds fields) (*decodingStructT } structType := &decodingStructType{ - fields: decFlds, - toArray: true, + fields: decFlds, + toArray: true, + toIndefArray: toIndefArray, } decodingStructTypeCache.Store(t, structType) return structType, nil @@ -199,6 +202,7 @@ type encodingStructType struct { omitEmptyFieldsIdx []int // Only populated if toArray is false err error toArray bool + toIndefArray bool } func (st *encodingStructType) getFields(em *encMode) encodingFields { @@ -258,6 +262,9 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) { flds, structOptions := getFields(t) + if hasToIndefArrayOption(structOptions) { + return getEncodingStructToIndefArrayType(t, flds) + } if hasToArrayOption(structOptions) { return getEncodingStructToArrayType(t, flds) } @@ -354,6 +361,27 @@ func getEncodingStructType(t reflect.Type) (*encodingStructType, error) { return structType, nil } +func getEncodingStructToIndefArrayType(t reflect.Type, flds fields) (*encodingStructType, error) { + encFlds := make(encodingFields, len(flds)) + for i, f := range flds { + encFlds[i] = &encodingField{field: *f} + encFlds[i].ef, encFlds[i].ief, encFlds[i].izf = getEncodeFunc(f.typ) + if encFlds[i].ef == nil { + structType := &encodingStructType{err: &UnsupportedTypeError{t}} + encodingStructTypeCache.Store(t, structType) + return nil, structType.err + } + } + + structType := &encodingStructType{ + fields: encFlds, + toArray: true, + toIndefArray: true, + } + encodingStructTypeCache.Store(t, structType) + return structType, nil +} + func getEncodingStructToArrayType(t reflect.Type, flds fields) (*encodingStructType, error) { encFlds := make(encodingFields, len(flds)) for i, f := range flds { @@ -398,3 +426,9 @@ func hasToArrayOption(tag string) bool { idx := strings.Index(tag, s) return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',') } + +func hasToIndefArrayOption(tag string) bool { + s := ",toindefarray" + idx := strings.Index(tag, s) + return idx >= 0 && (len(tag) == idx+len(s) || tag[idx+len(s)] == ',') +} diff --git a/encode.go b/encode.go index 0620bc63..97e1aa16 100644 --- a/encode.go +++ b/encode.go @@ -1480,6 +1480,47 @@ func putKeyValues(x *[]keyValue) { keyValuePool.Put(x) } +func encodeStructToIndefArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) { + structType, err := getEncodingStructType(v.Type()) + if err != nil { + return err + } + + if b := em.encTagBytes(v.Type()); b != nil { + e.Write(b) + } + + flds := structType.fields + + // Write indefinite-length array header (0x9f) + e.WriteByte(cborArrayWithIndefiniteLengthHead) + for i := 0; i < len(flds); i++ { + f := flds[i] + + var fv reflect.Value + if len(f.idx) == 1 { + fv = v.Field(f.idx[0]) + } else { + // Get embedded field value. No error is expected. + fv, _ = getFieldValue(v, f.idx, func(reflect.Value) (reflect.Value, error) { + // Write CBOR nil for null pointer to embedded struct + e.Write(cborNil) + return reflect.Value{}, nil + }) + if !fv.IsValid() { + continue + } + } + + if err := f.ef(e, em, fv); err != nil { + return err + } + } + // Write break code (0xff) + e.WriteByte(cborBreakFlag) + return nil +} + func encodeStructToArray(e *bytes.Buffer, em *encMode, v reflect.Value) (err error) { structType, err := getEncodingStructType(v.Type()) if err != nil { @@ -2079,6 +2120,9 @@ func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc, izf if f, ok := t.FieldByName("_"); ok { tag := f.Tag.Get("cbor") if tag != "-" { + if hasToIndefArrayOption(tag) { + return encodeStructToIndefArray, isEmptyStruct, isZeroFieldStruct + } if hasToArrayOption(tag) { return encodeStructToArray, isEmptyStruct, isZeroFieldStruct } diff --git a/toindefarray_test.go b/toindefarray_test.go new file mode 100644 index 00000000..98e85aa9 --- /dev/null +++ b/toindefarray_test.go @@ -0,0 +1,91 @@ +package cbor + +import ( + "encoding/hex" + "testing" +) + +func TestEncodeStructToIndefArray(t *testing.T) { + type TestStruct struct { + _ struct{} `cbor:",toindefarray"` + Data []byte + Count int + } + + ts := TestStruct{ + Data: []byte{0x73, 0xf7, 0x2d, 0xbe}, + Count: 3, + } + + b, err := Marshal(ts) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + h := hex.EncodeToString(b) + t.Logf("Encoded: %s", h) + + // Check indefinite array header + if b[0] != 0x9f { + t.Errorf("Expected first byte 0x9f (indef array), got 0x%02x", b[0]) + } + + // Check break code at end + if b[len(b)-1] != 0xff { + t.Errorf("Expected last byte 0xff (break), got 0x%02x", b[len(b)-1]) + } + + // Decode back + var ts2 TestStruct + err = Unmarshal(b, &ts2) + if err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + + if hex.EncodeToString(ts2.Data) != hex.EncodeToString(ts.Data) { + t.Errorf("Data mismatch: got %x, want %x", ts2.Data, ts.Data) + } + if ts2.Count != ts.Count { + t.Errorf("Count mismatch: got %d, want %d", ts2.Count, ts.Count) + } +} + +func TestEncodeStructToIndefArrayNested(t *testing.T) { + type Inner struct { + _ struct{} `cbor:",toindefarray"` + X []byte + Y int + } + + type Outer struct { + _ struct{} `cbor:",toindefarray"` + Inner Inner + Z int + } + + o := Outer{ + Inner: Inner{X: []byte{0xab, 0xcd}, Y: 42}, + Z: 7, + } + + b, err := Marshal(o) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + t.Logf("Encoded: %s", hex.EncodeToString(b)) + + if b[0] != 0x9f { + t.Errorf("Expected outer 0x9f, got 0x%02x", b[0]) + } + if b[len(b)-1] != 0xff { + t.Errorf("Expected outer 0xff, got 0x%02x", b[len(b)-1]) + } + + var o2 Outer + if err := Unmarshal(b, &o2); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if o2.Z != 7 || o2.Inner.Y != 42 { + t.Errorf("Decoded mismatch: %+v", o2) + } +}