diff --git a/gen.go b/gen.go index f996a02..038c839 100644 --- a/gen.go +++ b/gen.go @@ -369,6 +369,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { } e := f.Type.Elem() + // Note: this re-slices the slice to deal with arrays. if e.Kind() == reflect.Uint8 { return doTemplate(w, f, ` if len({{ .Name }}) > cbg.ByteArrayMaxLen { @@ -377,7 +378,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { {{ MajorType "w" "cbg.MajByteString" (print "len(" .Name ")" ) }} - if _, err := w.Write({{ .Name }}); err != nil { + if _, err := w.Write({{ .Name }}[:]); err != nil { return err } `) @@ -858,8 +859,18 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { if maj != cbg.MajByteString { return fmt.Errorf("expected byte array") } - {{ .Name }} = make([]byte, extra) - if _, err := io.ReadFull(br, {{ .Name }}); err != nil { + {{if .IsArray}} + if extra != {{ .Len }} { + return fmt.Errorf("expected array to have {{ .Len }} elements") + } + + {{ .Name }} = {{ .TypeName }}{} + {{else}} + if extra > 0 { + {{ .Name }} = make({{ .TypeName }}, extra) + } + {{end}} + if _, err := io.ReadFull(br, {{ .Name }}[:]); err != nil { return err } `) diff --git a/testgen/main.go b/testgen/main.go index e2e1717..c2d7a2e 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -11,6 +11,7 @@ func main() { types.SimpleTypeOne{}, types.SimpleTypeTwo{}, types.DeferredContainer{}, + types.FixedArrays{}, ); err != nil { panic(err) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 334ab68..2f79dd4 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -135,7 +135,7 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { return err } - if _, err := w.Write(t.Binary); err != nil { + if _, err := w.Write(t.Binary[:]); err != nil { return err } @@ -207,8 +207,12 @@ func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { if maj != cbg.MajByteString { return fmt.Errorf("expected byte array") } - t.Binary = make([]byte, extra) - if _, err := io.ReadFull(br, t.Binary); err != nil { + + if extra > 0 { + t.Binary = make([]uint8, extra) + } + + if _, err := io.ReadFull(br, t.Binary[:]); err != nil { return err } // t.Signed (int64) (int64) @@ -308,7 +312,7 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return err } - if _, err := w.Write(v); err != nil { + if _, err := w.Write(v[:]); err != nil { return err } } @@ -533,8 +537,12 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { if maj != cbg.MajByteString { return fmt.Errorf("expected byte array") } - t.Test[i] = make([]byte, extra) - if _, err := io.ReadFull(br, t.Test[i]); err != nil { + + if extra > 0 { + t.Test[i] = make([]uint8, extra) + } + + if _, err := io.ReadFull(br, t.Test[i][:]); err != nil { return err } } @@ -766,3 +774,160 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { } return nil } + +var lengthBufFixedArrays = []byte{131} + +func (t *FixedArrays) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write(lengthBufFixedArrays); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.Bytes ([20]uint8) (array) + if len(t.Bytes) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.Bytes was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.Bytes))); err != nil { + return err + } + + if _, err := w.Write(t.Bytes[:]); err != nil { + return err + } + + // t.Uint8 ([20]uint8) (array) + if len(t.Uint8) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.Uint8 was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.Uint8))); err != nil { + return err + } + + if _, err := w.Write(t.Uint8[:]); err != nil { + return err + } + + // t.Uint64 ([20]uint64) (array) + if len(t.Uint64) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Uint64 was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajArray, uint64(len(t.Uint64))); err != nil { + return err + } + for _, v := range t.Uint64 { + if err := cbg.CborWriteHeader(w, cbg.MajUnsignedInt, uint64(v)); err != nil { + return err + } + } + return nil +} + +func (t *FixedArrays) UnmarshalCBOR(r io.Reader) error { + *t = FixedArrays{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 3 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.Bytes ([20]uint8) (array) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.ByteArrayMaxLen { + return fmt.Errorf("t.Bytes: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra != 20 { + return fmt.Errorf("expected array to have 20 elements") + } + + t.Bytes = [20]uint8{} + + if _, err := io.ReadFull(br, t.Bytes[:]); err != nil { + return err + } + // t.Uint8 ([20]uint8) (array) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.ByteArrayMaxLen { + return fmt.Errorf("t.Uint8: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra != 20 { + return fmt.Errorf("expected array to have 20 elements") + } + + t.Uint8 = [20]uint8{} + + if _, err := io.ReadFull(br, t.Uint8[:]); err != nil { + return err + } + // t.Uint64 ([20]uint64) (array) + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.Uint64: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra != 20 { + return fmt.Errorf("expected array to have 20 elements") + } + + t.Uint64 = [20]uint64{} + + for i := 0; i < int(extra); i++ { + + maj, val, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return xerrors.Errorf("failed to read uint64 for t.Uint64 slice: %w", err) + } + + if maj != cbg.MajUnsignedInt { + return xerrors.Errorf("value read for array t.Uint64 was not a uint, instead got %d", maj) + } + + t.Uint64[i] = uint64(val) + } + + return nil +} diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index f618e76..9887549 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -108,7 +108,7 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return err } - if _, err := w.Write(v); err != nil { + if _, err := w.Write(v[:]); err != nil { return err } } @@ -331,8 +331,12 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { if maj != cbg.MajByteString { return fmt.Errorf("expected byte array") } - t.Test[i] = make([]byte, extra) - if _, err := io.ReadFull(br, t.Test[i]); err != nil { + + if extra > 0 { + t.Test[i] = make([]uint8, extra) + } + + if _, err := io.ReadFull(br, t.Test[i][:]); err != nil { return err } } diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index e127564..b506ff5 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -105,3 +105,9 @@ func TestNilValueDeferredUnmarshaling(t *testing.T) { t.Fatal("shouldnt be nil!") } } + +func TestFixedArrays(t *testing.T) { + zero := &FixedArrays{} + recepticle := &FixedArrays{} + testValueRoundtrip(t, zero, recepticle) +} diff --git a/testing/types.go b/testing/types.go index 4074d8c..fdc87c9 100644 --- a/testing/types.go +++ b/testing/types.go @@ -46,3 +46,9 @@ type DeferredContainer struct { Deferred *cbg.Deferred Value uint64 } + +type FixedArrays struct { + Bytes [20]byte + Uint8 [20]uint8 + Uint64 [20]uint64 +}