From 489dea9c909e8e45ed4e3840bf67096e5fd18d9f Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Aug 2025 12:10:47 +0200 Subject: [PATCH 1/7] Add array/map iterators Allows to receive iterators for arrays or maps, so inputs can be directly iterated over. --- msgp/iter.go | 876 ++++++++++++++ msgp/iter_test.go | 2923 +++++++++++++++++++++++++++++++++++++++++++++ msgp/read.go | 2 +- 3 files changed, 3800 insertions(+), 1 deletion(-) create mode 100644 msgp/iter.go create mode 100644 msgp/iter_test.go diff --git a/msgp/iter.go b/msgp/iter.go new file mode 100644 index 00000000..ff5f3884 --- /dev/null +++ b/msgp/iter.go @@ -0,0 +1,876 @@ +package msgp + +import ( + "fmt" + "iter" + "time" +) + +// ArrayExtraTypes is a type constraint that includes all types that can be +// decoded from an array. +// Even though 'any' type can be used, its pointer must implement the +// Decodable when using Reader or Unmarhaler interface when reading from bytes. +type ArrayExtraTypes interface { + bool | string | []byte | complex64 | complex128 | time.Time | time.Duration | any +} + +// ReadArray returns an iterator that can be used to iterate over the elements +// of an array in the MessagePack data while being read by the provided Reader. +// The type parameter V specifies the type of the elements in the array. +// The type parameter V must be ArrayExtraTypes or a type whose +// pointer implements the Decodable interface. +// Use ReadNumberArray for numbers. +// The returned iterator implements the iter.Seq[V] interface, +// allowing for sequential access to the array elements. +func ReadArray[V ArrayExtraTypes](m *Reader) iter.Seq2[V, error] { + return func(yield func(V, error) bool) { + // Assuming Reader has a method to read array length + var x V + length, err := m.ReadArrayHeader() + if err != nil { + yield(x, err) + return + } + switch any(x).(type) { + case string: + for range length { + v, err := m.ReadString() + if !yield(any(v).(V), err) { + return + } + } + case []byte: + for range length { + v, err := m.ReadBytes(nil) + if !yield(any(v).(V), err) { + return + } + } + case bool: + for range length { + v, err := m.ReadBool() + if !yield(any(v).(V), err) { + return + } + } + case time.Time: + for range length { + v, err := m.ReadTime() + if !yield(any(v).(V), err) { + return + } + } + case time.Duration: + for range length { + v, err := m.ReadDuration() + if !yield(any(v).(V), err) { + return + } + } + case complex64: + for range length { + v, err := m.ReadComplex64() + if !yield(any(v).(V), err) { + return + } + } + case complex128: + for range length { + v, err := m.ReadComplex128() + if !yield(any(v).(V), err) { + } + } + default: + for range length { + var v V + ptr := &v + if dc, ok := any(ptr).(Decodable); ok { + err = dc.DecodeMsg(m) + if !yield(v, err) { + return + } + } else { + err = fmt.Errorf("cannot decode into type %T", ptr) + return + } + } + } + } +} + +// MapKeyTypes are possible key types. Usually this is a string. +// Even though 'any' type can be used, its pointer must implement the Decodable interface. +type MapKeyTypes interface { + NumberTypes | bool | string | []byte | complex64 | complex128 | time.Time | time.Duration | any +} + +// MapValueTypes are possible value types. +// Even though 'any' type can be used, its pointer must implement the Decodable interface. +type MapValueTypes interface { + NumberTypes | bool | string | []byte | complex64 | complex128 | time.Time | time.Duration | any +} + +// ReadMap returns an iterator that can be used to iterate over the elements +// of a map in the MessagePack data while being read by the provided Reader. +// The type parameters K and V specify the types of the keys and values in the map. +// The returned iterator implements the iter.Seq2[K, V] interface, +// allowing for sequential access to the map elements. +// The returned function can be used to read any error that occurred during iteration when iteration is done. +func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() error) { + var err error + return func(yield func(K, V) bool) { + // Assuming Reader has a method to read array length + var length uint32 + length, err = m.ReadArrayHeader() + if err != nil { + return + } + for range length { + var key K + switch v := any(key).(type) { + case string: + if v, err = m.ReadString(); err != nil { + return + } + key = (any)(v).(K) + case []byte: + if v, err = m.ReadBytes(nil); err != nil { + return + } + key = (any)(v).(K) + case bool: + if v, err = m.ReadBool(); err != nil { + return + } + key = (any)(v).(K) + case time.Time: + if v, err = m.ReadTime(); err != nil { + return + } + key = (any)(v).(K) + case time.Duration: + if v, err = m.ReadDuration(); err != nil { + return + } + key = (any)(v).(K) + case complex64: + if v, err = m.ReadComplex64(); err != nil { + return + } + key = (any)(v).(K) + case complex128: + if v, err = m.ReadComplex128(); err != nil { + return + } + key = (any)(v).(K) + case uint8: + if v, err = m.ReadUint8(); err != nil { + return + } + key = (any)(v).(K) + case uint16: + if v, err = m.ReadUint16(); err != nil { + return + } + key = (any)(v).(K) + case uint32: + if v, err = m.ReadUint32(); err != nil { + return + } + key = (any)(v).(K) + case uint64: + if v, err = m.ReadUint64(); err != nil { + return + } + key = (any)(v).(K) + case uint: + if v, err = m.ReadUint(); err != nil { + return + } + key = (any)(v).(K) + case int8: + if v, err = m.ReadInt8(); err != nil { + return + } + key = (any)(v).(K) + case int16: + if v, err = m.ReadInt16(); err != nil { + return + } + key = (any)(v).(K) + case int32: + if v, err = m.ReadInt32(); err != nil { + return + } + key = (any)(v).(K) + case int64: + if v, err = m.ReadInt64(); err != nil { + return + } + key = (any)(v).(K) + case int: + if v, err = m.ReadInt(); err != nil { + return + } + key = (any)(v).(K) + case float32: + if v, err = m.ReadFloat32(); err != nil { + return + } + key = (any)(v).(K) + case float64: + if v, err = m.ReadFloat64(); err != nil { + return + } + key = (any)(v).(K) + default: + ptr := &key + if dc, ok := any(ptr).(Decodable); ok { + if err = dc.DecodeMsg(m); err != nil { + return + } + } else { + err = fmt.Errorf("cannot decode key into type %T", ptr) + return + } + } + + var val V + switch v := any(key).(type) { + case string: + if v, err = m.ReadString(); err != nil { + return + } + val = (any)(v).(V) + case []byte: + if v, err = m.ReadBytes(nil); err != nil { + return + } + val = (any)(v).(V) + case bool: + if v, err = m.ReadBool(); err != nil { + return + } + val = (any)(v).(V) + case time.Time: + if v, err = m.ReadTime(); err != nil { + return + } + val = (any)(v).(V) + case time.Duration: + if v, err = m.ReadDuration(); err != nil { + return + } + val = (any)(v).(V) + case complex64: + if v, err = m.ReadComplex64(); err != nil { + return + } + val = (any)(v).(V) + case complex128: + if v, err = m.ReadComplex128(); err != nil { + return + } + val = (any)(v).(V) + case int8: + if v, err = m.ReadInt8(); err != nil { + return + } + val = (any)(v).(V) + case int16: + if v, err = m.ReadInt16(); err != nil { + return + } + val = (any)(v).(V) + case int32: + if v, err = m.ReadInt32(); err != nil { + return + } + val = (any)(v).(V) + case int64: + if v, err = m.ReadInt64(); err != nil { + return + } + val = (any)(v).(V) + case int: + if v, err = m.ReadInt(); err != nil { + return + } + val = (any)(v).(V) + case float32: + if v, err = m.ReadFloat32(); err != nil { + return + } + val = (any)(v).(V) + case float64: + if v, err = m.ReadFloat64(); err != nil { + return + } + val = (any)(v).(V) + case uint8: + if v, err = m.ReadUint8(); err != nil { + return + } + val = (any)(v).(V) + case uint16: + if v, err = m.ReadUint16(); err != nil { + return + } + val = (any)(v).(V) + case uint32: + if v, err = m.ReadUint32(); err != nil { + return + } + val = (any)(v).(V) + case uint64: + if v, err = m.ReadUint64(); err != nil { + return + } + val = (any)(v).(V) + case uint: + if v, err = m.ReadUint(); err != nil { + return + } + val = (any)(v).(V) + + default: + ptr := &val + if dc, ok := any(ptr).(Decodable); ok { + if err = dc.DecodeMsg(m); err != nil { + return + } + } else { + err = fmt.Errorf("cannot decode value into type %T", ptr) + return + } + } + if !yield(key, val) { + return + } + } + }, func() error { return err } +} + +// NumberTypes is a type constraint that includes all numeric types. +type NumberTypes interface { + uint | uint8 | uint16 | uint32 | uint64 | int | int8 | int16 | int32 | int64 | float32 | float64 +} + +// ReadNumberArray returns an iterator that can be used to iterate over the elements +// of an array in the MessagePack data while being read by the provided Reader. +// The type parameter V specifies the type of the elements in the array. +// The returned iterator implements the iter.Seq[V] interface, +// allowing for sequential access to the array elements. +func ReadNumberArray[V NumberTypes](m *Reader) iter.Seq2[V, error] { + return func(yield func(V, error) bool) { + // Assuming Reader has a method to read array length + var x V + length, err := m.ReadArrayHeader() + if err != nil { + yield(x, err) + return + } + + switch any(x).(type) { + case uint8: + for range length { + v, err := m.ReadUint8() + if !yield(V(v), err) { + return + } + } + case uint16: + for range length { + v, err := m.ReadUint16() + if !yield(V(v), err) { + return + } + } + case uint32: + for range length { + v, err := m.ReadUint32() + if !yield(V(v), err) { + return + } + } + case uint64: + for range length { + v, err := m.ReadUint64() + if !yield(V(v), err) { + return + } + } + case uint: + for range length { + v, err := m.ReadUint() + if !yield(V(v), err) { + return + } + } + case int8: + for range length { + v, err := m.ReadInt8() + if !yield(V(v), err) { + return + } + } + case int16: + for range length { + v, err := m.ReadInt16() + if !yield(V(v), err) { + return + } + } + case int32: + for range length { + v, err := m.ReadInt32() + if !yield(V(v), err) { + return + } + } + case int64: + for range length { + v, err := m.ReadInt64() + if !yield(V(v), err) { + return + } + } + case int: + for range length { + v, err := m.ReadInt() + if !yield(V(v), err) { + return + } + } + case float32: + for range length { + v, err := m.ReadFloat32() + if !yield(V(v), err) { + return + } + } + case float64: + for range length { + v, err := m.ReadFloat64() + if !yield(V(v), err) { + return + } + } + default: + panic("unreachable") + } + } +} + +// ReadNumberArrayBytes returns an iterator that can be used to iterate over the elements +// of an array in the MessagePack data. +// The type parameter V specifies the type of the elements in the array. +// The returned iterator implements the iter.Seq[V] interface, +// allowing for sequential access to the array elements. +// After the iterator is exhausted, the remaining bytes in the buffer +// and any error can be read by calling the returned function. +func ReadNumberArrayBytes[V NumberTypes](b []byte) (iter.Seq[V], func() (remain []byte, err error)) { + sz, b, err := ReadArrayHeaderBytes(b) + if err != nil { + return nil, func() ([]byte, error) { return b, err } + } + + var readValue func() (V, error) + var v V + switch any(v).(type) { + case uint8: + readValue = func() (V, error) { + var val uint8 + val, b, err = ReadUint8Bytes(b) + return V(val), err + } + case uint16: + readValue = func() (V, error) { + var val uint16 + val, b, err = ReadUint16Bytes(b) + return V(val), err + } + case uint32: + readValue = func() (V, error) { + var val uint32 + val, b, err = ReadUint32Bytes(b) + return V(val), err + } + case uint64: + readValue = func() (V, error) { + var val uint64 + val, b, err = ReadUint64Bytes(b) + return V(val), err + } + case uint: + readValue = func() (V, error) { + var val uint + val, b, err = ReadUintBytes(b) + return V(val), err + } + case int8: + readValue = func() (V, error) { + var val int8 + val, b, err = ReadInt8Bytes(b) + return V(val), err + } + case int16: + readValue = func() (V, error) { + var val int16 + val, b, err = ReadInt16Bytes(b) + return V(val), err + } + case int32: + readValue = func() (V, error) { + var val int32 + val, b, err = ReadInt32Bytes(b) + return V(val), err + } + case int64: + readValue = func() (V, error) { + var val int64 + val, b, err = ReadInt64Bytes(b) + return V(val), err + } + case int: + readValue = func() (V, error) { + var val int + val, b, err = ReadIntBytes(b) + return V(val), err + } + case float32: + readValue = func() (V, error) { + var val float32 + val, b, err = ReadFloat32Bytes(b) + return V(val), err + } + case float64: + readValue = func() (V, error) { + var val float64 + val, b, err = ReadFloat64Bytes(b) + return V(val), err + } + default: + panic("unreachable") + } + return func(yield func(V) bool) { + for sz > 0 { + v, err = readValue() + if err != nil || !yield(v) { + return + } + sz-- + } + }, func() ([]byte, error) { return b, err } +} + +// ReadArrayBytes returns an iterator that can be used to iterate over the elements +// of an array in the MessagePack data while being read by the provided Reader. +// The type parameter V specifies the type of the elements in the array. +// The type parameter V must be bool, string, []byte or a type whose +// pointer implements the Unmarshaler interface. +// Use ReadNumberArrayBytes for numbers. +// The returned iterator implements the iter.Seq[V] interface, +// allowing for sequential access to the array elements. +// Byte slices will reference the same underlying data. +// After the iterator is exhausted, the remaining bytes in the buffer +// and any error can be read by calling the returned function. +func ReadArrayBytes[V ArrayExtraTypes](b []byte) (iter.Seq[V], func() (remain []byte, err error)) { + sz, b, err := ReadArrayHeaderBytes(b) + if err != nil { + return nil, func() ([]byte, error) { return b, err } + } + return func(yield func(V) bool) { + var x V + switch any(x).(type) { + case string: + for range sz { + var v string + v, b, err = ReadStringBytes(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + case []byte: + for range sz { + var v []byte + v, b, err = ReadBytesZC(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + case bool: + for range sz { + var v bool + v, b, err = ReadBoolBytes(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + case time.Time: + for range sz { + var v time.Time + v, b, err = ReadTimeBytes(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + case time.Duration: + for range sz { + var v time.Duration + v, b, err = ReadDurationBytes(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + case complex64: + for range sz { + var v complex64 + v, b, err = ReadComplex64Bytes(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + case complex128: + for range sz { + var v complex128 + v, b, err = ReadComplex128Bytes(b) + if err != nil || !yield(any(v).(V)) { + return + } + } + default: + for range sz { + var v V + ptr := &v + if um, ok := any(ptr).(Unmarshaler); ok { + b, err = um.UnmarshalMsg(b) + if err != nil || !yield(v) { + return + } + } else { + err = fmt.Errorf("cannot unmarshal into type %T", ptr) + return + } + } + } + }, func() (remain []byte, err error) { + return b, err + } +} + +// ReadMapBytes returns an iterator over key-value pairs of a map encoded in MessagePack bytes. +// The type parameters K and V specify the types of the keys and values in the map. +// The iterator yields pairs in wire order. After iteration completes (or stops early), +// call the returned tail function to get any remaining bytes and the first error encountered (if any). +// K and V must be one of the supported built-in types, or pointers to types implementing Unmarshaler. +// Byte slices will reference the same underlying data. +func ReadMapBytes[K MapKeyTypes, V MapValueTypes](b []byte) (iter.Seq2[K, V], func() (remain []byte, err error)) { + sz, b, err := ReadMapHeaderBytes(b) + if err != nil { + return nil, func() ([]byte, error) { return b, err } + } + + readKey := func() (K, error) { + var key K + switch any(key).(type) { + case string: + var v string + v, b, err = ReadStringBytes(b) + key = any(v).(K) + case []byte: + var v []byte + v, b, err = ReadBytesZC(b) + key = any(v).(K) + case bool: + var v bool + v, b, err = ReadBoolBytes(b) + key = any(v).(K) + case time.Time: + var v time.Time + v, b, err = ReadTimeBytes(b) + key = any(v).(K) + case time.Duration: + var v time.Duration + v, b, err = ReadDurationBytes(b) + key = any(v).(K) + case complex64: + var v complex64 + v, b, err = ReadComplex64Bytes(b) + key = any(v).(K) + case complex128: + var v complex128 + v, b, err = ReadComplex128Bytes(b) + key = any(v).(K) + case uint8: + var v uint8 + v, b, err = ReadUint8Bytes(b) + key = any(v).(K) + case uint16: + var v uint16 + v, b, err = ReadUint16Bytes(b) + key = any(v).(K) + case uint32: + var v uint32 + v, b, err = ReadUint32Bytes(b) + key = any(v).(K) + case uint64: + var v uint64 + v, b, err = ReadUint64Bytes(b) + key = any(v).(K) + case uint: + var v uint + v, b, err = ReadUintBytes(b) + key = any(v).(K) + case int8: + var v int8 + v, b, err = ReadInt8Bytes(b) + key = any(v).(K) + case int16: + var v int16 + v, b, err = ReadInt16Bytes(b) + key = any(v).(K) + case int32: + var v int32 + v, b, err = ReadInt32Bytes(b) + key = any(v).(K) + case int64: + var v int64 + v, b, err = ReadInt64Bytes(b) + key = any(v).(K) + case int: + var v int + v, b, err = ReadIntBytes(b) + key = any(v).(K) + case float32: + var v float32 + v, b, err = ReadFloat32Bytes(b) + key = any(v).(K) + case float64: + var v float64 + v, b, err = ReadFloat64Bytes(b) + key = any(v).(K) + default: + // Fallback for custom types implementing Unmarshaler + ptr := &key + if um, ok := any(ptr).(Unmarshaler); ok { + b, err = um.UnmarshalMsg(b) + } else { + err = fmt.Errorf("cannot unmarshal key into type %T", ptr) + } + } + return key, err + } + + readVal := func() (V, error) { + var val V + switch any(val).(type) { + case string: + var v string + v, b, err = ReadStringBytes(b) + val = any(v).(V) + case []byte: + var v []byte + v, b, err = ReadBytesZC(b) + val = any(v).(V) + case bool: + var v bool + v, b, err = ReadBoolBytes(b) + val = any(v).(V) + case time.Time: + var v time.Time + v, b, err = ReadTimeBytes(b) + val = any(v).(V) + case time.Duration: + var v time.Duration + v, b, err = ReadDurationBytes(b) + val = any(v).(V) + case complex64: + var v complex64 + v, b, err = ReadComplex64Bytes(b) + val = any(v).(V) + case complex128: + var v complex128 + v, b, err = ReadComplex128Bytes(b) + val = any(v).(V) + case uint8: + var v uint8 + v, b, err = ReadUint8Bytes(b) + val = any(v).(V) + case uint16: + var v uint16 + v, b, err = ReadUint16Bytes(b) + val = any(v).(V) + case uint32: + var v uint32 + v, b, err = ReadUint32Bytes(b) + val = any(v).(V) + case uint64: + var v uint64 + v, b, err = ReadUint64Bytes(b) + val = any(v).(V) + case uint: + var v uint + v, b, err = ReadUintBytes(b) + val = any(v).(V) + case int8: + var v int8 + v, b, err = ReadInt8Bytes(b) + val = any(v).(V) + case int16: + var v int16 + v, b, err = ReadInt16Bytes(b) + val = any(v).(V) + case int32: + var v int32 + v, b, err = ReadInt32Bytes(b) + val = any(v).(V) + case int64: + var v int64 + v, b, err = ReadInt64Bytes(b) + val = any(v).(V) + case int: + var v int + v, b, err = ReadIntBytes(b) + val = any(v).(V) + case float32: + var v float32 + v, b, err = ReadFloat32Bytes(b) + val = any(v).(V) + case float64: + var v float64 + v, b, err = ReadFloat64Bytes(b) + val = any(v).(V) + default: + // Fallback for custom types implementing Unmarshaler + ptr := &val + if um, ok := any(ptr).(Unmarshaler); ok { + b, err = um.UnmarshalMsg(b) + } else { + err = fmt.Errorf("cannot unmarshal value into type %T", ptr) + } + } + return val, err + } + + return func(yield func(K, V) bool) { + for sz > 0 { + k, e := readKey() + if e != nil { + err = e + return + } + v, e := readVal() + if e != nil { + err = e + return + } + if !yield(k, v) { + return + } + sz-- + } + }, func() ([]byte, error) { + return b, err + } +} diff --git a/msgp/iter_test.go b/msgp/iter_test.go new file mode 100644 index 00000000..15bbaaac --- /dev/null +++ b/msgp/iter_test.go @@ -0,0 +1,2923 @@ +//go:build go1.23 + +package msgp + +import ( + "bytes" + "math" + "testing" + "time" +) + +// collectSeq2 collects values from an iter.Seq2[V, error] into a slice. +// It stops at the first non-nil error and returns it together with the collected values. +func collectSeq2[V any](seq func(func(V, error) bool)) (vals []V, err error) { + seq(func(v V, e error) bool { + if e != nil { + err = e + return false + } + vals = append(vals, v) + return true + }) + return +} + +func TestReadNumberArray_Int(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + + want := []int{1, -2, 3, 0, 42} + if err := w.WriteArrayHeader(uint32(len(want))); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + for _, v := range want { + if err := w.WriteInt(v); err != nil { + t.Fatalf("WriteInt: %v", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[int](r)) + if err != nil { + t.Fatalf("iteration error: %v", err) + } + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("index %d: got %v want %v", i, got[i], want[i]) + } + } +} + +func TestReadNumberArray_Float64(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + + want := []float64{0, 1.5, -2.25, 1e9} + if err := w.WriteArrayHeader(uint32(len(want))); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + for _, v := range want { + if err := w.WriteFloat64(v); err != nil { + t.Fatalf("WriteFloat64: %v", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[float64](r)) + if err != nil { + t.Fatalf("iteration error: %v", err) + } + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("index %d: got %v want %v", i, got[i], want[i]) + } + } +} + +func TestReadArray_String(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + + want := []string{"", "a", "hello", "世界"} + if err := w.WriteArrayHeader(uint32(len(want))); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + for _, v := range want { + if err := w.WriteString(v); err != nil { + t.Fatalf("WriteString: %v", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[string](r)) + if err != nil { + t.Fatalf("iteration error: %v", err) + } + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("index %d: got %q want %q", i, got[i], want[i]) + } + } +} + +func TestReadArray_Bool(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + + want := []bool{true, false, true} + if err := w.WriteArrayHeader(uint32(len(want))); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + for _, v := range want { + if err := w.WriteBool(v); err != nil { + t.Fatalf("WriteBool: %v", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[bool](r)) + if err != nil { + t.Fatalf("iteration error: %v", err) + } + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("index %d: got %v want %v", i, got[i], want[i]) + } + } +} + +// A decodable type to exercise the default branch of ReadArray. +type testDec struct { + A int + B string +} + +func (t *testDec) MarshalMsg(i []byte) ([]byte, error) { + i = AppendInt(i, t.A) + i = AppendString(i, t.B) + return i, nil +} + +func (t *testDec) UnmarshalMsg(i []byte) ([]byte, error) { + var err error + if t.A, i, err = ReadIntBytes(i); err != nil { + return nil, err + } + if t.B, i, err = ReadStringBytes(i); err != nil { + return nil, err + } + return i, nil +} + +func (t *testDec) EncodeMsg(w *Writer) error { + if err := w.WriteInt(t.A); err != nil { + return err + } + return w.WriteString(t.B) +} + +func (t *testDec) DecodeMsg(r *Reader) error { + var err error + if t.A, err = r.ReadInt(); err != nil { + return err + } + t.B, err = r.ReadString() + return err +} + +func TestReadArray_Decodable(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + + want := []testDec{ + {A: 1, B: "x"}, + {A: -5, B: "yz"}, + } + if err := w.WriteArrayHeader(uint32(len(want))); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + for i := range want { + if err := (&want[i]).EncodeMsg(w); err != nil { + t.Fatalf("EncodeMsg: %v", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[testDec](r)) + if err != nil { + t.Fatalf("iteration error: %v", err) + } + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i].A != want[i].A || got[i].B != want[i].B { + t.Fatalf("index %d: got %+v want %+v", i, got[i], want[i]) + } + } +} + +func TestReadArray_TimeAndDuration(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + + now := time.Unix(1700000000, 123456789).UTC() + durs := []time.Duration{0, time.Second, -5 * time.Millisecond} + + // time.Time + if err := w.WriteArrayHeader(2); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + if err := w.WriteTime(now); err != nil { + t.Fatalf("WriteTime: %v", err) + } + if err := w.WriteTime(now.Add(time.Minute)); err != nil { + t.Fatalf("WriteTime: %v", err) + } + + // time.Duration + if err := w.WriteArrayHeader(uint32(len(durs))); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + for _, d := range durs { + if err := w.WriteDuration(d); err != nil { + t.Fatalf("WriteDuration: %v", err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + timesGot, err := collectSeq2(ReadArray[time.Time](r)) + if err != nil { + t.Fatalf("times iteration error: %v", err) + } + if len(timesGot) != 2 || !timesGot[0].Equal(now) || !timesGot[1].Equal(now.Add(time.Minute)) { + t.Fatalf("times mismatch: got %v", timesGot) + } + + dursGot, err := collectSeq2(ReadArray[time.Duration](r)) + if err != nil { + t.Fatalf("durations iteration error: %v", err) + } + if len(dursGot) != len(durs) { + t.Fatalf("length mismatch: got %d want %d", len(dursGot), len(durs)) + } + for i := range durs { + if dursGot[i] != durs[i] { + t.Fatalf("index %d: got %v want %v", i, dursGot[i], durs[i]) + } + } +} + +func TestReadNumberArrayBytes_Uint16(t *testing.T) { + var msg []byte + want := []uint16{0, 1, 255, 256, 65535} + + msg = AppendArrayHeader(msg, uint32(len(want))) + for _, v := range want { + msg = AppendUint16(msg, v) + } + + seq, tail := ReadNumberArrayBytes[uint16](msg) + var got []uint16 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("tail err: %v", err) + } + if len(remain) != 0 { + t.Fatalf("expected no remaining bytes, got %d", len(remain)) + } + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("index %d: got %v want %v", i, got[i], want[i]) + } + } +} + +func TestReadNumberArrayBytes_ErrOnTruncated(t *testing.T) { + var msg []byte + // 2 elements, but we will truncate the second one + msg = AppendArrayHeader(msg, 2) + msg = AppendInt32(msg, 123) + full := AppendInt32(msg, 456) + + // Truncate to cause an error when reading the second element. + trunc := full[:len(full)-2] + + seq, tail := ReadNumberArrayBytes[int32](trunc) + var got []int32 + for v := range seq { + got = append(got, v) + // The second element should fail before yielding + } + if len(got) != 1 || got[0] != 123 { + t.Fatalf("expected to read only first element (123), got %v", got) + } + remain, err := tail() + if err == nil { + t.Fatalf("expected an error from tail() on truncated input") + } + // remain can be partial bytes; ensure it's from the truncated buffer + if len(remain) != 0 { + // Not strictly required, but checks contract that tail returns the remaining unread slice. + _ = remain + } +} + +func TestReadArray_ErrorOnTooFewElements(t *testing.T) { + // Array header says 2, but only 1 element provided. + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(2); err != nil { + t.Fatalf("WriteArrayHeader: %v", err) + } + if err := w.WriteInt(7); err != nil { + t.Fatalf("WriteInt: %v", err) + } + if err := w.Flush(); err != nil { + t.Fatalf("Flush: %v", err) + } + + r := NewReader(&buf) + var got []int + var firstErr error + ReadNumberArray[int](r)(func(v int, err error) bool { + if err != nil { + firstErr = err + return false + } + got = append(got, v) + return true + }) + if firstErr == nil { + t.Fatalf("expected error due to missing second element, got nil") + } + if len(got) != 1 || got[0] != 7 { + t.Fatalf("unexpected values read before error: %v", got) + } +} + +// approxEqual checks approximate equality for float32/float64 +func approxEqual[T ~float32 | ~float64](a, b T) bool { + af := float64(a) + bf := float64(b) + const eps = 1e-6 + return math.Abs(af-bf) <= eps*(1+math.Max(math.Abs(af), math.Abs(bf))) +} + +func TestRoundtripNumberArray_AllTypes(t *testing.T) { + type testcase[V NumberTypes] struct { + name string + vals []V + write func(w *Writer, v V) error + } + now := time.Now() // not used here, avoid unused import warnings if refactoring + _ = now + + tests := []any{ + testcase[uint]{name: "uint", vals: []uint{0, 1, 255, 1 << 20, math.MaxUint32}, write: func(w *Writer, v uint) error { return w.WriteUint(v) }}, + testcase[uint8]{name: "uint8", vals: []uint8{0, 1, 127, 128, 255}, write: func(w *Writer, v uint8) error { return w.WriteUint8(v) }}, + testcase[uint16]{name: "uint16", vals: []uint16{0, 255, 256, 65535}, write: func(w *Writer, v uint16) error { return w.WriteUint16(v) }}, + testcase[uint32]{name: "uint32", vals: []uint32{0, 65535, 1 << 20, math.MaxUint32}, write: func(w *Writer, v uint32) error { return w.WriteUint32(v) }}, + testcase[uint64]{name: "uint64", vals: []uint64{0, 1, 1 << 40, math.MaxUint32 + 1, math.MaxUint64 >> 1}, write: func(w *Writer, v uint64) error { return w.WriteUint64(v) }}, + testcase[int]{name: "int", vals: []int{0, 1, -1, 1 << 20, -(1 << 20)}, write: func(w *Writer, v int) error { return w.WriteInt(v) }}, + testcase[int8]{name: "int8", vals: []int8{0, 1, -1, 127, -128}, write: func(w *Writer, v int8) error { return w.WriteInt8(v) }}, + testcase[int16]{name: "int16", vals: []int16{0, 1, -1, 32767, -32768}, write: func(w *Writer, v int16) error { return w.WriteInt16(v) }}, + testcase[int32]{name: "int32", vals: []int32{0, 1, -1, math.MaxInt32, math.MinInt32}, write: func(w *Writer, v int32) error { return w.WriteInt32(v) }}, + testcase[int64]{name: "int64", vals: []int64{0, 1, -1, math.MaxInt32 + 1, math.MinInt32 - 1}, write: func(w *Writer, v int64) error { return w.WriteInt64(v) }}, + testcase[float32]{name: "float32", vals: []float32{0, 1.5, -2.25, 3.14159, 1e20}, write: func(w *Writer, v float32) error { return w.WriteFloat32(v) }}, + testcase[float64]{name: "float64", vals: []float64{0, 1.5, -2.25, math.Pi, 1e308}, write: func(w *Writer, v float64) error { return w.WriteFloat64(v) }}, + } + + for _, anytc := range tests { + switch tc := anytc.(type) { + case testcase[uint]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[uint](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[uint8]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[uint8](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[uint16]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[uint16](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[uint32]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[uint32](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[uint64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[uint64](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[int]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[int](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[int8]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[int8](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[int16]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[int16](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[int32]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[int32](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[int64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[int64](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if got[i] != tc.vals[i] { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[float32]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[float32](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !approxEqual(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case testcase[float64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("WriteArrayHeader %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadNumberArray[float64](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !approxEqual(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + } + } +} + +func TestRoundtripArray_AllTypes(t *testing.T) { + type regCase[V ArrayExtraTypes] struct { + name string + vals []V + write func(*Writer, V) error + eq func(a, b V) bool + } + + now := time.Unix(1700000000, 123456789).UTC() + later := now.Add(5 * time.Minute) + + rcases := []any{ + regCase[bool]{name: "bool", vals: []bool{true, false, true}, write: func(w *Writer, v bool) error { return w.WriteBool(v) }, eq: func(a, b bool) bool { return a == b }}, + regCase[string]{name: "string", vals: []string{"", "a", "hello", "世界"}, write: func(w *Writer, v string) error { return w.WriteString(v) }, eq: func(a, b string) bool { return a == b }}, + regCase[[]byte]{name: "bytes", vals: [][]byte{nil, {}, {0x00}, {0x01, 0x02, 0x03}}, write: func(w *Writer, v []byte) error { return w.WriteBytes(v) }, eq: func(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true + }}, + regCase[time.Time]{name: "time", vals: []time.Time{now, later}, write: func(w *Writer, v time.Time) error { return w.WriteTime(v) }, eq: func(a, b time.Time) bool { return a.Equal(b) }}, + regCase[time.Duration]{name: "duration", vals: []time.Duration{0, time.Second, -5 * time.Millisecond}, write: func(w *Writer, v time.Duration) error { return w.WriteDuration(v) }, eq: func(a, b time.Duration) bool { return a == b }}, + regCase[complex64]{name: "complex64", vals: []complex64{0, 1 + 2i, -3.5 + 4.25i}, write: func(w *Writer, v complex64) error { return w.WriteComplex64(v) }, eq: func(a, b complex64) bool { return a == b }}, + regCase[complex128]{name: "complex128", vals: []complex128{0, 1 + 2i, -3.5 + 4.25i}, write: func(w *Writer, v complex128) error { return w.WriteComplex128(v) }, eq: func(a, b complex128) bool { return a == b }}, + regCase[testDec]{name: "decoder", vals: []testDec{{A: 1, B: "abc"}, {A: 2, B: "def"}}, write: func(w *Writer, v testDec) error { return v.EncodeMsg(w) }, eq: func(a, b testDec) bool { return a == b }}, + } + + for _, anytc := range rcases { + switch tc := anytc.(type) { + case regCase[bool]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[bool](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[string]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[string](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %q want %q", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[[]byte]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[[]byte](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[time.Time]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[time.Time](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[time.Duration]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[time.Duration](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[complex64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[complex64](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[complex128]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[complex128](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case regCase[testDec]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.vals))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for _, v := range tc.vals { + if err := tc.write(w, v); err != nil { + t.Fatalf("%s write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + got, err := collectSeq2(ReadArray[testDec](r)) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), len(tc.vals)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + + } + } +} + +func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { + type tb[V NumberTypes] struct { + name string + vals []V + append func([]byte, V) []byte + eq func(a, b V) bool + } + + tests := []any{ + tb[uint]{name: "uint", vals: []uint{0, 1, 255, 1 << 20}, append: func(b []byte, v uint) []byte { return AppendUint(b, v) }, eq: func(a, b uint) bool { return a == b }}, + tb[uint8]{name: "uint8", vals: []uint8{0, 1, 127, 128, 255}, append: func(b []byte, v uint8) []byte { return AppendUint8(b, v) }, eq: func(a, b uint8) bool { return a == b }}, + tb[uint16]{name: "uint16", vals: []uint16{0, 255, 256, 65535}, append: func(b []byte, v uint16) []byte { return AppendUint16(b, v) }, eq: func(a, b uint16) bool { return a == b }}, + tb[uint32]{name: "uint32", vals: []uint32{0, 65535, 1 << 20}, append: func(b []byte, v uint32) []byte { return AppendUint32(b, v) }, eq: func(a, b uint32) bool { return a == b }}, + tb[uint64]{name: "uint64", vals: []uint64{0, 1, 1 << 40}, append: func(b []byte, v uint64) []byte { return AppendUint64(b, v) }, eq: func(a, b uint64) bool { return a == b }}, + tb[int]{name: "int", vals: []int{0, 1, -1, 1 << 20, -(1 << 20)}, append: func(b []byte, v int) []byte { return AppendInt(b, v) }, eq: func(a, b int) bool { return a == b }}, + tb[int8]{name: "int8", vals: []int8{0, 1, -1, 127, -128}, append: func(b []byte, v int8) []byte { return AppendInt8(b, v) }, eq: func(a, b int8) bool { return a == b }}, + tb[int16]{name: "int16", vals: []int16{0, 1, -1, 32767, -32768}, append: func(b []byte, v int16) []byte { return AppendInt16(b, v) }, eq: func(a, b int16) bool { return a == b }}, + tb[int32]{name: "int32", vals: []int32{0, 1, -1, math.MaxInt32, math.MinInt32}, append: func(b []byte, v int32) []byte { return AppendInt32(b, v) }, eq: func(a, b int32) bool { return a == b }}, + tb[int64]{name: "int64", vals: []int64{0, 1, -1, math.MaxInt32 + 1, math.MinInt32 - 1}, append: func(b []byte, v int64) []byte { return AppendInt64(b, v) }, eq: func(a, b int64) bool { return a == b }}, + tb[float32]{name: "float32", vals: []float32{0, 1.5, -2.25, 3.14159, 1e20}, append: func(b []byte, v float32) []byte { return AppendFloat32(b, v) }, eq: func(a, b float32) bool { return approxEqual(a, b) }}, + tb[float64]{name: "float64", vals: []float64{0, 1.5, -2.25, math.Pi, 1e308}, append: func(b []byte, v float64) []byte { return AppendFloat(b, v) }, eq: func(a, b float64) bool { return approxEqual(a, b) }}, + } + + for _, anytc := range tests { + switch tc := anytc.(type) { + case tb[uint]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[uint](msg) + var got []uint + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[uint8]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[uint8](msg) + var got []uint8 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[uint16]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[uint16](msg) + var got []uint16 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[uint32]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[uint32](msg) + var got []uint32 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[uint64]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[uint64](msg) + var got []uint64 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[int]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[int](msg) + var got []int + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[int8]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[int8](msg) + var got []int8 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[int16]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[int16](msg) + var got []int16 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[int32]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[int32](msg) + var got []int32 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[int64]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[int64](msg) + var got []int64 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[float32]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[float32](msg) + var got []float32 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case tb[float64]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadNumberArrayBytes[float64](msg) + var got []float64 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len: %d", tc.name, len(got)) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + } + } +} + +func TestRoundtripArrayBytes_AllTypes(t *testing.T) { + type rb[V ArrayExtraTypes] struct { + name string + vals []V + append func([]byte, V) []byte + eq func(a, b V) bool + } + + now := time.Unix(1700000000, 123456789).UTC() + later := now.Add(7 * time.Second) + + rtests := []any{ + rb[bool]{name: "bool", vals: []bool{true, false, true}, append: func(b []byte, v bool) []byte { return AppendBool(b, v) }, eq: func(a, b bool) bool { return a == b }}, + rb[string]{name: "string", vals: []string{"", "hi", "世界"}, append: func(b []byte, v string) []byte { return AppendString(b, v) }, eq: func(a, b string) bool { return a == b }}, + rb[[]byte]{name: "bytes", vals: [][]byte{nil, {}, {0x00}, {0x01, 0x02}}, append: func(b []byte, v []byte) []byte { return AppendBytes(b, v) }, eq: func(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true + }}, + rb[time.Time]{name: "time", vals: []time.Time{now, later}, append: func(b []byte, v time.Time) []byte { return AppendTime(b, v) }, eq: func(a, b time.Time) bool { return a.Equal(b) }}, + rb[time.Duration]{name: "duration", vals: []time.Duration{0, time.Second, -3 * time.Millisecond}, append: func(b []byte, v time.Duration) []byte { return AppendDuration(b, v) }, eq: func(a, b time.Duration) bool { return a == b }}, + rb[complex64]{name: "complex64", vals: []complex64{0, 1 + 2i, -3.5 + 4.25i}, append: func(b []byte, v complex64) []byte { return AppendComplex64(b, v) }, eq: func(a, b complex64) bool { return a == b }}, + rb[complex128]{name: "complex128", vals: []complex128{0, 1 + 2i, -3.5 + 4.25i}, append: func(b []byte, v complex128) []byte { return AppendComplex128(b, v) }, eq: func(a, b complex128) bool { return a == b }}, + rb[testDec]{name: "unmarshal", vals: []testDec{{A: 1, B: "abc"}, {A: 2, B: "def"}}, append: func(b []byte, v testDec) []byte { b, _ = v.MarshalMsg(b); return b }, eq: func(a, b testDec) bool { return a == b }}, + } + + for _, anytc := range rtests { + switch tc := anytc.(type) { + case rb[bool]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[bool](msg) + var got []bool + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[string]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[string](msg) + var got []string + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %q want %q", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[[]byte]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[[]byte](msg) + var got [][]byte + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[time.Time]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[time.Time](msg) + var got []time.Time + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[time.Duration]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[time.Duration](msg) + var got []time.Duration + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[complex64]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[complex64](msg) + var got []complex64 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[complex128]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[complex128](msg) + var got []complex128 + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + case rb[testDec]: + msg := AppendArrayHeader(nil, uint32(len(tc.vals))) + for _, v := range tc.vals { + msg = tc.append(msg, v) + } + seq, tail := ReadArrayBytes[testDec](msg) + var got []testDec + for v := range seq { + got = append(got, v) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + if len(got) != len(tc.vals) { + t.Fatalf("%s len mismatch", tc.name) + } + for i := range got { + if !tc.eq(got[i], tc.vals[i]) { + t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) + } + } + + } + } +} + +func eqNum[T comparable](a, b T) bool { return a == b } + +func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { + type numCase[T NumberTypes] struct { + name string + keys []T + vals []T + write func(*Writer, T) error + eq func(a, b T) bool + } + // Equality helpers + eqF32 := func(a, b float32) bool { return approxEqual(a, b) } + eqF64 := func(a, b float64) bool { return approxEqual(a, b) } + + cases := []any{ + numCase[uint]{name: "uint", keys: []uint{1, 2, 3}, vals: []uint{10, 20, 30}, write: func(w *Writer, v uint) error { return w.WriteUint(v) }, eq: eqNum[uint]}, + numCase[uint8]{name: "uint8", keys: []uint8{1, 2, 255}, vals: []uint8{9, 8, 7}, write: func(w *Writer, v uint8) error { return w.WriteUint8(v) }, eq: eqNum[uint8]}, + numCase[uint16]{name: "uint16", keys: []uint16{1, 300, 65535}, vals: []uint16{100, 200, 300}, write: func(w *Writer, v uint16) error { return w.WriteUint16(v) }, eq: eqNum[uint16]}, + numCase[uint32]{name: "uint32", keys: []uint32{1, 1 << 20, 4000000000}, vals: []uint32{42, 43, 44}, write: func(w *Writer, v uint32) error { return w.WriteUint32(v) }, eq: eqNum[uint32]}, + numCase[uint64]{name: "uint64", keys: []uint64{1, 1 << 40, 1<<50 + 123}, vals: []uint64{5, 6, 7}, write: func(w *Writer, v uint64) error { return w.WriteUint64(v) }, eq: eqNum[uint64]}, + numCase[int]{name: "int", keys: []int{-1, 0, 2}, vals: []int{100, -100, 0}, write: func(w *Writer, v int) error { return w.WriteInt(v) }, eq: eqNum[int]}, + numCase[int8]{name: "int8", keys: []int8{-128, 0, 127}, vals: []int8{1, -1, 0}, write: func(w *Writer, v int8) error { return w.WriteInt8(v) }, eq: eqNum[int8]}, + numCase[int16]{name: "int16", keys: []int16{-32768, 0, 32767}, vals: []int16{2, -2, 3}, write: func(w *Writer, v int16) error { return w.WriteInt16(v) }, eq: eqNum[int16]}, + numCase[int32]{name: "int32", keys: []int32{-1, 0, 1}, vals: []int32{7, 8, 9}, write: func(w *Writer, v int32) error { return w.WriteInt32(v) }, eq: eqNum[int32]}, + numCase[int64]{name: "int64", keys: []int64{-1, 0, 1<<40 + 5}, vals: []int64{9, 8, 7}, write: func(w *Writer, v int64) error { return w.WriteInt64(v) }, eq: eqNum[int64]}, + numCase[float32]{name: "float32", keys: []float32{-2.5, 0, 3.25}, vals: []float32{1.5, -0.25, 10}, write: func(w *Writer, v float32) error { return w.WriteFloat32(v) }, eq: eqF32}, + numCase[float64]{name: "float64", keys: []float64{-2.5, 0, 3.25}, vals: []float64{1.5, -0.25, 10}, write: func(w *Writer, v float64) error { return w.WriteFloat64(v) }, eq: eqF64}, + } + + for _, anytc := range cases { + switch tc := anytc.(type) { + case numCase[uint]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[uint, uint](r) + got := make(map[uint]uint, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint8]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[uint8, uint8](r) + got := make(map[uint8]uint8, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint16]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[uint16, uint16](r) + got := make(map[uint16]uint16, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint32]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[uint32, uint32](r) + got := make(map[uint32]uint32, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[uint64, uint64](r) + got := make(map[uint64]uint64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[int, int](r) + got := make(map[int]int, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int8]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[int8, int8](r) + got := make(map[int8]int8, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int16]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[int16, int16](r) + got := make(map[int16]int16, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int32]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[int32, int32](r) + got := make(map[int32]int32, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[int64, int64](r) + got := make(map[int64]int64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[float32]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[float32, float32](r) + got := make(map[float32]float32, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[float64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[float64, float64](r) + got := make(map[float64]float64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size: got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + } + } +} + +func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { + type regCase[T any] struct { + name string + keys []T + vals []T + write func(*Writer, T) error + eq func(a, b T) bool + } + eqBool := func(a, b bool) bool { return a == b } + eqStr := func(a, b string) bool { return a == b } + eqBytes := func(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true + } + eqTime := func(a, b time.Time) bool { return a.Equal(b) } + eqDur := func(a, b time.Duration) bool { return a == b } + eqC64 := func(a, b complex64) bool { return a == b } + eqC128 := func(a, b complex128) bool { return a == b } + + now := time.Unix(1700000000, 123456789).UTC() + later := now.Add(123 * time.Second) + + cases := []any{ + regCase[bool]{name: "bool", keys: []bool{false, true}, vals: []bool{false, true, false}, write: func(w *Writer, v bool) error { return w.WriteBool(v) }, eq: eqBool}, + regCase[string]{name: "string", keys: []string{"a", "b", "c"}, vals: []string{"x", "y", "z"}, write: func(w *Writer, v string) error { return w.WriteString(v) }, eq: eqStr}, + // Note: []byte keys are not comparable in Go; we validate by pair matching instead of using a map. + regCase[[]byte]{name: "bytes", keys: [][]byte{{0x01}, {0x02, 0x03}, nil}, vals: [][]byte{{0x09}, {}, {0xFF}}, write: func(w *Writer, v []byte) error { return w.WriteBytes(v) }, eq: eqBytes}, + regCase[time.Time]{name: "time", keys: []time.Time{now, later}, vals: []time.Time{later, now}, write: func(w *Writer, v time.Time) error { return w.WriteTime(v) }, eq: eqTime}, + regCase[time.Duration]{name: "duration", keys: []time.Duration{0, time.Second, -5 * time.Millisecond}, vals: []time.Duration{time.Minute, -time.Second, 0}, write: func(w *Writer, v time.Duration) error { return w.WriteDuration(v) }, eq: eqDur}, + regCase[complex64]{name: "complex64", keys: []complex64{1 + 2i, -3 + 4.5i, 0}, vals: []complex64{9 - 2i, 3 - 4.5i, 7}, write: func(w *Writer, v complex64) error { return w.WriteComplex64(v) }, eq: eqC64}, + regCase[complex128]{name: "complex128", keys: []complex128{1 + 2i, -3 + 4.5i, 0}, vals: []complex128{9 - 2i, 3 - 4.5i, 7}, write: func(w *Writer, v complex128) error { return w.WriteComplex128(v) }, eq: eqC128}, + } + + for _, anytc := range cases { + switch tc := anytc.(type) { + case regCase[bool]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[bool, bool](r) + got := make(map[bool]bool, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s: key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[string]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[string, string](r) + got := make(map[string]string, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s: key %q got %q want %q", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[[]byte]: + // For []byte keys (not comparable), validate by pair presence. + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + + type pair struct{ k, v []byte } + expected := make([]pair, len(tc.keys)) + for i := range tc.keys { + expected[i] = pair{append([]byte(nil), tc.keys[i]...), append([]byte(nil), tc.vals[i]...)} + } + + r := NewReader(&buf) + seq, tail := ReadMap[[]byte, []byte](r) + var got []pair + for k, v := range seq { + kk := append([]byte(nil), k...) + vv := append([]byte(nil), v...) + got = append(got, pair{kk, vv}) + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(expected) { + t.Fatalf("%s size mismatch got %d want %d", tc.name, len(got), len(expected)) + } + + // Match expected pairs + match := func(p pair, set []pair) bool { + for _, q := range set { + if eqBytes(p.k, q.k) && eqBytes(p.v, q.v) { + return true + } + } + return false + } + for _, p := range expected { + if !match(p, got) { + t.Fatalf("%s missing pair key=%v val=%v", tc.name, p.k, p.v) + } + } + case regCase[time.Time]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[time.Time, time.Time](r) + got := make(map[time.Time]time.Time, len(tc.keys)) + for k, v := range seq { + got[k.UTC()] = v.UTC() + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s: key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[time.Duration]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[time.Duration, time.Duration](r) + got := make(map[time.Duration]time.Duration, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Logf("got: %#v", got) + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s: key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[complex64]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[complex64, complex64](r) + got := make(map[complex64]complex64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s: key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[complex128]: + var buf bytes.Buffer + w := NewWriter(&buf) + if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + t.Fatalf("hdr %s: %v", tc.name, err) + } + for i := range tc.keys { + if err := tc.write(w, tc.keys[i]); err != nil { + t.Fatalf("%s key write: %v", tc.name, err) + } + if err := tc.write(w, tc.vals[i]); err != nil { + t.Fatalf("%s val write: %v", tc.name, err) + } + } + if err := w.Flush(); err != nil { + t.Fatalf("flush: %v", err) + } + r := NewReader(&buf) + seq, tail := ReadMap[complex128, complex128](r) + got := make(map[complex128]complex128, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s: key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + } + } +} + +func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { + type numCase[T NumberTypes] struct { + name string + keys []T + vals []T + append func([]byte, T) []byte + eq func(a, b T) bool + } + + eqF32 := func(a, b float32) bool { return approxEqual(a, b) } + eqF64 := func(a, b float64) bool { return approxEqual(a, b) } + + cases := []any{ + numCase[uint]{name: "uint", keys: []uint{1, 2, 3}, vals: []uint{10, 20, 30}, append: AppendUint, eq: eqNum[uint]}, + numCase[uint8]{name: "uint8", keys: []uint8{1, 2, 255}, vals: []uint8{9, 8, 7}, append: AppendUint8, eq: eqNum[uint8]}, + numCase[uint16]{name: "uint16", keys: []uint16{1, 300, 65535}, vals: []uint16{100, 200, 300}, append: AppendUint16, eq: eqNum[uint16]}, + numCase[uint32]{name: "uint32", keys: []uint32{1, 1 << 20, 4000000000}, vals: []uint32{42, 43, 44}, append: AppendUint32, eq: eqNum[uint32]}, + numCase[uint64]{name: "uint64", keys: []uint64{1, 1 << 40, 1<<50 + 123}, vals: []uint64{5, 6, 7}, append: AppendUint64, eq: eqNum[uint64]}, + numCase[int]{name: "int", keys: []int{-1, 0, 2}, vals: []int{100, -100, 0}, append: AppendInt, eq: eqNum[int]}, + numCase[int8]{name: "int8", keys: []int8{-128, 0, 127}, vals: []int8{1, -1, 0}, append: AppendInt8, eq: eqNum[int8]}, + numCase[int16]{name: "int16", keys: []int16{-32768, 0, 32767}, vals: []int16{2, -2, 3}, append: AppendInt16, eq: eqNum[int16]}, + numCase[int32]{name: "int32", keys: []int32{-1, 0, 1}, vals: []int32{7, 8, 9}, append: AppendInt32, eq: eqNum[int32]}, + numCase[int64]{name: "int64", keys: []int64{-1, 0, 1<<40 + 5}, vals: []int64{9, 8, 7}, append: AppendInt64, eq: eqNum[int64]}, + numCase[float32]{name: "float32", keys: []float32{-2.5, 0, 3.25}, vals: []float32{1.5, -0.25, 10}, append: AppendFloat32, eq: eqF32}, + numCase[float64]{name: "float64", keys: []float64{-2.5, 0, 3.25}, vals: []float64{1.5, -0.25, 10}, append: AppendFloat, eq: eqF64}, + } + + for _, anytc := range cases { + switch tc := anytc.(type) { + case numCase[uint]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[uint, uint](msg) + got := make(map[uint]uint, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint8]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[uint8, uint8](msg) + got := make(map[uint8]uint8, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint16]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[uint16, uint16](msg) + got := make(map[uint16]uint16, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint32]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[uint32, uint32](msg) + got := make(map[uint32]uint32, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[uint64]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[uint64, uint64](msg) + got := make(map[uint64]uint64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[int, int](msg) + got := make(map[int]int, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int8]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[int8, int8](msg) + got := make(map[int8]int8, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int16]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[int16, int16](msg) + got := make(map[int16]int16, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int32]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[int32, int32](msg) + got := make(map[int32]int32, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[int64]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[int64, int64](msg) + got := make(map[int64]int64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[float32]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[float32, float32](msg) + got := make(map[float32]float32, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case numCase[float64]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[float64, float64](msg) + got := make(map[float64]float64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size got %d want %d", tc.name, len(got), len(tc.keys)) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + } + } +} + +func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { + type regCase[T any] struct { + name string + keys []T + vals []T + append func([]byte, T) []byte + eq func(a, b T) bool + } + + eqBool := func(a, b bool) bool { return a == b } + eqStr := func(a, b string) bool { return a == b } + eqBytes := func(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true + } + eqTime := func(a, b time.Time) bool { return a.Equal(b) } + eqDur := func(a, b time.Duration) bool { return a == b } + eqC64 := func(a, b complex64) bool { return a == b } + eqC128 := func(a, b complex128) bool { return a == b } + + now := time.Unix(1700000000, 123456789).UTC() + later := now.Add(99 * time.Second) + + cases := []any{ + regCase[bool]{name: "bool", keys: []bool{false, true}, vals: []bool{false, true}, append: AppendBool, eq: eqBool}, + regCase[string]{name: "string", keys: []string{"a", "b", "c"}, vals: []string{"x", "y", "z"}, append: AppendString, eq: eqStr}, + regCase[[]byte]{name: "bytes", keys: [][]byte{{0x01}, {0x02, 0x03}, nil}, vals: [][]byte{{0x09}, {}, {0xFF}}, append: AppendBytes, eq: eqBytes}, + regCase[time.Time]{name: "time", keys: []time.Time{now, later}, vals: []time.Time{later, now}, append: AppendTime, eq: eqTime}, + regCase[time.Duration]{name: "duration", keys: []time.Duration{0, time.Second, -5 * time.Millisecond}, vals: []time.Duration{time.Minute, -time.Second, 0}, append: AppendDuration, eq: eqDur}, + regCase[complex64]{name: "complex64", keys: []complex64{1 + 2i, -3 + 4.5i, 0}, vals: []complex64{9 - 2i, 3 - 4.5i, 7}, append: AppendComplex64, eq: eqC64}, + regCase[complex128]{name: "complex128", keys: []complex128{1 + 2i, -3 + 4.5i, 0}, vals: []complex128{9 - 2i, 3 - 4.5i, 7}, append: AppendComplex128, eq: eqC128}, + } + + for _, anytc := range cases { + switch tc := anytc.(type) { + case regCase[bool]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[bool, bool](msg) + got := make(map[bool]bool, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[string]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[string, string](msg) + got := make(map[string]string, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %q got %q want %q", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[[]byte]: + // For []byte keys (not comparable), validate by pair presence. + type pair struct{ k, v []byte } + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, append([]byte(nil), tc.keys[i]...)) + msg = tc.append(msg, append([]byte(nil), tc.vals[i]...)) + } + expected := make([]pair, len(tc.keys)) + for i := range tc.keys { + expected[i] = pair{append([]byte(nil), tc.keys[i]...), append([]byte(nil), tc.vals[i]...)} + } + + seq, tail := ReadMapBytes[[]byte, []byte](msg) + var got []pair + for k, v := range seq { + kk := append([]byte(nil), k...) + vv := append([]byte(nil), v...) + got = append(got, pair{kk, vv}) + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(expected) { + t.Fatalf("%s count got %d want %d", tc.name, len(got), len(expected)) + } + match := func(p pair, set []pair) bool { + for _, q := range set { + if eqBytes(p.k, q.k) && eqBytes(p.v, q.v) { + return true + } + } + return false + } + for _, p := range expected { + if !match(p, got) { + t.Fatalf("%s missing pair key=%v val=%v", tc.name, p.k, p.v) + } + } + case regCase[time.Time]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[time.Time, time.Time](msg) + got := make(map[time.Time]time.Time, len(tc.keys)) + for k, v := range seq { + got[k.UTC()] = v.UTC() + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[time.Duration]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[time.Duration, time.Duration](msg) + got := make(map[time.Duration]time.Duration, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[complex64]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[complex64, complex64](msg) + got := make(map[complex64]complex64, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + case regCase[complex128]: + msg := AppendMapHeader(nil, uint32(len(tc.keys))) + for i := range tc.keys { + msg = tc.append(msg, tc.keys[i]) + msg = tc.append(msg, tc.vals[i]) + } + seq, tail := ReadMapBytes[complex128, complex128](msg) + got := make(map[complex128]complex128, len(tc.keys)) + for k, v := range seq { + got[k] = v + } + remain, err := tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + if len(got) != len(tc.keys) { + t.Fatalf("%s size mismatch", tc.name) + } + for i := range tc.keys { + if !tc.eq(got[tc.keys[i]], tc.vals[i]) { + t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) + } + } + } + } +} + +func TestReadMapBytes_TailErrorOnTruncated(t *testing.T) { + // Prepare a map with 2 entries: {1: 10, 2: 20} + // Then truncate after encoding the second key to induce an error while reading the second value. + msg := AppendMapHeader(nil, 2) + msg = AppendInt(msg, 1) + msg = AppendInt(msg, 10) + msg = AppendInt(msg, 2) + full := AppendInt(msg, 20) + trunc := full[:len(full)-2] // truncate some bytes from the last value + + seq, tail := ReadMapBytes[int, int](trunc) + got := make(map[int]int) + for k, v := range seq { + got[k] = v + } + // We expect only the first pair to be read + if len(got) != 1 || got[1] != 10 { + t.Fatalf("expected only first pair (1:10), got %v", got) + } + remain, err := tail() + if err == nil { + t.Fatalf("expected tail error due to truncation") + } + _ = remain // remaining bytes content is not strictly asserted here +} diff --git a/msgp/read.go b/msgp/read.go index 409dbcec..280eb1aa 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -582,7 +582,7 @@ func (m *Reader) ReadFloat64() (f float64, err error) { var p []byte p, err = m.R.Peek(9) if err != nil { - // we'll allow a coversion from float32 to float64, + // we'll allow a conversion from float32 to float64, // since we don't lose any precision if err == io.EOF && len(p) > 0 && p[0] == mfloat32 { ef, err := m.ReadFloat32() From 53c063982618165c77facb764c88dc533a7f124d Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Aug 2025 12:38:20 +0200 Subject: [PATCH 2/7] Build guard from Go 1.22 - softer than removing support --- msgp/iter.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/msgp/iter.go b/msgp/iter.go index ff5f3884..1de78651 100644 --- a/msgp/iter.go +++ b/msgp/iter.go @@ -1,3 +1,5 @@ +//go:build go1.23 + package msgp import ( From 48987e9ffd3dd98c60e821773e8bde24a116bfb2 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 27 Aug 2025 12:45:49 +0200 Subject: [PATCH 3/7] linter feedback --- msgp/iter.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/msgp/iter.go b/msgp/iter.go index 1de78651..2d6c2242 100644 --- a/msgp/iter.go +++ b/msgp/iter.go @@ -80,6 +80,7 @@ func ReadArray[V ArrayExtraTypes](m *Reader) iter.Seq2[V, error] { for range length { v, err := m.ReadComplex128() if !yield(any(v).(V), err) { + return } } default: @@ -92,7 +93,7 @@ func ReadArray[V ArrayExtraTypes](m *Reader) iter.Seq2[V, error] { return } } else { - err = fmt.Errorf("cannot decode into type %T", ptr) + yield(v, fmt.Errorf("cannot decode into type %T", ptr)) return } } @@ -226,6 +227,7 @@ func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() } key = (any)(v).(K) default: + _ = v ptr := &key if dc, ok := any(ptr).(Decodable); ok { if err = dc.DecodeMsg(m); err != nil { @@ -336,6 +338,7 @@ func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() val = (any)(v).(V) default: + _ = v ptr := &val if dc, ok := any(ptr).(Decodable); ok { if err = dc.DecodeMsg(m); err != nil { From b6eca0cb3b8312093f552d106e85e5d05ef230f8 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 28 Aug 2025 19:42:59 +0200 Subject: [PATCH 4/7] Allow nil maps and arrays. Refactor to make all type switching happen once. Add nil tests. --- msgp/iter.go | 743 +++++++++++++++++++++++++++------------------- msgp/iter_test.go | 121 ++++++-- 2 files changed, 536 insertions(+), 328 deletions(-) diff --git a/msgp/iter.go b/msgp/iter.go index 2d6c2242..9cbb6015 100644 --- a/msgp/iter.go +++ b/msgp/iter.go @@ -26,11 +26,16 @@ type ArrayExtraTypes interface { // allowing for sequential access to the array elements. func ReadArray[V ArrayExtraTypes](m *Reader) iter.Seq2[V, error] { return func(yield func(V, error) bool) { - // Assuming Reader has a method to read array length + // Check if nil + if m.IsNil() { + m.ReadNil() + return + } + // Regular array. var x V length, err := m.ReadArrayHeader() if err != nil { - yield(x, err) + yield(x, fmt.Errorf("cannot read array header: %w", err)) return } switch any(x).(type) { @@ -122,234 +127,256 @@ type MapValueTypes interface { func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() error) { var err error return func(yield func(K, V) bool) { - // Assuming Reader has a method to read array length - var length uint32 - length, err = m.ReadArrayHeader() + var sz uint32 + if m.IsNil() { + err = m.ReadNil() + return + } + sz, err = m.ReadMapHeader() if err != nil { + err = fmt.Errorf("cannot read map header: %w", err) return } - for range length { - var key K - switch v := any(key).(type) { + + // Prepare per-type readers once, avoid switching for each element. + // Key reader. + var readKey func() (K, error) + { + var keyZero K + switch any(keyZero).(type) { case string: - if v, err = m.ReadString(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadString() + return any(v).(K), e } - key = (any)(v).(K) case []byte: - if v, err = m.ReadBytes(nil); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadBytes(nil) + return any(v).(K), e } - key = (any)(v).(K) case bool: - if v, err = m.ReadBool(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadBool() + return any(v).(K), e } - key = (any)(v).(K) case time.Time: - if v, err = m.ReadTime(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadTime() + return any(v).(K), e } - key = (any)(v).(K) case time.Duration: - if v, err = m.ReadDuration(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadDuration() + return any(v).(K), e } - key = (any)(v).(K) case complex64: - if v, err = m.ReadComplex64(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadComplex64() + return any(v).(K), e } - key = (any)(v).(K) case complex128: - if v, err = m.ReadComplex128(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadComplex128() + return any(v).(K), e } - key = (any)(v).(K) case uint8: - if v, err = m.ReadUint8(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadUint8() + return any(v).(K), e } - key = (any)(v).(K) case uint16: - if v, err = m.ReadUint16(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadUint16() + return any(v).(K), e } - key = (any)(v).(K) case uint32: - if v, err = m.ReadUint32(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadUint32() + return any(v).(K), e } - key = (any)(v).(K) case uint64: - if v, err = m.ReadUint64(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadUint64() + return any(v).(K), e } - key = (any)(v).(K) case uint: - if v, err = m.ReadUint(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadUint() + return any(v).(K), e } - key = (any)(v).(K) case int8: - if v, err = m.ReadInt8(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadInt8() + return any(v).(K), e } - key = (any)(v).(K) case int16: - if v, err = m.ReadInt16(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadInt16() + return any(v).(K), e } - key = (any)(v).(K) case int32: - if v, err = m.ReadInt32(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadInt32() + return any(v).(K), e } - key = (any)(v).(K) case int64: - if v, err = m.ReadInt64(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadInt64() + return any(v).(K), e } - key = (any)(v).(K) case int: - if v, err = m.ReadInt(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadInt() + return any(v).(K), e } - key = (any)(v).(K) case float32: - if v, err = m.ReadFloat32(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadFloat32() + return any(v).(K), e } - key = (any)(v).(K) case float64: - if v, err = m.ReadFloat64(); err != nil { - return + readKey = func() (K, error) { + v, e := m.ReadFloat64() + return any(v).(K), e } - key = (any)(v).(K) default: - _ = v - ptr := &key - if dc, ok := any(ptr).(Decodable); ok { - if err = dc.DecodeMsg(m); err != nil { - return + readKey = func() (K, error) { + var k K + ptr := &k + if dc, ok := any(ptr).(Decodable); ok { + return k, dc.DecodeMsg(m) } - } else { - err = fmt.Errorf("cannot decode key into type %T", ptr) - return + return k, fmt.Errorf("cannot decode key into type %T", ptr) } } + } - var val V - switch v := any(key).(type) { + // Value reader. + var readVal func() (V, error) + { + var valZero V + switch any(valZero).(type) { case string: - if v, err = m.ReadString(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadString() + return any(v).(V), e } - val = (any)(v).(V) case []byte: - if v, err = m.ReadBytes(nil); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadBytes(nil) + return any(v).(V), e } - val = (any)(v).(V) case bool: - if v, err = m.ReadBool(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadBool() + return any(v).(V), e } - val = (any)(v).(V) case time.Time: - if v, err = m.ReadTime(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadTime() + return any(v).(V), e } - val = (any)(v).(V) case time.Duration: - if v, err = m.ReadDuration(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadDuration() + return any(v).(V), e } - val = (any)(v).(V) case complex64: - if v, err = m.ReadComplex64(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadComplex64() + return any(v).(V), e } - val = (any)(v).(V) case complex128: - if v, err = m.ReadComplex128(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadComplex128() + return any(v).(V), e } - val = (any)(v).(V) case int8: - if v, err = m.ReadInt8(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadInt8() + return any(v).(V), e } - val = (any)(v).(V) case int16: - if v, err = m.ReadInt16(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadInt16() + return any(v).(V), e } - val = (any)(v).(V) case int32: - if v, err = m.ReadInt32(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadInt32() + return any(v).(V), e } - val = (any)(v).(V) case int64: - if v, err = m.ReadInt64(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadInt64() + return any(v).(V), e } - val = (any)(v).(V) case int: - if v, err = m.ReadInt(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadInt() + return any(v).(V), e } - val = (any)(v).(V) case float32: - if v, err = m.ReadFloat32(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadFloat32() + return any(v).(V), e } - val = (any)(v).(V) case float64: - if v, err = m.ReadFloat64(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadFloat64() + return any(v).(V), e } - val = (any)(v).(V) case uint8: - if v, err = m.ReadUint8(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadUint8() + return any(v).(V), e } - val = (any)(v).(V) case uint16: - if v, err = m.ReadUint16(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadUint16() + return any(v).(V), e } - val = (any)(v).(V) case uint32: - if v, err = m.ReadUint32(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadUint32() + return any(v).(V), e } - val = (any)(v).(V) case uint64: - if v, err = m.ReadUint64(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadUint64() + return any(v).(V), e } - val = (any)(v).(V) case uint: - if v, err = m.ReadUint(); err != nil { - return + readVal = func() (V, error) { + v, e := m.ReadUint() + return any(v).(V), e } - val = (any)(v).(V) - default: - _ = v - ptr := &val - if dc, ok := any(ptr).(Decodable); ok { - if err = dc.DecodeMsg(m); err != nil { - return + readVal = func() (V, error) { + var v V + ptr := &v + if dc, ok := any(ptr).(Decodable); ok { + return v, dc.DecodeMsg(m) } - } else { - err = fmt.Errorf("cannot decode value into type %T", ptr) - return + return v, fmt.Errorf("cannot decode value into type %T", ptr) } } - if !yield(key, val) { + } + + for range sz { + var k K + k, err = readKey() + if err != nil { + err = fmt.Errorf("cannot read key: %w", err) + return + } + var v V + v, err = readVal() + if err != nil { + err = fmt.Errorf("cannot read value: %w", err) + return + } + if !yield(k, v) { return } } @@ -368,11 +395,14 @@ type NumberTypes interface { // allowing for sequential access to the array elements. func ReadNumberArray[V NumberTypes](m *Reader) iter.Seq2[V, error] { return func(yield func(V, error) bool) { - // Assuming Reader has a method to read array length + if m.IsNil() { + m.ReadNil() + return + } var x V length, err := m.ReadArrayHeader() if err != nil { - yield(x, err) + yield(x, fmt.Errorf("cannot read array header: %w", err)) return } @@ -475,9 +505,14 @@ func ReadNumberArray[V NumberTypes](m *Reader) iter.Seq2[V, error] { // After the iterator is exhausted, the remaining bytes in the buffer // and any error can be read by calling the returned function. func ReadNumberArrayBytes[V NumberTypes](b []byte) (iter.Seq[V], func() (remain []byte, err error)) { + if IsNil(b) { + b, err := ReadNilBytes(b) + return func(yield func(V) bool) {}, func() ([]byte, error) { return b, err } + } + // Regular array. sz, b, err := ReadArrayHeaderBytes(b) if err != nil { - return nil, func() ([]byte, error) { return b, err } + return func(yield func(V) bool) {}, func() ([]byte, error) { return b, fmt.Errorf("cannot read array header: %w", err) } } var readValue func() (V, error) @@ -581,6 +616,10 @@ func ReadNumberArrayBytes[V NumberTypes](b []byte) (iter.Seq[V], func() (remain // After the iterator is exhausted, the remaining bytes in the buffer // and any error can be read by calling the returned function. func ReadArrayBytes[V ArrayExtraTypes](b []byte) (iter.Seq[V], func() (remain []byte, err error)) { + if IsNil(b) { + b, err := ReadNilBytes(b) + return func(yield func(V) bool) {}, func() ([]byte, error) { return b, err } + } sz, b, err := ReadArrayHeaderBytes(b) if err != nil { return nil, func() ([]byte, error) { return b, err } @@ -664,218 +703,304 @@ func ReadArrayBytes[V ArrayExtraTypes](b []byte) (iter.Seq[V], func() (remain [] } } -// ReadMapBytes returns an iterator over key-value pairs of a map encoded in MessagePack bytes. -// The type parameters K and V specify the types of the keys and values in the map. -// The iterator yields pairs in wire order. After iteration completes (or stops early), -// call the returned tail function to get any remaining bytes and the first error encountered (if any). -// K and V must be one of the supported built-in types, or pointers to types implementing Unmarshaler. -// Byte slices will reference the same underlying data. +// ReadMapBytes returns an iterator over key/value pairs from a MessagePack map encoded in b. +// The iterator yields K,V pairs and this function also returns a closure to obtain the remaining bytes and any error. +// It avoids per-element type switches by precomputing readKey/readVal funcs based on K and V. func ReadMapBytes[K MapKeyTypes, V MapValueTypes](b []byte) (iter.Seq2[K, V], func() (remain []byte, err error)) { - sz, b, err := ReadMapHeaderBytes(b) + var err error + var sz uint32 + if IsNil(b) { + b, err = ReadNilBytes(b) + return func(yield func(K, V) bool) {}, func() ([]byte, error) { return b, err } + } + sz, b, err = ReadMapHeaderBytes(b) if err != nil { - return nil, func() ([]byte, error) { return b, err } + return func(yield func(K, V) bool) {}, func() ([]byte, error) { return b, err } } - readKey := func() (K, error) { - var key K - switch any(key).(type) { + // Precompute key reader + var readKey func() (K, error) + { + var keyZero K + switch any(keyZero).(type) { case string: - var v string - v, b, err = ReadStringBytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadStringBytes(b) + b, err = e, er + return any(v).(K), err + } case []byte: - var v []byte - v, b, err = ReadBytesZC(b) - key = any(v).(K) + // Map keys can be str or bin; use specialized helper that accepts both. + readKey = func() (K, error) { + v, e, er := ReadMapKeyZC(b) + b, err = e, er + return any(v).(K), err + } case bool: - var v bool - v, b, err = ReadBoolBytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadBoolBytes(b) + b, err = e, er + return any(v).(K), err + } case time.Time: - var v time.Time - v, b, err = ReadTimeBytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadTimeBytes(b) + b, err = e, er + return any(v).(K), err + } case time.Duration: - var v time.Duration - v, b, err = ReadDurationBytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadDurationBytes(b) + b, err = e, er + return any(v).(K), err + } case complex64: - var v complex64 - v, b, err = ReadComplex64Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadComplex64Bytes(b) + b, err = e, er + return any(v).(K), err + } case complex128: - var v complex128 - v, b, err = ReadComplex128Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadComplex128Bytes(b) + b, err = e, er + return any(v).(K), err + } case uint8: - var v uint8 - v, b, err = ReadUint8Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadUint8Bytes(b) + b, err = e, er + return any(v).(K), err + } case uint16: - var v uint16 - v, b, err = ReadUint16Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadUint16Bytes(b) + b, err = e, er + return any(v).(K), err + } case uint32: - var v uint32 - v, b, err = ReadUint32Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadUint32Bytes(b) + b, err = e, er + return any(v).(K), err + } case uint64: - var v uint64 - v, b, err = ReadUint64Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadUint64Bytes(b) + b, err = e, er + return any(v).(K), err + } case uint: - var v uint - v, b, err = ReadUintBytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadUintBytes(b) + b, err = e, er + return any(v).(K), err + } case int8: - var v int8 - v, b, err = ReadInt8Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadInt8Bytes(b) + b, err = e, er + return any(v).(K), err + } case int16: - var v int16 - v, b, err = ReadInt16Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadInt16Bytes(b) + b, err = e, er + return any(v).(K), err + } case int32: - var v int32 - v, b, err = ReadInt32Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadInt32Bytes(b) + b, err = e, er + return any(v).(K), err + } case int64: - var v int64 - v, b, err = ReadInt64Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadInt64Bytes(b) + b, err = e, er + return any(v).(K), err + } case int: - var v int - v, b, err = ReadIntBytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadIntBytes(b) + b, err = e, er + return any(v).(K), err + } case float32: - var v float32 - v, b, err = ReadFloat32Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadFloat32Bytes(b) + b, err = e, er + return any(v).(K), err + } case float64: - var v float64 - v, b, err = ReadFloat64Bytes(b) - key = any(v).(K) + readKey = func() (K, error) { + v, e, er := ReadFloat64Bytes(b) + b, err = e, er + return any(v).(K), err + } default: - // Fallback for custom types implementing Unmarshaler - ptr := &key - if um, ok := any(ptr).(Unmarshaler); ok { - b, err = um.UnmarshalMsg(b) - } else { - err = fmt.Errorf("cannot unmarshal key into type %T", ptr) + readKey = func() (K, error) { + var k K + ptr := &k + if um, ok := any(ptr).(Unmarshaler); ok { + var e error + b, e = um.UnmarshalMsg(b) + return k, e + } + return k, fmt.Errorf("cannot unmarshal key into type %T", ptr) } } - return key, err } - readVal := func() (V, error) { - var val V - switch any(val).(type) { + // Precompute value reader + var readVal func() (V, error) + { + var valZero V + switch any(valZero).(type) { case string: - var v string - v, b, err = ReadStringBytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadStringBytes(b) + b, err = e, er + return any(v).(V), err + } case []byte: - var v []byte - v, b, err = ReadBytesZC(b) - val = any(v).(V) + // For values, zero-copy read of bin/str payload. + readVal = func() (V, error) { + v, e, er := ReadBytesZC(b) + b, err = e, er + return any(v).(V), err + } case bool: - var v bool - v, b, err = ReadBoolBytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadBoolBytes(b) + b, err = e, er + return any(v).(V), err + } case time.Time: - var v time.Time - v, b, err = ReadTimeBytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadTimeBytes(b) + b, err = e, er + return any(v).(V), err + } case time.Duration: - var v time.Duration - v, b, err = ReadDurationBytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadDurationBytes(b) + b, err = e, er + return any(v).(V), err + } case complex64: - var v complex64 - v, b, err = ReadComplex64Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadComplex64Bytes(b) + b, err = e, er + return any(v).(V), err + } case complex128: - var v complex128 - v, b, err = ReadComplex128Bytes(b) - val = any(v).(V) - case uint8: - var v uint8 - v, b, err = ReadUint8Bytes(b) - val = any(v).(V) - case uint16: - var v uint16 - v, b, err = ReadUint16Bytes(b) - val = any(v).(V) - case uint32: - var v uint32 - v, b, err = ReadUint32Bytes(b) - val = any(v).(V) - case uint64: - var v uint64 - v, b, err = ReadUint64Bytes(b) - val = any(v).(V) - case uint: - var v uint - v, b, err = ReadUintBytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadComplex128Bytes(b) + b, err = e, er + return any(v).(V), err + } case int8: - var v int8 - v, b, err = ReadInt8Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadInt8Bytes(b) + b, err = e, er + return any(v).(V), err + } case int16: - var v int16 - v, b, err = ReadInt16Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadInt16Bytes(b) + b, err = e, er + return any(v).(V), err + } case int32: - var v int32 - v, b, err = ReadInt32Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadInt32Bytes(b) + b, err = e, er + return any(v).(V), err + } case int64: - var v int64 - v, b, err = ReadInt64Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadInt64Bytes(b) + b, err = e, er + return any(v).(V), err + } case int: - var v int - v, b, err = ReadIntBytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadIntBytes(b) + b, err = e, er + return any(v).(V), err + } case float32: - var v float32 - v, b, err = ReadFloat32Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadFloat32Bytes(b) + b, err = e, er + return any(v).(V), err + } case float64: - var v float64 - v, b, err = ReadFloat64Bytes(b) - val = any(v).(V) + readVal = func() (V, error) { + v, e, er := ReadFloat64Bytes(b) + b, err = e, er + return any(v).(V), err + } + case uint8: + readVal = func() (V, error) { + v, e, er := ReadUint8Bytes(b) + b, err = e, er + return any(v).(V), err + } + case uint16: + readVal = func() (V, error) { + v, e, er := ReadUint16Bytes(b) + b, err = e, er + return any(v).(V), err + } + case uint32: + readVal = func() (V, error) { + v, e, er := ReadUint32Bytes(b) + b, err = e, er + return any(v).(V), err + } + case uint64: + readVal = func() (V, error) { + v, e, er := ReadUint64Bytes(b) + b, err = e, er + return any(v).(V), err + } + case uint: + readVal = func() (V, error) { + v, e, er := ReadUintBytes(b) + b, err = e, er + return any(v).(V), err + } default: - // Fallback for custom types implementing Unmarshaler - ptr := &val - if um, ok := any(ptr).(Unmarshaler); ok { - b, err = um.UnmarshalMsg(b) - } else { - err = fmt.Errorf("cannot unmarshal value into type %T", ptr) + readVal = func() (V, error) { + var v V + ptr := &v + if um, ok := any(ptr).(Unmarshaler); ok { + var e error + b, e = um.UnmarshalMsg(b) + return v, e + } + return v, fmt.Errorf("cannot unmarshal value into type %T", ptr) } } - return val, err } return func(yield func(K, V) bool) { - for sz > 0 { - k, e := readKey() - if e != nil { - err = e - return - } - v, e := readVal() - if e != nil { - err = e - return - } - if !yield(k, v) { - return - } - sz-- + for range sz { + k, er := readKey() + if er != nil { + err = fmt.Errorf("cannot read map key: %w", er) + return + } + v, er := readVal() + if er != nil { + err = fmt.Errorf("cannot read map value: %w", er) + return + } + if !yield(k, v) { + return } - }, func() ([]byte, error) { - return b, err } + }, func() ([]byte, error) { return b, err } } diff --git a/msgp/iter_test.go b/msgp/iter_test.go index 15bbaaac..be7ff7b4 100644 --- a/msgp/iter_test.go +++ b/msgp/iter_test.go @@ -9,6 +9,8 @@ import ( "time" ) +var nilMsg = AppendNil(nil) + // collectSeq2 collects values from an iter.Seq2[V, error] into a slice. // It stops at the first non-nil error and returns it together with the collected values. func collectSeq2[V any](seq func(func(V, error) bool)) (vals []V, err error) { @@ -434,6 +436,14 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } + got, err = collectSeq2(ReadNumberArray[uint](NewReader(bytes.NewReader(nilMsg)))) + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + if len(got) != 0 { + t.Fatalf("%s len: got %d want %d", tc.name, len(got), 0) + } + case testcase[uint8]: var buf bytes.Buffer w := NewWriter(&buf) @@ -796,6 +806,15 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } + r.Reset(bytes.NewReader(nilMsg)) + got, err = collectSeq2(ReadArray[bool](r)) + if len(got) != 0 { + t.Fatalf("%s len: got %d want 0", tc.name, len(got)) + } + if err != nil { + t.Fatalf("%s iterate: %v", tc.name, err) + } + case regCase[string]: var buf bytes.Buffer w := NewWriter(&buf) @@ -1040,6 +1059,18 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } + seq, tail = ReadNumberArrayBytes[uint](nilMsg) + for range seq { + t.Fatalf("%s: got entries on nil", tc.name) + } + remain, err = tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } + case tb[uint8]: msg := AppendArrayHeader(nil, uint32(len(tc.vals))) for _, v := range tc.vals { @@ -1378,6 +1409,17 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } + seq, tail = ReadArrayBytes[bool](nilMsg) + for range seq { + t.Fatalf("%s: got entries on nil", tc.name) + } + remain, err = tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain: %d", tc.name, len(remain)) + } case rb[string]: msg := AppendArrayHeader(nil, uint32(len(tc.vals))) for _, v := range tc.vals { @@ -1592,7 +1634,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[uint]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1623,10 +1665,19 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v: got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } + // Test nil + r = NewReader(bytes.NewReader(nilMsg)) + seq, tail = ReadMap[uint, uint](r) + for k, v := range seq { + t.Fatalf("nil %s: got key %v val %v", tc.name, k, v) + } + if err := tail(); err != nil { + t.Fatalf("nil %s: tail: %v", tc.name, err) + } case numCase[uint8]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1660,7 +1711,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[uint16]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1694,7 +1745,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[uint32]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1728,7 +1779,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[uint64]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1762,7 +1813,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[int]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1796,7 +1847,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[int8]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1830,7 +1881,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[int16]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1864,7 +1915,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[int32]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1898,7 +1949,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[int64]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1932,7 +1983,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[float32]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -1966,7 +2017,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { case numCase[float64]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2046,7 +2097,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { case regCase[bool]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2077,10 +2128,19 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s: key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } + r = NewReader(bytes.NewReader(nilMsg)) + seq, tail = ReadMap[bool, bool](r) + for k, v := range seq { + t.Fatalf("%s:expected ni results, got %v:%v", tc.name, k, v) + } + if err := tail(); err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + case regCase[string]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2115,7 +2175,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { // For []byte keys (not comparable), validate by pair presence. var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2168,7 +2228,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { case regCase[time.Time]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2202,7 +2262,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { case regCase[time.Duration]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2237,7 +2297,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { case regCase[complex64]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2271,7 +2331,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { case regCase[complex128]: var buf bytes.Buffer w := NewWriter(&buf) - if err := w.WriteArrayHeader(uint32(len(tc.keys))); err != nil { + if err := w.WriteMapHeader(uint32(len(tc.keys))); err != nil { t.Fatalf("hdr %s: %v", tc.name, err) } for i := range tc.keys { @@ -2361,6 +2421,17 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } + seq, tail = ReadMapBytes[uint, uint](nilMsg) + for k, v := range seq { + t.Fatalf("%s: got %v:%v want nothing", tc.name, k, v) + } + remain, err = tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } case numCase[uint8]: msg := AppendMapHeader(nil, uint32(len(tc.keys))) for i := range tc.keys { @@ -2719,6 +2790,18 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } + seq, tail = ReadMapBytes[bool, bool](nilMsg) + for k, v := range seq { + t.Fatalf("%s key %v:%v want nothing", tc.name, k, v) + } + remain, err = tail() + if err != nil { + t.Fatalf("%s tail: %v", tc.name, err) + } + if len(remain) != 0 { + t.Fatalf("%s remain %d", tc.name, len(remain)) + } + case regCase[string]: msg := AppendMapHeader(nil, uint32(len(tc.keys))) for i := range tc.keys { From 917044b32409695e31a0df89434bc227dfff7525 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sat, 30 Aug 2025 13:20:51 +0200 Subject: [PATCH 5/7] Completely rewrite and add write functions. --- msgp/iter.go | 1140 ++++++++++----------------------------------- msgp/iter_test.go | 433 +++++++++++++---- 2 files changed, 587 insertions(+), 986 deletions(-) diff --git a/msgp/iter.go b/msgp/iter.go index 9cbb6015..0b3a2349 100644 --- a/msgp/iter.go +++ b/msgp/iter.go @@ -3,119 +3,66 @@ package msgp import ( + "cmp" "fmt" "iter" - "time" + "maps" + "math" + "slices" ) -// ArrayExtraTypes is a type constraint that includes all types that can be -// decoded from an array. -// Even though 'any' type can be used, its pointer must implement the -// Decodable when using Reader or Unmarhaler interface when reading from bytes. -type ArrayExtraTypes interface { - bool | string | []byte | complex64 | complex128 | time.Time | time.Duration | any -} - // ReadArray returns an iterator that can be used to iterate over the elements // of an array in the MessagePack data while being read by the provided Reader. // The type parameter V specifies the type of the elements in the array. -// The type parameter V must be ArrayExtraTypes or a type whose -// pointer implements the Decodable interface. -// Use ReadNumberArray for numbers. // The returned iterator implements the iter.Seq[V] interface, // allowing for sequential access to the array elements. -func ReadArray[V ArrayExtraTypes](m *Reader) iter.Seq2[V, error] { - return func(yield func(V, error) bool) { +func ReadArray[T any](m *Reader, readFn func() (T, error)) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { // Check if nil if m.IsNil() { m.ReadNil() return } // Regular array. - var x V + var empty T length, err := m.ReadArrayHeader() if err != nil { - yield(x, fmt.Errorf("cannot read array header: %w", err)) + yield(empty, fmt.Errorf("cannot read array header: %w", err)) return } - switch any(x).(type) { - case string: - for range length { - v, err := m.ReadString() - if !yield(any(v).(V), err) { - return - } - } - case []byte: - for range length { - v, err := m.ReadBytes(nil) - if !yield(any(v).(V), err) { - return - } - } - case bool: - for range length { - v, err := m.ReadBool() - if !yield(any(v).(V), err) { - return - } - } - case time.Time: - for range length { - v, err := m.ReadTime() - if !yield(any(v).(V), err) { - return - } - } - case time.Duration: - for range length { - v, err := m.ReadDuration() - if !yield(any(v).(V), err) { - return - } - } - case complex64: - for range length { - v, err := m.ReadComplex64() - if !yield(any(v).(V), err) { - return - } - } - case complex128: - for range length { - v, err := m.ReadComplex128() - if !yield(any(v).(V), err) { - return - } - } - default: - for range length { - var v V - ptr := &v - if dc, ok := any(ptr).(Decodable); ok { - err = dc.DecodeMsg(m) - if !yield(v, err) { - return - } - } else { - yield(v, fmt.Errorf("cannot decode into type %T", ptr)) - return - } + for range length { + var v T + v, err = readFn() + if !yield(v, err) { + return } } } } -// MapKeyTypes are possible key types. Usually this is a string. -// Even though 'any' type can be used, its pointer must implement the Decodable interface. -type MapKeyTypes interface { - NumberTypes | bool | string | []byte | complex64 | complex128 | time.Time | time.Duration | any -} - -// MapValueTypes are possible value types. -// Even though 'any' type can be used, its pointer must implement the Decodable interface. -type MapValueTypes interface { - NumberTypes | bool | string | []byte | complex64 | complex128 | time.Time | time.Duration | any +// WriteArray writes an array to the provided Writer. +// The writeFn parameter specifies the function to use to write each element of the array. +func WriteArray[T any](w *Writer, a []T, writeFn func(T) error) error { + // Check if nil + if a == nil { + return w.WriteNil() + } + if uint64(len(a)) > math.MaxUint32 { + return fmt.Errorf("array too large to encode: %d elements", len(a)) + } + // Write array header + err := w.WriteArrayHeader(uint32(len(a))) + if err != nil { + return err + } + // Write elements + for _, v := range a { + err = writeFn(v) + if err != nil { + return err + } + } + return nil } // ReadMap returns an iterator that can be used to iterate over the elements @@ -124,7 +71,7 @@ type MapValueTypes interface { // The returned iterator implements the iter.Seq2[K, V] interface, // allowing for sequential access to the map elements. // The returned function can be used to read any error that occurred during iteration when iteration is done. -func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() error) { +func ReadMap[K, V any](m *Reader, readKey func() (K, error), readVal func() (V, error)) (iter.Seq2[K, V], func() error) { var err error return func(yield func(K, V) bool) { var sz uint32 @@ -138,231 +85,6 @@ func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() return } - // Prepare per-type readers once, avoid switching for each element. - // Key reader. - var readKey func() (K, error) - { - var keyZero K - switch any(keyZero).(type) { - case string: - readKey = func() (K, error) { - v, e := m.ReadString() - return any(v).(K), e - } - case []byte: - readKey = func() (K, error) { - v, e := m.ReadBytes(nil) - return any(v).(K), e - } - case bool: - readKey = func() (K, error) { - v, e := m.ReadBool() - return any(v).(K), e - } - case time.Time: - readKey = func() (K, error) { - v, e := m.ReadTime() - return any(v).(K), e - } - case time.Duration: - readKey = func() (K, error) { - v, e := m.ReadDuration() - return any(v).(K), e - } - case complex64: - readKey = func() (K, error) { - v, e := m.ReadComplex64() - return any(v).(K), e - } - case complex128: - readKey = func() (K, error) { - v, e := m.ReadComplex128() - return any(v).(K), e - } - case uint8: - readKey = func() (K, error) { - v, e := m.ReadUint8() - return any(v).(K), e - } - case uint16: - readKey = func() (K, error) { - v, e := m.ReadUint16() - return any(v).(K), e - } - case uint32: - readKey = func() (K, error) { - v, e := m.ReadUint32() - return any(v).(K), e - } - case uint64: - readKey = func() (K, error) { - v, e := m.ReadUint64() - return any(v).(K), e - } - case uint: - readKey = func() (K, error) { - v, e := m.ReadUint() - return any(v).(K), e - } - case int8: - readKey = func() (K, error) { - v, e := m.ReadInt8() - return any(v).(K), e - } - case int16: - readKey = func() (K, error) { - v, e := m.ReadInt16() - return any(v).(K), e - } - case int32: - readKey = func() (K, error) { - v, e := m.ReadInt32() - return any(v).(K), e - } - case int64: - readKey = func() (K, error) { - v, e := m.ReadInt64() - return any(v).(K), e - } - case int: - readKey = func() (K, error) { - v, e := m.ReadInt() - return any(v).(K), e - } - case float32: - readKey = func() (K, error) { - v, e := m.ReadFloat32() - return any(v).(K), e - } - case float64: - readKey = func() (K, error) { - v, e := m.ReadFloat64() - return any(v).(K), e - } - default: - readKey = func() (K, error) { - var k K - ptr := &k - if dc, ok := any(ptr).(Decodable); ok { - return k, dc.DecodeMsg(m) - } - return k, fmt.Errorf("cannot decode key into type %T", ptr) - } - } - } - - // Value reader. - var readVal func() (V, error) - { - var valZero V - switch any(valZero).(type) { - case string: - readVal = func() (V, error) { - v, e := m.ReadString() - return any(v).(V), e - } - case []byte: - readVal = func() (V, error) { - v, e := m.ReadBytes(nil) - return any(v).(V), e - } - case bool: - readVal = func() (V, error) { - v, e := m.ReadBool() - return any(v).(V), e - } - case time.Time: - readVal = func() (V, error) { - v, e := m.ReadTime() - return any(v).(V), e - } - case time.Duration: - readVal = func() (V, error) { - v, e := m.ReadDuration() - return any(v).(V), e - } - case complex64: - readVal = func() (V, error) { - v, e := m.ReadComplex64() - return any(v).(V), e - } - case complex128: - readVal = func() (V, error) { - v, e := m.ReadComplex128() - return any(v).(V), e - } - case int8: - readVal = func() (V, error) { - v, e := m.ReadInt8() - return any(v).(V), e - } - case int16: - readVal = func() (V, error) { - v, e := m.ReadInt16() - return any(v).(V), e - } - case int32: - readVal = func() (V, error) { - v, e := m.ReadInt32() - return any(v).(V), e - } - case int64: - readVal = func() (V, error) { - v, e := m.ReadInt64() - return any(v).(V), e - } - case int: - readVal = func() (V, error) { - v, e := m.ReadInt() - return any(v).(V), e - } - case float32: - readVal = func() (V, error) { - v, e := m.ReadFloat32() - return any(v).(V), e - } - case float64: - readVal = func() (V, error) { - v, e := m.ReadFloat64() - return any(v).(V), e - } - case uint8: - readVal = func() (V, error) { - v, e := m.ReadUint8() - return any(v).(V), e - } - case uint16: - readVal = func() (V, error) { - v, e := m.ReadUint16() - return any(v).(V), e - } - case uint32: - readVal = func() (V, error) { - v, e := m.ReadUint32() - return any(v).(V), e - } - case uint64: - readVal = func() (V, error) { - v, e := m.ReadUint64() - return any(v).(V), e - } - case uint: - readVal = func() (V, error) { - v, e := m.ReadUint() - return any(v).(V), e - } - default: - readVal = func() (V, error) { - var v V - ptr := &v - if dc, ok := any(ptr).(Decodable); ok { - return v, dc.DecodeMsg(m) - } - return v, fmt.Errorf("cannot decode value into type %T", ptr) - } - } - } - for range sz { var k K k, err = readKey() @@ -383,330 +105,119 @@ func ReadMap[K MapKeyTypes, V MapValueTypes](m *Reader) (iter.Seq2[K, V], func() }, func() error { return err } } -// NumberTypes is a type constraint that includes all numeric types. -type NumberTypes interface { - uint | uint8 | uint16 | uint32 | uint64 | int | int8 | int16 | int32 | int64 | float32 | float64 -} +// WriteMap writes a map to the provided Writer. +// The writeKey and writeVal parameters specify the functions to use to write each key and value of the map. +func WriteMap[K comparable, V any](w *Writer, m map[K]V, writeKey func(K) error, writeVal func(V) error) error { + if m == nil { + return w.WriteNil() + } + if uint64(len(m)) > math.MaxUint32 { + return fmt.Errorf("map too large to encode: %d elements", len(m)) + } -// ReadNumberArray returns an iterator that can be used to iterate over the elements -// of an array in the MessagePack data while being read by the provided Reader. -// The type parameter V specifies the type of the elements in the array. -// The returned iterator implements the iter.Seq[V] interface, -// allowing for sequential access to the array elements. -func ReadNumberArray[V NumberTypes](m *Reader) iter.Seq2[V, error] { - return func(yield func(V, error) bool) { - if m.IsNil() { - m.ReadNil() - return - } - var x V - length, err := m.ReadArrayHeader() + // Write map header + err := w.WriteMapHeader(uint32(len(m))) + if err != nil { + return err + } + // Write elements + for k, v := range m { + err = writeKey(k) if err != nil { - yield(x, fmt.Errorf("cannot read array header: %w", err)) - return + return err } - - switch any(x).(type) { - case uint8: - for range length { - v, err := m.ReadUint8() - if !yield(V(v), err) { - return - } - } - case uint16: - for range length { - v, err := m.ReadUint16() - if !yield(V(v), err) { - return - } - } - case uint32: - for range length { - v, err := m.ReadUint32() - if !yield(V(v), err) { - return - } - } - case uint64: - for range length { - v, err := m.ReadUint64() - if !yield(V(v), err) { - return - } - } - case uint: - for range length { - v, err := m.ReadUint() - if !yield(V(v), err) { - return - } - } - case int8: - for range length { - v, err := m.ReadInt8() - if !yield(V(v), err) { - return - } - } - case int16: - for range length { - v, err := m.ReadInt16() - if !yield(V(v), err) { - return - } - } - case int32: - for range length { - v, err := m.ReadInt32() - if !yield(V(v), err) { - return - } - } - case int64: - for range length { - v, err := m.ReadInt64() - if !yield(V(v), err) { - return - } - } - case int: - for range length { - v, err := m.ReadInt() - if !yield(V(v), err) { - return - } - } - case float32: - for range length { - v, err := m.ReadFloat32() - if !yield(V(v), err) { - return - } - } - case float64: - for range length { - v, err := m.ReadFloat64() - if !yield(V(v), err) { - return - } - } - default: - panic("unreachable") + err = writeVal(v) + if err != nil { + return err } } + return nil } -// ReadNumberArrayBytes returns an iterator that can be used to iterate over the elements -// of an array in the MessagePack data. -// The type parameter V specifies the type of the elements in the array. -// The returned iterator implements the iter.Seq[V] interface, -// allowing for sequential access to the array elements. -// After the iterator is exhausted, the remaining bytes in the buffer -// and any error can be read by calling the returned function. -func ReadNumberArrayBytes[V NumberTypes](b []byte) (iter.Seq[V], func() (remain []byte, err error)) { - if IsNil(b) { - b, err := ReadNilBytes(b) - return func(yield func(V) bool) {}, func() ([]byte, error) { return b, err } +// WriteMapSorted writes a map to the provided Writer. +// The keys of the map are sorted before writing. +// This provides deterministic output. +// The writeKey and writeVal parameters specify the functions +// to use to write each key and value of the map. +func WriteMapSorted[K cmp.Ordered, V any](w *Writer, m map[K]V, writeKey func(K) error, writeVal func(V) error) error { + if m == nil { + return w.WriteNil() } - // Regular array. - sz, b, err := ReadArrayHeaderBytes(b) - if err != nil { - return func(yield func(V) bool) {}, func() ([]byte, error) { return b, fmt.Errorf("cannot read array header: %w", err) } + if uint64(len(m)) > math.MaxUint32 { + return fmt.Errorf("map too large to encode: %d elements", len(m)) } - var readValue func() (V, error) - var v V - switch any(v).(type) { - case uint8: - readValue = func() (V, error) { - var val uint8 - val, b, err = ReadUint8Bytes(b) - return V(val), err - } - case uint16: - readValue = func() (V, error) { - var val uint16 - val, b, err = ReadUint16Bytes(b) - return V(val), err - } - case uint32: - readValue = func() (V, error) { - var val uint32 - val, b, err = ReadUint32Bytes(b) - return V(val), err - } - case uint64: - readValue = func() (V, error) { - var val uint64 - val, b, err = ReadUint64Bytes(b) - return V(val), err - } - case uint: - readValue = func() (V, error) { - var val uint - val, b, err = ReadUintBytes(b) - return V(val), err - } - case int8: - readValue = func() (V, error) { - var val int8 - val, b, err = ReadInt8Bytes(b) - return V(val), err - } - case int16: - readValue = func() (V, error) { - var val int16 - val, b, err = ReadInt16Bytes(b) - return V(val), err - } - case int32: - readValue = func() (V, error) { - var val int32 - val, b, err = ReadInt32Bytes(b) - return V(val), err - } - case int64: - readValue = func() (V, error) { - var val int64 - val, b, err = ReadInt64Bytes(b) - return V(val), err - } - case int: - readValue = func() (V, error) { - var val int - val, b, err = ReadIntBytes(b) - return V(val), err - } - case float32: - readValue = func() (V, error) { - var val float32 - val, b, err = ReadFloat32Bytes(b) - return V(val), err + // Write map header + err := w.WriteMapHeader(uint32(len(m))) + if err != nil { + return err + } + // Write elements + for _, k := range slices.Sorted(maps.Keys(m)) { + err = writeKey(k) + if err != nil { + return err } - case float64: - readValue = func() (V, error) { - var val float64 - val, b, err = ReadFloat64Bytes(b) - return V(val), err + err = writeVal(m[k]) + if err != nil { + return err } - default: - panic("unreachable") } - return func(yield func(V) bool) { - for sz > 0 { - v, err = readValue() - if err != nil || !yield(v) { - return - } - sz-- - } - }, func() ([]byte, error) { return b, err } + return nil } // ReadArrayBytes returns an iterator that can be used to iterate over the elements // of an array in the MessagePack data while being read by the provided Reader. // The type parameter V specifies the type of the elements in the array. -// The type parameter V must be bool, string, []byte or a type whose -// pointer implements the Unmarshaler interface. -// Use ReadNumberArrayBytes for numbers. -// The returned iterator implements the iter.Seq[V] interface, -// allowing for sequential access to the array elements. -// Byte slices will reference the same underlying data. // After the iterator is exhausted, the remaining bytes in the buffer // and any error can be read by calling the returned function. -func ReadArrayBytes[V ArrayExtraTypes](b []byte) (iter.Seq[V], func() (remain []byte, err error)) { +func ReadArrayBytes[T any](b []byte, readFn func([]byte) (T, []byte, error)) (iter.Seq[T], func() (remain []byte, err error)) { if IsNil(b) { b, err := ReadNilBytes(b) - return func(yield func(V) bool) {}, func() ([]byte, error) { return b, err } + return func(yield func(T) bool) {}, func() ([]byte, error) { return b, err } } sz, b, err := ReadArrayHeaderBytes(b) - if err != nil { - return nil, func() ([]byte, error) { return b, err } + if err != nil || sz == 0 { + return func(yield func(T) bool) {}, func() ([]byte, error) { return b, err } } - return func(yield func(V) bool) { - var x V - switch any(x).(type) { - case string: - for range sz { - var v string - v, b, err = ReadStringBytes(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - case []byte: - for range sz { - var v []byte - v, b, err = ReadBytesZC(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - case bool: - for range sz { - var v bool - v, b, err = ReadBoolBytes(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - case time.Time: - for range sz { - var v time.Time - v, b, err = ReadTimeBytes(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - case time.Duration: - for range sz { - var v time.Duration - v, b, err = ReadDurationBytes(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - case complex64: - for range sz { - var v complex64 - v, b, err = ReadComplex64Bytes(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - case complex128: - for range sz { - var v complex128 - v, b, err = ReadComplex128Bytes(b) - if err != nil || !yield(any(v).(V)) { - return - } - } - default: - for range sz { - var v V - ptr := &v - if um, ok := any(ptr).(Unmarshaler); ok { - b, err = um.UnmarshalMsg(b) - if err != nil || !yield(v) { - return - } - } else { - err = fmt.Errorf("cannot unmarshal into type %T", ptr) - return - } + return func(yield func(T) bool) { + for range sz { + var v T + v, b, err = readFn(b) + if err != nil || !yield(v) { + return } } - }, func() (remain []byte, err error) { + }, func() ([]byte, error) { return b, err } } -// ReadMapBytes returns an iterator over key/value pairs from a MessagePack map encoded in b. -// The iterator yields K,V pairs and this function also returns a closure to obtain the remaining bytes and any error. -// It avoids per-element type switches by precomputing readKey/readVal funcs based on K and V. -func ReadMapBytes[K MapKeyTypes, V MapValueTypes](b []byte) (iter.Seq2[K, V], func() (remain []byte, err error)) { +// AppendArray writes an array to the provided buffer. +// The writeFn parameter specifies the function to use to write each element of the array. +// The returned buffer contains the encoded array. +// The function panics if the map is larger than math.MaxUint32 elements. +func AppendArray[T any](b []byte, a []T, writeFn func(b []byte, v T) []byte) []byte { + if a == nil { + return AppendNil(b) + } + if uint64(len(a)) > math.MaxUint32 { + panic(fmt.Sprintf("array too large to encode: %d elements", len(a))) + } + b = AppendArrayHeader(b, uint32(len(a))) + for _, v := range a { + b = writeFn(b, v) + } + return b +} + +// ReadMapBytes returns an iterator over key/value +// pairs from a MessagePack map encoded in b. +// The iterator yields K,V pairs and this function also returns +// a closure to get the remaining bytes and any error. +func ReadMapBytes[K any, V any](b []byte, + readK func([]byte) (K, []byte, error), + readV func([]byte) (V, []byte, error)) (iter.Seq2[K, V], func() (remain []byte, err error)) { var err error var sz uint32 if IsNil(b) { @@ -714,288 +225,22 @@ func ReadMapBytes[K MapKeyTypes, V MapValueTypes](b []byte) (iter.Seq2[K, V], fu return func(yield func(K, V) bool) {}, func() ([]byte, error) { return b, err } } sz, b, err = ReadMapHeaderBytes(b) - if err != nil { + if err != nil || sz == 0 { return func(yield func(K, V) bool) {}, func() ([]byte, error) { return b, err } } - // Precompute key reader - var readKey func() (K, error) - { - var keyZero K - switch any(keyZero).(type) { - case string: - readKey = func() (K, error) { - v, e, er := ReadStringBytes(b) - b, err = e, er - return any(v).(K), err - } - case []byte: - // Map keys can be str or bin; use specialized helper that accepts both. - readKey = func() (K, error) { - v, e, er := ReadMapKeyZC(b) - b, err = e, er - return any(v).(K), err - } - case bool: - readKey = func() (K, error) { - v, e, er := ReadBoolBytes(b) - b, err = e, er - return any(v).(K), err - } - case time.Time: - readKey = func() (K, error) { - v, e, er := ReadTimeBytes(b) - b, err = e, er - return any(v).(K), err - } - case time.Duration: - readKey = func() (K, error) { - v, e, er := ReadDurationBytes(b) - b, err = e, er - return any(v).(K), err - } - case complex64: - readKey = func() (K, error) { - v, e, er := ReadComplex64Bytes(b) - b, err = e, er - return any(v).(K), err - } - case complex128: - readKey = func() (K, error) { - v, e, er := ReadComplex128Bytes(b) - b, err = e, er - return any(v).(K), err - } - case uint8: - readKey = func() (K, error) { - v, e, er := ReadUint8Bytes(b) - b, err = e, er - return any(v).(K), err - } - case uint16: - readKey = func() (K, error) { - v, e, er := ReadUint16Bytes(b) - b, err = e, er - return any(v).(K), err - } - case uint32: - readKey = func() (K, error) { - v, e, er := ReadUint32Bytes(b) - b, err = e, er - return any(v).(K), err - } - case uint64: - readKey = func() (K, error) { - v, e, er := ReadUint64Bytes(b) - b, err = e, er - return any(v).(K), err - } - case uint: - readKey = func() (K, error) { - v, e, er := ReadUintBytes(b) - b, err = e, er - return any(v).(K), err - } - case int8: - readKey = func() (K, error) { - v, e, er := ReadInt8Bytes(b) - b, err = e, er - return any(v).(K), err - } - case int16: - readKey = func() (K, error) { - v, e, er := ReadInt16Bytes(b) - b, err = e, er - return any(v).(K), err - } - case int32: - readKey = func() (K, error) { - v, e, er := ReadInt32Bytes(b) - b, err = e, er - return any(v).(K), err - } - case int64: - readKey = func() (K, error) { - v, e, er := ReadInt64Bytes(b) - b, err = e, er - return any(v).(K), err - } - case int: - readKey = func() (K, error) { - v, e, er := ReadIntBytes(b) - b, err = e, er - return any(v).(K), err - } - case float32: - readKey = func() (K, error) { - v, e, er := ReadFloat32Bytes(b) - b, err = e, er - return any(v).(K), err - } - case float64: - readKey = func() (K, error) { - v, e, er := ReadFloat64Bytes(b) - b, err = e, er - return any(v).(K), err - } - default: - readKey = func() (K, error) { - var k K - ptr := &k - if um, ok := any(ptr).(Unmarshaler); ok { - var e error - b, e = um.UnmarshalMsg(b) - return k, e - } - return k, fmt.Errorf("cannot unmarshal key into type %T", ptr) - } - } - } - - // Precompute value reader - var readVal func() (V, error) - { - var valZero V - switch any(valZero).(type) { - case string: - readVal = func() (V, error) { - v, e, er := ReadStringBytes(b) - b, err = e, er - return any(v).(V), err - } - case []byte: - // For values, zero-copy read of bin/str payload. - readVal = func() (V, error) { - v, e, er := ReadBytesZC(b) - b, err = e, er - return any(v).(V), err - } - case bool: - readVal = func() (V, error) { - v, e, er := ReadBoolBytes(b) - b, err = e, er - return any(v).(V), err - } - case time.Time: - readVal = func() (V, error) { - v, e, er := ReadTimeBytes(b) - b, err = e, er - return any(v).(V), err - } - case time.Duration: - readVal = func() (V, error) { - v, e, er := ReadDurationBytes(b) - b, err = e, er - return any(v).(V), err - } - case complex64: - readVal = func() (V, error) { - v, e, er := ReadComplex64Bytes(b) - b, err = e, er - return any(v).(V), err - } - case complex128: - readVal = func() (V, error) { - v, e, er := ReadComplex128Bytes(b) - b, err = e, er - return any(v).(V), err - } - case int8: - readVal = func() (V, error) { - v, e, er := ReadInt8Bytes(b) - b, err = e, er - return any(v).(V), err - } - case int16: - readVal = func() (V, error) { - v, e, er := ReadInt16Bytes(b) - b, err = e, er - return any(v).(V), err - } - case int32: - readVal = func() (V, error) { - v, e, er := ReadInt32Bytes(b) - b, err = e, er - return any(v).(V), err - } - case int64: - readVal = func() (V, error) { - v, e, er := ReadInt64Bytes(b) - b, err = e, er - return any(v).(V), err - } - case int: - readVal = func() (V, error) { - v, e, er := ReadIntBytes(b) - b, err = e, er - return any(v).(V), err - } - case float32: - readVal = func() (V, error) { - v, e, er := ReadFloat32Bytes(b) - b, err = e, er - return any(v).(V), err - } - case float64: - readVal = func() (V, error) { - v, e, er := ReadFloat64Bytes(b) - b, err = e, er - return any(v).(V), err - } - case uint8: - readVal = func() (V, error) { - v, e, er := ReadUint8Bytes(b) - b, err = e, er - return any(v).(V), err - } - case uint16: - readVal = func() (V, error) { - v, e, er := ReadUint16Bytes(b) - b, err = e, er - return any(v).(V), err - } - case uint32: - readVal = func() (V, error) { - v, e, er := ReadUint32Bytes(b) - b, err = e, er - return any(v).(V), err - } - case uint64: - readVal = func() (V, error) { - v, e, er := ReadUint64Bytes(b) - b, err = e, er - return any(v).(V), err - } - case uint: - readVal = func() (V, error) { - v, e, er := ReadUintBytes(b) - b, err = e, er - return any(v).(V), err - } - default: - readVal = func() (V, error) { - var v V - ptr := &v - if um, ok := any(ptr).(Unmarshaler); ok { - var e error - b, e = um.UnmarshalMsg(b) - return v, e - } - return v, fmt.Errorf("cannot unmarshal value into type %T", ptr) - } - } - } - return func(yield func(K, V) bool) { for range sz { - k, er := readKey() - if er != nil { - err = fmt.Errorf("cannot read map key: %w", er) + var k K + k, b, err = readK(b) + if err != nil { + err = fmt.Errorf("cannot read map key: %w", err) return } - v, er := readVal() - if er != nil { - err = fmt.Errorf("cannot read map value: %w", er) + var v V + v, b, err = readV(b) + if err != nil { + err = fmt.Errorf("cannot read map value: %w", err) return } if !yield(k, v) { @@ -1004,3 +249,132 @@ func ReadMapBytes[K MapKeyTypes, V MapValueTypes](b []byte) (iter.Seq2[K, V], fu } }, func() ([]byte, error) { return b, err } } + +// AppendMap writes a map to the provided buffer. +// The writeK and writeV parameters specify the functions to use to write each key and value of the map. +// The returned buffer contains the encoded map. +// The function panics if the map is larger than math.MaxUint32 elements. +func AppendMap[K comparable, V any](b []byte, m map[K]V, + writeK func(b []byte, k K) []byte, + writeV func(b []byte, v V) []byte) []byte { + if m == nil { + return AppendNil(b) + } + if uint64(len(m)) > math.MaxUint32 { + panic(fmt.Sprintf("map too large to encode: %d elements", len(m))) + } + b = AppendMapHeader(b, uint32(len(m))) + for k, v := range m { + b = writeK(b, k) + b = writeV(b, v) + } + return b +} + +// AppendMapSorted writes a map to the provided buffer. +// Keys are sorted before writing. This provides deterministic output. +// The writeK and writeV parameters specify the functions to use to write each key and value of the map. +// The returned buffer contains the encoded map. +// The function panics if the map is larger than math.MaxUint32 elements. +func AppendMapSorted[K cmp.Ordered, V any](b []byte, m map[K]V, + writeK func(b []byte, k K) []byte, + writeV func(b []byte, v V) []byte) []byte { + if m == nil { + return AppendNil(b) + } + if uint64(len(m)) > math.MaxUint32 { + panic(fmt.Sprintf("map too large to encode: %d elements", len(m))) + } + b = AppendMapHeader(b, uint32(len(m))) + for _, k := range slices.Sorted(maps.Keys(m)) { + b = writeK(b, k) + b = writeV(b, m[k]) + } + return b +} + +// DecodePtr is a convenience type for decoding into a pointer. +type DecodePtr[T any] interface { + *T + Decodable +} + +// DecoderFrom allows augmenting any type with a DecodeMsg method into a method +// that reads from Reader and returns a T. +// Provide an instance of T. This value isn't used. +func DecoderFrom[T any, PT DecodePtr[T]](r *Reader, _ T) func() (T, error) { + return func() (T, error) { + var t T + tPtr := PT(&t) + err := tPtr.DecodeMsg(r) + return t, err + } +} + +// FlexibleEncoder is a constraint for types where either T or *T implements Encodable +type FlexibleEncoder[T any] interface { + Encodable + *T // Include *T in the interface +} + +// EncoderTo allows augmenting any type with a EncodeMsg method into a method +// that writes to Writer on each call. +// Provide an instance of T. This value isn't used.' +func EncoderTo[T any, PT FlexibleEncoder[T]](w *Writer, _ T) func(T) error { + return func(t T) error { + // Check if T implements Marshaler + if marshaler, ok := any(t).(Encodable); ok { + return marshaler.EncodeMsg(w) + } + // Check if *T implements Marshaler + if ptrMarshaler, ok := any(&t).(Encodable); ok { + return ptrMarshaler.EncodeMsg(w) + } + // The compiler should have asserted this. + panic("type does not implement Marshaler") + } +} + +// UnmarshalPtr is a convenience type for unmarshaling into a pointer. +type UnmarshalPtr[T any] interface { + *T + Unmarshaler +} + +// DecoderFromBytes allows augmenting any type with a UnmarshalMsg method into a method +// that reads from []byte and returns a T. +// Provide an instance of T. This value isn't used. +func DecoderFromBytes[T any, PT UnmarshalPtr[T]](_ T) func([]byte) (T, []byte, error) { + return func(b []byte) (T, []byte, error) { + var t T + tPtr := PT(&t) + b, err := tPtr.UnmarshalMsg(b) + return t, b, err + } +} + +// FlexibleMarshaler is a constraint for types where either T or *T implements Marshaler +type FlexibleMarshaler[T any] interface { + Marshaler + *T // Include *T in the interface +} + +// EncoderToBytes allows augmenting any type with a MarshalMsg method into a method +// that reads from T and returns a []byte. +// Provide an instance of T. This value isn't used. +func EncoderToBytes[T any, PT FlexibleMarshaler[T]](_ T) func([]byte, T) []byte { + return func(b []byte, t T) []byte { + // Check if T implements Marshaler + if marshaler, ok := any(t).(Marshaler); ok { + b, _ = marshaler.MarshalMsg(b) + return b + } + // Check if *T implements Marshaler + if ptrMarshaler, ok := any(&t).(Marshaler); ok { + b, _ = ptrMarshaler.MarshalMsg(b) + return b + } + // The compiler should have asserted this. + panic("type does not implement Marshaler") + } +} diff --git a/msgp/iter_test.go b/msgp/iter_test.go index be7ff7b4..0ab1d974 100644 --- a/msgp/iter_test.go +++ b/msgp/iter_test.go @@ -4,11 +4,223 @@ package msgp import ( "bytes" + "fmt" "math" "testing" "time" ) +// Example: reading an array of ints using ReadArray with a *Reader. +// It prints each element in order. +func ExampleReadArray() { + var buf bytes.Buffer + w := NewWriter(&buf) + + // Write an array [10, 20, 30] using WriteArray + _ = WriteArray[int](w, []int{10, 20, 30}, w.WriteInt) + _ = w.Flush() + + r := NewReader(&buf) + + seq := ReadArray[int](r, r.ReadInt) + seq(func(v int, err error) bool { + if err != nil { + fmt.Println("err:", err) + return false + } + fmt.Println(v) + return true + }) + + // Output: + // 10 + // 20 + // 30 +} + +// Example: reading a map[string]int using ReadMap with a *Reader. +// It prints key=value pairs in the same order they were written. +func ExampleReadMap() { + var buf bytes.Buffer + w := NewWriter(&buf) + + // Write a map {"a":1, "b":2} using WriteMap - we use the sorted version so output is predictable. + _ = WriteMapSorted[string, int](w, map[string]int{"a": 1, "b": 2}, w.WriteString, w.WriteInt) + _ = w.Flush() + + r := NewReader(&buf) + + seq, done := ReadMap[string, int](r, r.ReadString, r.ReadInt) + seq(func(k string, v int) bool { + fmt.Printf("%s=%d\n", k, v) + return true + }) + if err := done(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // a=1 + // b=2 +} + +// Example: reading an array of strings from a byte slice using ReadArrayBytes. +// It prints each element and then checks for a final error from the returned closure. +func ExampleReadArrayBytes() { + var b []byte + // Append ["x","y","z"] using AppendArray + b = AppendArray[string](b, []string{"x", "y", "z"}, AppendString) + + seq, finish := ReadArrayBytes[string](b, ReadStringBytes) + seq(func(s string) bool { + fmt.Println(s) + return true + }) + if _, err := finish(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // x + // y + // z +} + +// Example: reading a map[string]float64 from a byte slice using ReadMapBytes. +// It prints key=value pairs and then checks the remaining bytes/error from the returned closure. +func ExampleReadMapBytes() { + var b []byte + // Append {"pi":3.14, "e":2.718} using AppendMap - we use the sorted version for the example + b = AppendMapSorted[string, float64](b, map[string]float64{"pi": 3.14, "e": 2.718}, AppendString, AppendFloat64) + + seq, finish := ReadMapBytes[string, float64](b, ReadStringBytes, ReadFloat64Bytes) + seq(func(k string, v float64) bool { + fmt.Printf("%s=%.3f\n", k, v) + return true + }) + if _, err := finish(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // e=2.718 + // pi=3.140 +} + +// Example: slice roundtrip with struct elements using WriteArray/ReadArray. +// Uses testDec as the element type, with EncoderTo/DecoderFrom helpers. +func ExampleReadArray_struct() { + var buf bytes.Buffer + w := NewWriter(&buf) + + in := []testDec{{A: 1, B: "x"}, {A: 2, B: "y"}} + // Write []testDec using EncoderTo as the per-element writer. + _ = WriteArray[testDec](w, in, EncoderTo(w, testDec{})) + _ = w.Flush() + + r := NewReader(&buf) + // Read []testDec using DecoderFrom as the per-element reader. + seq := ReadArray[testDec](r, DecoderFrom(r, testDec{})) + + seq(func(v testDec, err error) bool { + if err != nil { + fmt.Println("err:", err) + return false + } + fmt.Printf("%d %s\n", v.A, v.B) + return true + }) + + // Output: + // 1 x + // 2 y +} + +// Example: map roundtrip with struct values using WriteMapSorted/ReadMap. +// Uses testDec as the value type and sorts keys for deterministic output. +// Employs EncoderTo for values and DecoderFrom for values when reading. +func ExampleReadMap_struct() { + var buf bytes.Buffer + w := NewWriter(&buf) + + in := map[string]testDec{ + "a": {A: 1, B: "x"}, + "b": {A: 2, B: "y"}, + } + // Write map[string]testDec using sorted keys for stable example output. + _ = WriteMapSorted[string, testDec](w, in, w.WriteString, EncoderTo(w, testDec{})) + _ = w.Flush() + + r := NewReader(&buf) + seq, done := ReadMap[string, testDec](r, r.ReadString, DecoderFrom(r, testDec{})) + + seq(func(k string, v testDec) bool { + fmt.Printf("%s=%d,%s\n", k, v.A, v.B) + return true + }) + if err := done(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // a=1,x + // b=2,y +} + +// Example: slice roundtrip with struct elements in a byte slice using AppendArray/ReadArrayBytes. +// Uses testDec as the element type, with EncoderToBytes/DecoderFromBytes helpers. +func ExampleReadArrayBytes_struct() { + in := []testDec{{A: 1, B: "x"}, {A: 2, B: "y"}} + var b []byte + + // Append []testDec using EncoderToBytes as the per-element appender. + b = AppendArray[testDec](b, in, EncoderToBytes(testDec{})) + + // Read back using DecoderFromBytes as the per-element reader. + seq, finish := ReadArrayBytes[testDec](b, DecoderFromBytes(testDec{})) + + seq(func(v testDec) bool { + fmt.Printf("%d %s\n", v.A, v.B) + return true + }) + if _, err := finish(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // 1 x + // 2 y +} + +// Example: map roundtrip with struct values in a byte slice using AppendMapSorted/ReadMapBytes. +// Uses testDec as the value type and sorts keys for deterministic output. +// Employs EncoderToBytes for values and DecoderFromBytes for values when reading. +func ExampleReadMapBytes_struct() { + in := map[string]testDec{ + "a": {A: 1, B: "x"}, + "b": {A: 2, B: "y"}, + } + var b []byte + + // Append map[string]testDec with sorted keys for stable example output. + b = AppendMapSorted[string, testDec](b, in, AppendString, EncoderToBytes(testDec{})) + + // Read back using DecoderFromBytes for values. + seq, finish := ReadMapBytes[string, testDec](b, ReadStringBytes, DecoderFromBytes(testDec{})) + + seq(func(k string, v testDec) bool { + fmt.Printf("%s=%d,%s\n", k, v.A, v.B) + return true + }) + if _, err := finish(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // a=1,x + // b=2,y +} + var nilMsg = AppendNil(nil) // collectSeq2 collects values from an iter.Seq2[V, error] into a slice. @@ -43,7 +255,7 @@ func TestReadNumberArray_Int(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[int](r)) + got, err := collectSeq2(ReadArray[int](r, r.ReadInt)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -75,7 +287,7 @@ func TestReadNumberArray_Float64(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[float64](r)) + got, err := collectSeq2(ReadArray[float64](r, r.ReadFloat64)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -107,7 +319,7 @@ func TestReadArray_String(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[string](r)) + got, err := collectSeq2(ReadArray[string](r, r.ReadString)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -139,7 +351,7 @@ func TestReadArray_Bool(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[bool](r)) + got, err := collectSeq2(ReadArray[bool](r, r.ReadBool)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -213,7 +425,13 @@ func TestReadArray_Decodable(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[testDec](r)) + got, err := collectSeq2(ReadArray[testDec](r, func() (testDec, error) { + var t testDec + if err := t.DecodeMsg(r); err != nil { + return testDec{}, err + } + return t, nil + })) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -259,7 +477,7 @@ func TestReadArray_TimeAndDuration(t *testing.T) { } r := NewReader(&buf) - timesGot, err := collectSeq2(ReadArray[time.Time](r)) + timesGot, err := collectSeq2(ReadArray[time.Time](r, r.ReadTime)) if err != nil { t.Fatalf("times iteration error: %v", err) } @@ -267,7 +485,7 @@ func TestReadArray_TimeAndDuration(t *testing.T) { t.Fatalf("times mismatch: got %v", timesGot) } - dursGot, err := collectSeq2(ReadArray[time.Duration](r)) + dursGot, err := collectSeq2(ReadArray[time.Duration](r, r.ReadDuration)) if err != nil { t.Fatalf("durations iteration error: %v", err) } @@ -290,7 +508,7 @@ func TestReadNumberArrayBytes_Uint16(t *testing.T) { msg = AppendUint16(msg, v) } - seq, tail := ReadNumberArrayBytes[uint16](msg) + seq, tail := ReadArrayBytes[uint16](msg, ReadUint16Bytes) var got []uint16 for v := range seq { got = append(got, v) @@ -322,7 +540,7 @@ func TestReadNumberArrayBytes_ErrOnTruncated(t *testing.T) { // Truncate to cause an error when reading the second element. trunc := full[:len(full)-2] - seq, tail := ReadNumberArrayBytes[int32](trunc) + seq, tail := ReadArrayBytes[int32](trunc, ReadInt32Bytes) var got []int32 for v := range seq { got = append(got, v) @@ -359,7 +577,7 @@ func TestReadArray_ErrorOnTooFewElements(t *testing.T) { r := NewReader(&buf) var got []int var firstErr error - ReadNumberArray[int](r)(func(v int, err error) bool { + ReadArray[int](r, r.ReadInt)(func(v int, err error) bool { if err != nil { firstErr = err return false @@ -384,7 +602,7 @@ func approxEqual[T ~float32 | ~float64](a, b T) bool { } func TestRoundtripNumberArray_AllTypes(t *testing.T) { - type testcase[V NumberTypes] struct { + type testcase[V any] struct { name string vals []V write func(w *Writer, v V) error @@ -424,7 +642,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[uint](r)) + got, err := collectSeq2(ReadArray[uint](r, r.ReadUint)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -436,7 +654,8 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } - got, err = collectSeq2(ReadNumberArray[uint](NewReader(bytes.NewReader(nilMsg)))) + r = NewReader(bytes.NewReader(nilMsg)) + got, err = collectSeq2(ReadArray[uint](r, r.ReadUint)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -459,7 +678,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[uint8](r)) + got, err := collectSeq2(ReadArray[uint8](r, r.ReadUint8)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -486,7 +705,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[uint16](r)) + got, err := collectSeq2(ReadArray[uint16](r, r.ReadUint16)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -513,7 +732,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[uint32](r)) + got, err := collectSeq2(ReadArray[uint32](r, r.ReadUint32)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -540,7 +759,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[uint64](r)) + got, err := collectSeq2(ReadArray[uint64](r, r.ReadUint64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -567,7 +786,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[int](r)) + got, err := collectSeq2(ReadArray[int](r, r.ReadInt)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -594,7 +813,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[int8](r)) + got, err := collectSeq2(ReadArray[int8](r, r.ReadInt8)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -621,7 +840,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[int16](r)) + got, err := collectSeq2(ReadArray[int16](r, r.ReadInt16)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -648,7 +867,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[int32](r)) + got, err := collectSeq2(ReadArray[int32](r, r.ReadInt32)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -675,7 +894,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[int64](r)) + got, err := collectSeq2(ReadArray[int64](r, r.ReadInt64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -702,7 +921,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[float32](r)) + got, err := collectSeq2(ReadArray[float32](r, r.ReadFloat32)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -729,7 +948,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadNumberArray[float64](r)) + got, err := collectSeq2(ReadArray[float64](r, r.ReadFloat64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -746,7 +965,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { } func TestRoundtripArray_AllTypes(t *testing.T) { - type regCase[V ArrayExtraTypes] struct { + type regCase[V any] struct { name string vals []V write func(*Writer, V) error @@ -794,7 +1013,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[bool](r)) + got, err := collectSeq2(ReadArray[bool](r, r.ReadBool)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -807,7 +1026,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { } } r.Reset(bytes.NewReader(nilMsg)) - got, err = collectSeq2(ReadArray[bool](r)) + got, err = collectSeq2(ReadArray[bool](r, r.ReadBool)) if len(got) != 0 { t.Fatalf("%s len: got %d want 0", tc.name, len(got)) } @@ -830,7 +1049,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[string](r)) + got, err := collectSeq2(ReadArray[string](r, r.ReadString)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -857,7 +1076,9 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[[]byte](r)) + got, err := collectSeq2(ReadArray[[]byte](r, func() ([]byte, error) { + return r.ReadBytes(nil) + })) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -884,7 +1105,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[time.Time](r)) + got, err := collectSeq2(ReadArray[time.Time](r, r.ReadTime)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -911,7 +1132,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[time.Duration](r)) + got, err := collectSeq2(ReadArray[time.Duration](r, r.ReadDuration)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -938,7 +1159,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[complex64](r)) + got, err := collectSeq2(ReadArray[complex64](r, r.ReadComplex64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -965,7 +1186,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[complex128](r)) + got, err := collectSeq2(ReadArray[complex128](r, r.ReadComplex128)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -992,7 +1213,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[testDec](r)) + got, err := collectSeq2(ReadArray[testDec](r, DecoderFrom(r, testDec{}))) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1010,7 +1231,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { } func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { - type tb[V NumberTypes] struct { + type tb[V any] struct { name string vals []V append func([]byte, V) []byte @@ -1039,7 +1260,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[uint](msg) + seq, tail := ReadArrayBytes[uint](msg, ReadUintBytes) var got []uint for v := range seq { got = append(got, v) @@ -1059,7 +1280,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } - seq, tail = ReadNumberArrayBytes[uint](nilMsg) + seq, tail = ReadArrayBytes[uint](nilMsg, ReadUintBytes) for range seq { t.Fatalf("%s: got entries on nil", tc.name) } @@ -1076,7 +1297,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[uint8](msg) + seq, tail := ReadArrayBytes[uint8](msg, ReadUint8Bytes) var got []uint8 for v := range seq { got = append(got, v) @@ -1101,7 +1322,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[uint16](msg) + seq, tail := ReadArrayBytes[uint16](msg, ReadUint16Bytes) var got []uint16 for v := range seq { got = append(got, v) @@ -1126,7 +1347,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[uint32](msg) + seq, tail := ReadArrayBytes[uint32](msg, ReadUint32Bytes) var got []uint32 for v := range seq { got = append(got, v) @@ -1151,7 +1372,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[uint64](msg) + seq, tail := ReadArrayBytes[uint64](msg, ReadUint64Bytes) var got []uint64 for v := range seq { got = append(got, v) @@ -1176,7 +1397,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[int](msg) + seq, tail := ReadArrayBytes[int](msg, ReadIntBytes) var got []int for v := range seq { got = append(got, v) @@ -1201,7 +1422,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[int8](msg) + seq, tail := ReadArrayBytes[int8](msg, ReadInt8Bytes) var got []int8 for v := range seq { got = append(got, v) @@ -1226,7 +1447,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[int16](msg) + seq, tail := ReadArrayBytes[int16](msg, ReadInt16Bytes) var got []int16 for v := range seq { got = append(got, v) @@ -1251,7 +1472,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[int32](msg) + seq, tail := ReadArrayBytes[int32](msg, ReadInt32Bytes) var got []int32 for v := range seq { got = append(got, v) @@ -1276,7 +1497,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[int64](msg) + seq, tail := ReadArrayBytes[int64](msg, ReadInt64Bytes) var got []int64 for v := range seq { got = append(got, v) @@ -1301,7 +1522,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[float32](msg) + seq, tail := ReadArrayBytes[float32](msg, ReadFloat32Bytes) var got []float32 for v := range seq { got = append(got, v) @@ -1326,7 +1547,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadNumberArrayBytes[float64](msg) + seq, tail := ReadArrayBytes[float64](msg, ReadFloat64Bytes) var got []float64 for v := range seq { got = append(got, v) @@ -1351,7 +1572,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { } func TestRoundtripArrayBytes_AllTypes(t *testing.T) { - type rb[V ArrayExtraTypes] struct { + type rb[V any] struct { name string vals []V append func([]byte, V) []byte @@ -1389,7 +1610,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[bool](msg) + seq, tail := ReadArrayBytes[bool](msg, ReadBoolBytes) var got []bool for v := range seq { got = append(got, v) @@ -1409,7 +1630,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } - seq, tail = ReadArrayBytes[bool](nilMsg) + seq, tail = ReadArrayBytes[bool](nilMsg, ReadBoolBytes) for range seq { t.Fatalf("%s: got entries on nil", tc.name) } @@ -1425,7 +1646,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[string](msg) + seq, tail := ReadArrayBytes[string](msg, ReadStringBytes) var got []string for v := range seq { got = append(got, v) @@ -1450,7 +1671,9 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[[]byte](msg) + seq, tail := ReadArrayBytes[[]byte](msg, func(i []byte) ([]byte, []byte, error) { + return ReadBytesBytes(i, nil) + }) var got [][]byte for v := range seq { got = append(got, v) @@ -1475,7 +1698,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[time.Time](msg) + seq, tail := ReadArrayBytes[time.Time](msg, ReadTimeBytes) var got []time.Time for v := range seq { got = append(got, v) @@ -1500,7 +1723,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[time.Duration](msg) + seq, tail := ReadArrayBytes[time.Duration](msg, ReadDurationBytes) var got []time.Duration for v := range seq { got = append(got, v) @@ -1525,7 +1748,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[complex64](msg) + seq, tail := ReadArrayBytes[complex64](msg, ReadComplex64Bytes) var got []complex64 for v := range seq { got = append(got, v) @@ -1550,7 +1773,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[complex128](msg) + seq, tail := ReadArrayBytes[complex128](msg, ReadComplex128Bytes) var got []complex128 for v := range seq { got = append(got, v) @@ -1575,7 +1798,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[testDec](msg) + seq, tail := ReadArrayBytes[testDec](msg, DecoderFromBytes(testDec{})) var got []testDec for v := range seq { got = append(got, v) @@ -1603,7 +1826,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { func eqNum[T comparable](a, b T) bool { return a == b } func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { - type numCase[T NumberTypes] struct { + type numCase[T any] struct { name string keys []T vals []T @@ -1649,7 +1872,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint, uint](r) + seq, tail := ReadMap[uint, uint](r, r.ReadUint, r.ReadUint) got := make(map[uint]uint, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1667,7 +1890,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { } // Test nil r = NewReader(bytes.NewReader(nilMsg)) - seq, tail = ReadMap[uint, uint](r) + seq, tail = ReadMap[uint, uint](r, r.ReadUint, r.ReadUint) for k, v := range seq { t.Fatalf("nil %s: got key %v val %v", tc.name, k, v) } @@ -1692,7 +1915,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint8, uint8](r) + seq, tail := ReadMap[uint8, uint8](r, r.ReadUint8, r.ReadUint8) got := make(map[uint8]uint8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1726,7 +1949,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint16, uint16](r) + seq, tail := ReadMap[uint16, uint16](r, r.ReadUint16, r.ReadUint16) got := make(map[uint16]uint16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1760,7 +1983,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint32, uint32](r) + seq, tail := ReadMap[uint32, uint32](r, r.ReadUint32, r.ReadUint32) got := make(map[uint32]uint32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1794,7 +2017,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint64, uint64](r) + seq, tail := ReadMap[uint64, uint64](r, r.ReadUint64, r.ReadUint64) got := make(map[uint64]uint64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1828,7 +2051,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int, int](r) + seq, tail := ReadMap[int, int](r, r.ReadInt, r.ReadInt) got := make(map[int]int, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1862,7 +2085,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int8, int8](r) + seq, tail := ReadMap[int8, int8](r, r.ReadInt8, r.ReadInt8) got := make(map[int8]int8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1896,7 +2119,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int16, int16](r) + seq, tail := ReadMap[int16, int16](r, r.ReadInt16, r.ReadInt16) got := make(map[int16]int16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1930,7 +2153,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int32, int32](r) + seq, tail := ReadMap[int32, int32](r, r.ReadInt32, r.ReadInt32) got := make(map[int32]int32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1964,7 +2187,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int64, int64](r) + seq, tail := ReadMap[int64, int64](r, r.ReadInt64, r.ReadInt64) got := make(map[int64]int64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1998,7 +2221,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[float32, float32](r) + seq, tail := ReadMap[float32, float32](r, r.ReadFloat32, r.ReadFloat32) got := make(map[float32]float32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2032,7 +2255,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[float64, float64](r) + seq, tail := ReadMap[float64, float64](r, r.ReadFloat64, r.ReadFloat64) got := make(map[float64]float64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2112,7 +2335,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[bool, bool](r) + seq, tail := ReadMap[bool, bool](r, r.ReadBool, r.ReadBool) got := make(map[bool]bool, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2129,7 +2352,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { } } r = NewReader(bytes.NewReader(nilMsg)) - seq, tail = ReadMap[bool, bool](r) + seq, tail = ReadMap[bool, bool](r, r.ReadBool, r.ReadBool) for k, v := range seq { t.Fatalf("%s:expected ni results, got %v:%v", tc.name, k, v) } @@ -2155,7 +2378,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[string, string](r) + seq, tail := ReadMap[string, string](r, r.ReadString, r.ReadString) got := make(map[string]string, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2197,7 +2420,11 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { } r := NewReader(&buf) - seq, tail := ReadMap[[]byte, []byte](r) + seq, tail := ReadMap[[]byte, []byte](r, func() ([]byte, error) { + return r.ReadBytes(nil) + }, func() ([]byte, error) { + return r.ReadBytes(nil) + }) var got []pair for k, v := range seq { kk := append([]byte(nil), k...) @@ -2243,7 +2470,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[time.Time, time.Time](r) + seq, tail := ReadMap[time.Time, time.Time](r, r.ReadTime, r.ReadTime) got := make(map[time.Time]time.Time, len(tc.keys)) for k, v := range seq { got[k.UTC()] = v.UTC() @@ -2277,7 +2504,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[time.Duration, time.Duration](r) + seq, tail := ReadMap[time.Duration, time.Duration](r, r.ReadDuration, r.ReadDuration) got := make(map[time.Duration]time.Duration, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2312,7 +2539,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[complex64, complex64](r) + seq, tail := ReadMap[complex64, complex64](r, r.ReadComplex64, r.ReadComplex64) got := make(map[complex64]complex64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2346,7 +2573,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[complex128, complex128](r) + seq, tail := ReadMap[complex128, complex128](r, r.ReadComplex128, r.ReadComplex128) got := make(map[complex128]complex128, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2367,7 +2594,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { } func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { - type numCase[T NumberTypes] struct { + type numCase[T any] struct { name string keys []T vals []T @@ -2401,7 +2628,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint, uint](msg) + seq, tail := ReadMapBytes[uint, uint](msg, ReadUintBytes, ReadUintBytes) got := make(map[uint]uint, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2421,7 +2648,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } - seq, tail = ReadMapBytes[uint, uint](nilMsg) + seq, tail = ReadMapBytes[uint, uint](nilMsg, ReadUintBytes, ReadUintBytes) for k, v := range seq { t.Fatalf("%s: got %v:%v want nothing", tc.name, k, v) } @@ -2438,7 +2665,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint8, uint8](msg) + seq, tail := ReadMapBytes[uint8, uint8](msg, ReadUint8Bytes, ReadUint8Bytes) got := make(map[uint8]uint8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2464,7 +2691,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint16, uint16](msg) + seq, tail := ReadMapBytes[uint16, uint16](msg, ReadUint16Bytes, ReadUint16Bytes) got := make(map[uint16]uint16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2490,7 +2717,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint32, uint32](msg) + seq, tail := ReadMapBytes[uint32, uint32](msg, ReadUint32Bytes, ReadUint32Bytes) got := make(map[uint32]uint32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2516,7 +2743,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint64, uint64](msg) + seq, tail := ReadMapBytes[uint64, uint64](msg, ReadUint64Bytes, ReadUint64Bytes) got := make(map[uint64]uint64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2542,7 +2769,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int, int](msg) + seq, tail := ReadMapBytes[int, int](msg, ReadIntBytes, ReadIntBytes) got := make(map[int]int, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2568,7 +2795,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int8, int8](msg) + seq, tail := ReadMapBytes[int8, int8](msg, ReadInt8Bytes, ReadInt8Bytes) got := make(map[int8]int8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2594,7 +2821,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int16, int16](msg) + seq, tail := ReadMapBytes[int16, int16](msg, ReadInt16Bytes, ReadInt16Bytes) got := make(map[int16]int16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2620,7 +2847,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int32, int32](msg) + seq, tail := ReadMapBytes[int32, int32](msg, ReadInt32Bytes, ReadInt32Bytes) got := make(map[int32]int32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2646,7 +2873,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int64, int64](msg) + seq, tail := ReadMapBytes[int64, int64](msg, ReadInt64Bytes, ReadInt64Bytes) got := make(map[int64]int64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2672,7 +2899,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[float32, float32](msg) + seq, tail := ReadMapBytes[float32, float32](msg, ReadFloat32Bytes, ReadFloat32Bytes) got := make(map[float32]float32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2698,7 +2925,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[float64, float64](msg) + seq, tail := ReadMapBytes[float64, float64](msg, ReadFloat64Bytes, ReadFloat64Bytes) got := make(map[float64]float64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2770,7 +2997,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[bool, bool](msg) + seq, tail := ReadMapBytes[bool, bool](msg, ReadBoolBytes, ReadBoolBytes) got := make(map[bool]bool, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2790,7 +3017,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } - seq, tail = ReadMapBytes[bool, bool](nilMsg) + seq, tail = ReadMapBytes[bool, bool](nilMsg, ReadBoolBytes, ReadBoolBytes) for k, v := range seq { t.Fatalf("%s key %v:%v want nothing", tc.name, k, v) } @@ -2808,7 +3035,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[string, string](msg) + seq, tail := ReadMapBytes[string, string](msg, ReadStringBytes, ReadStringBytes) got := make(map[string]string, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2841,7 +3068,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { expected[i] = pair{append([]byte(nil), tc.keys[i]...), append([]byte(nil), tc.vals[i]...)} } - seq, tail := ReadMapBytes[[]byte, []byte](msg) + seq, tail := ReadMapBytes[[]byte, []byte](msg, ReadBytesZC, ReadBytesZC) var got []pair for k, v := range seq { kk := append([]byte(nil), k...) @@ -2877,7 +3104,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[time.Time, time.Time](msg) + seq, tail := ReadMapBytes[time.Time, time.Time](msg, ReadTimeBytes, ReadTimeBytes) got := make(map[time.Time]time.Time, len(tc.keys)) for k, v := range seq { got[k.UTC()] = v.UTC() @@ -2903,7 +3130,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[time.Duration, time.Duration](msg) + seq, tail := ReadMapBytes[time.Duration, time.Duration](msg, ReadDurationBytes, ReadDurationBytes) got := make(map[time.Duration]time.Duration, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2929,7 +3156,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[complex64, complex64](msg) + seq, tail := ReadMapBytes[complex64, complex64](msg, ReadComplex64Bytes, ReadComplex64Bytes) got := make(map[complex64]complex64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2955,7 +3182,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[complex128, complex128](msg) + seq, tail := ReadMapBytes[complex128, complex128](msg, ReadComplex128Bytes, ReadComplex128Bytes) got := make(map[complex128]complex128, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2989,7 +3216,7 @@ func TestReadMapBytes_TailErrorOnTruncated(t *testing.T) { full := AppendInt(msg, 20) trunc := full[:len(full)-2] // truncate some bytes from the last value - seq, tail := ReadMapBytes[int, int](trunc) + seq, tail := ReadMapBytes[int, int](trunc, ReadIntBytes, ReadIntBytes) got := make(map[int]int) for k, v := range seq { got[k] = v From c23802fd0348964b955e5c378cf988361ae8499c Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Mon, 1 Sep 2025 10:52:16 +0200 Subject: [PATCH 6/7] Add more examples, remove explicit types from all example calls. Fix copy+paste --- msgp/iter.go | 9 +- msgp/iter_test.go | 285 +++++++++++++++++++++++++++++----------------- 2 files changed, 186 insertions(+), 108 deletions(-) diff --git a/msgp/iter.go b/msgp/iter.go index 0b3a2349..f72df24e 100644 --- a/msgp/iter.go +++ b/msgp/iter.go @@ -106,7 +106,8 @@ func ReadMap[K, V any](m *Reader, readKey func() (K, error), readVal func() (V, } // WriteMap writes a map to the provided Writer. -// The writeKey and writeVal parameters specify the functions to use to write each key and value of the map. +// The writeKey and writeVal parameters specify the functions +// to use to write each key and value of the map. func WriteMap[K comparable, V any](w *Writer, m map[K]V, writeKey func(K) error, writeVal func(V) error) error { if m == nil { return w.WriteNil() @@ -196,7 +197,7 @@ func ReadArrayBytes[T any](b []byte, readFn func([]byte) (T, []byte, error)) (it // AppendArray writes an array to the provided buffer. // The writeFn parameter specifies the function to use to write each element of the array. // The returned buffer contains the encoded array. -// The function panics if the map is larger than math.MaxUint32 elements. +// The function panics if the array is larger than math.MaxUint32 elements. func AppendArray[T any](b []byte, a []T, writeFn func(b []byte, v T) []byte) []byte { if a == nil { return AppendNil(b) @@ -341,8 +342,8 @@ type UnmarshalPtr[T any] interface { Unmarshaler } -// DecoderFromBytes allows augmenting any type with a UnmarshalMsg method into a method -// that reads from []byte and returns a T. +// DecoderFromBytes allows augmenting any type with an UnmarshalMsg +// method into a method that reads from []byte and returns a T. // Provide an instance of T. This value isn't used. func DecoderFromBytes[T any, PT UnmarshalPtr[T]](_ T) func([]byte) (T, []byte, error) { return func(b []byte) (T, []byte, error) { diff --git a/msgp/iter_test.go b/msgp/iter_test.go index 0ab1d974..71b1ec32 100644 --- a/msgp/iter_test.go +++ b/msgp/iter_test.go @@ -17,12 +17,41 @@ func ExampleReadArray() { w := NewWriter(&buf) // Write an array [10, 20, 30] using WriteArray - _ = WriteArray[int](w, []int{10, 20, 30}, w.WriteInt) + _ = WriteArray(w, []int{10, 20, 30}, w.WriteInt) _ = w.Flush() r := NewReader(&buf) - seq := ReadArray[int](r, r.ReadInt) + seq := ReadArray(r, r.ReadInt) + seq(func(v int, err error) bool { + if err != nil { + fmt.Println("err:", err) + return false + } + fmt.Println(v) + return true + }) + + // Output: + // 10 + // 20 + // 30 +} + +// Example: Writing and array with WriteArray, +// then reading back using ReadArray with a *Reader. +// It prints each element in order. +func ExampleWriteArray() { + var buf bytes.Buffer + w := NewWriter(&buf) + + // Write an array [10, 20, 30] using WriteArray + _ = WriteArray(w, []int8{10, 20, 30}, w.WriteInt8) + _ = w.Flush() + + r := NewReader(&buf) + + seq := ReadArray(r, r.ReadInt) seq(func(v int, err error) bool { if err != nil { fmt.Println("err:", err) @@ -45,12 +74,12 @@ func ExampleReadMap() { w := NewWriter(&buf) // Write a map {"a":1, "b":2} using WriteMap - we use the sorted version so output is predictable. - _ = WriteMapSorted[string, int](w, map[string]int{"a": 1, "b": 2}, w.WriteString, w.WriteInt) + _ = WriteMapSorted(w, map[string]int{"a": 1, "b": 2}, w.WriteString, w.WriteInt) _ = w.Flush() r := NewReader(&buf) - seq, done := ReadMap[string, int](r, r.ReadString, r.ReadInt) + seq, done := ReadMap(r, r.ReadString, r.ReadInt) seq(func(k string, v int) bool { fmt.Printf("%s=%d\n", k, v) return true @@ -64,14 +93,39 @@ func ExampleReadMap() { // b=2 } +// Example: writing a map using WriteMap (non-sorted). +// Uses a single-entry map to keep output deterministic. +// Use WriteMapSorted to write a sorted map. +func ExampleWriteMap() { + var buf bytes.Buffer + w := NewWriter(&buf) + + // Write a map {"only":1} using WriteMap + _ = WriteMap(w, map[string]int{"only": 1}, w.WriteString, w.WriteInt) + _ = w.Flush() + + r := NewReader(&buf) + seq, done := ReadMap(r, r.ReadString, r.ReadInt) + seq(func(k string, v int) bool { + fmt.Printf("%s=%d\n", k, v) + return true + }) + if err := done(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // only=1 +} + // Example: reading an array of strings from a byte slice using ReadArrayBytes. // It prints each element and then checks for a final error from the returned closure. func ExampleReadArrayBytes() { var b []byte // Append ["x","y","z"] using AppendArray - b = AppendArray[string](b, []string{"x", "y", "z"}, AppendString) + b = AppendArray(b, []string{"x", "y", "z"}, AppendString) - seq, finish := ReadArrayBytes[string](b, ReadStringBytes) + seq, finish := ReadArrayBytes(b, ReadStringBytes) seq(func(s string) bool { fmt.Println(s) return true @@ -91,9 +145,9 @@ func ExampleReadArrayBytes() { func ExampleReadMapBytes() { var b []byte // Append {"pi":3.14, "e":2.718} using AppendMap - we use the sorted version for the example - b = AppendMapSorted[string, float64](b, map[string]float64{"pi": 3.14, "e": 2.718}, AppendString, AppendFloat64) + b = AppendMapSorted(b, map[string]float64{"pi": 3.14, "e": 2.718}, AppendString, AppendFloat64) - seq, finish := ReadMapBytes[string, float64](b, ReadStringBytes, ReadFloat64Bytes) + seq, finish := ReadMapBytes(b, ReadStringBytes, ReadFloat64Bytes) seq(func(k string, v float64) bool { fmt.Printf("%s=%.3f\n", k, v) return true @@ -115,12 +169,12 @@ func ExampleReadArray_struct() { in := []testDec{{A: 1, B: "x"}, {A: 2, B: "y"}} // Write []testDec using EncoderTo as the per-element writer. - _ = WriteArray[testDec](w, in, EncoderTo(w, testDec{})) + _ = WriteArray(w, in, EncoderTo(w, testDec{})) _ = w.Flush() r := NewReader(&buf) // Read []testDec using DecoderFrom as the per-element reader. - seq := ReadArray[testDec](r, DecoderFrom(r, testDec{})) + seq := ReadArray(r, DecoderFrom(r, testDec{})) seq(func(v testDec, err error) bool { if err != nil { @@ -148,11 +202,11 @@ func ExampleReadMap_struct() { "b": {A: 2, B: "y"}, } // Write map[string]testDec using sorted keys for stable example output. - _ = WriteMapSorted[string, testDec](w, in, w.WriteString, EncoderTo(w, testDec{})) + _ = WriteMapSorted(w, in, w.WriteString, EncoderTo(w, testDec{})) _ = w.Flush() r := NewReader(&buf) - seq, done := ReadMap[string, testDec](r, r.ReadString, DecoderFrom(r, testDec{})) + seq, done := ReadMap(r, r.ReadString, DecoderFrom(r, testDec{})) seq(func(k string, v testDec) bool { fmt.Printf("%s=%d,%s\n", k, v.A, v.B) @@ -174,10 +228,10 @@ func ExampleReadArrayBytes_struct() { var b []byte // Append []testDec using EncoderToBytes as the per-element appender. - b = AppendArray[testDec](b, in, EncoderToBytes(testDec{})) + b = AppendArray(b, in, EncoderToBytes(testDec{})) // Read back using DecoderFromBytes as the per-element reader. - seq, finish := ReadArrayBytes[testDec](b, DecoderFromBytes(testDec{})) + seq, finish := ReadArrayBytes(b, DecoderFromBytes(testDec{})) seq(func(v testDec) bool { fmt.Printf("%d %s\n", v.A, v.B) @@ -203,10 +257,10 @@ func ExampleReadMapBytes_struct() { var b []byte // Append map[string]testDec with sorted keys for stable example output. - b = AppendMapSorted[string, testDec](b, in, AppendString, EncoderToBytes(testDec{})) + b = AppendMapSorted(b, in, AppendString, EncoderToBytes(testDec{})) // Read back using DecoderFromBytes for values. - seq, finish := ReadMapBytes[string, testDec](b, ReadStringBytes, DecoderFromBytes(testDec{})) + seq, finish := ReadMapBytes(b, ReadStringBytes, DecoderFromBytes(testDec{})) seq(func(k string, v testDec) bool { fmt.Printf("%s=%d,%s\n", k, v.A, v.B) @@ -221,6 +275,29 @@ func ExampleReadMapBytes_struct() { // b=2,y } +// Example: appending a map to a byte slice using AppendMap (non-sorted). +// Uses a single-entry map to keep output deterministic. +// Use AppendMapSorted to write a sorted map. +func ExampleAppendMap() { + var b []byte + + // Append {"only":1} using AppendMap + b = AppendMap(b, map[string]int{"only": 1}, AppendString, AppendInt) + + // Read back and print + seq, finish := ReadMapBytes(b, ReadStringBytes, ReadIntBytes) + seq(func(k string, v int) bool { + fmt.Printf("%s=%d\n", k, v) + return true + }) + if _, err := finish(); err != nil { + fmt.Println("err:", err) + } + + // Output: + // only=1 +} + var nilMsg = AppendNil(nil) // collectSeq2 collects values from an iter.Seq2[V, error] into a slice. @@ -255,7 +332,7 @@ func TestReadNumberArray_Int(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[int](r, r.ReadInt)) + got, err := collectSeq2(ReadArray(r, r.ReadInt)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -287,7 +364,7 @@ func TestReadNumberArray_Float64(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[float64](r, r.ReadFloat64)) + got, err := collectSeq2(ReadArray(r, r.ReadFloat64)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -319,7 +396,7 @@ func TestReadArray_String(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[string](r, r.ReadString)) + got, err := collectSeq2(ReadArray(r, r.ReadString)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -351,7 +428,7 @@ func TestReadArray_Bool(t *testing.T) { } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[bool](r, r.ReadBool)) + got, err := collectSeq2(ReadArray(r, r.ReadBool)) if err != nil { t.Fatalf("iteration error: %v", err) } @@ -508,7 +585,7 @@ func TestReadNumberArrayBytes_Uint16(t *testing.T) { msg = AppendUint16(msg, v) } - seq, tail := ReadArrayBytes[uint16](msg, ReadUint16Bytes) + seq, tail := ReadArrayBytes(msg, ReadUint16Bytes) var got []uint16 for v := range seq { got = append(got, v) @@ -540,7 +617,7 @@ func TestReadNumberArrayBytes_ErrOnTruncated(t *testing.T) { // Truncate to cause an error when reading the second element. trunc := full[:len(full)-2] - seq, tail := ReadArrayBytes[int32](trunc, ReadInt32Bytes) + seq, tail := ReadArrayBytes(trunc, ReadInt32Bytes) var got []int32 for v := range seq { got = append(got, v) @@ -577,7 +654,7 @@ func TestReadArray_ErrorOnTooFewElements(t *testing.T) { r := NewReader(&buf) var got []int var firstErr error - ReadArray[int](r, r.ReadInt)(func(v int, err error) bool { + ReadArray(r, r.ReadInt)(func(v int, err error) bool { if err != nil { firstErr = err return false @@ -642,7 +719,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[uint](r, r.ReadUint)) + got, err := collectSeq2(ReadArray(r, r.ReadUint)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -655,7 +732,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { } } r = NewReader(bytes.NewReader(nilMsg)) - got, err = collectSeq2(ReadArray[uint](r, r.ReadUint)) + got, err = collectSeq2(ReadArray(r, r.ReadUint)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -678,7 +755,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[uint8](r, r.ReadUint8)) + got, err := collectSeq2(ReadArray(r, r.ReadUint8)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -705,7 +782,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[uint16](r, r.ReadUint16)) + got, err := collectSeq2(ReadArray(r, r.ReadUint16)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -732,7 +809,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[uint32](r, r.ReadUint32)) + got, err := collectSeq2(ReadArray(r, r.ReadUint32)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -759,7 +836,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[uint64](r, r.ReadUint64)) + got, err := collectSeq2(ReadArray(r, r.ReadUint64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -786,7 +863,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[int](r, r.ReadInt)) + got, err := collectSeq2(ReadArray(r, r.ReadInt)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -813,7 +890,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[int8](r, r.ReadInt8)) + got, err := collectSeq2(ReadArray(r, r.ReadInt8)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -840,7 +917,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[int16](r, r.ReadInt16)) + got, err := collectSeq2(ReadArray(r, r.ReadInt16)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -867,7 +944,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[int32](r, r.ReadInt32)) + got, err := collectSeq2(ReadArray(r, r.ReadInt32)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -894,7 +971,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[int64](r, r.ReadInt64)) + got, err := collectSeq2(ReadArray(r, r.ReadInt64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -921,7 +998,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[float32](r, r.ReadFloat32)) + got, err := collectSeq2(ReadArray(r, r.ReadFloat32)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -948,7 +1025,7 @@ func TestRoundtripNumberArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[float64](r, r.ReadFloat64)) + got, err := collectSeq2(ReadArray(r, r.ReadFloat64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1013,7 +1090,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[bool](r, r.ReadBool)) + got, err := collectSeq2(ReadArray(r, r.ReadBool)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1026,7 +1103,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { } } r.Reset(bytes.NewReader(nilMsg)) - got, err = collectSeq2(ReadArray[bool](r, r.ReadBool)) + got, err = collectSeq2(ReadArray(r, r.ReadBool)) if len(got) != 0 { t.Fatalf("%s len: got %d want 0", tc.name, len(got)) } @@ -1049,7 +1126,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[string](r, r.ReadString)) + got, err := collectSeq2(ReadArray(r, r.ReadString)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1076,7 +1153,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[[]byte](r, func() ([]byte, error) { + got, err := collectSeq2(ReadArray(r, func() ([]byte, error) { return r.ReadBytes(nil) })) if err != nil { @@ -1105,7 +1182,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[time.Time](r, r.ReadTime)) + got, err := collectSeq2(ReadArray(r, r.ReadTime)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1132,7 +1209,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[time.Duration](r, r.ReadDuration)) + got, err := collectSeq2(ReadArray(r, r.ReadDuration)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1159,7 +1236,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[complex64](r, r.ReadComplex64)) + got, err := collectSeq2(ReadArray(r, r.ReadComplex64)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1186,7 +1263,7 @@ func TestRoundtripArray_AllTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - got, err := collectSeq2(ReadArray[complex128](r, r.ReadComplex128)) + got, err := collectSeq2(ReadArray(r, r.ReadComplex128)) if err != nil { t.Fatalf("%s iterate: %v", tc.name, err) } @@ -1260,7 +1337,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[uint](msg, ReadUintBytes) + seq, tail := ReadArrayBytes(msg, ReadUintBytes) var got []uint for v := range seq { got = append(got, v) @@ -1280,7 +1357,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } - seq, tail = ReadArrayBytes[uint](nilMsg, ReadUintBytes) + seq, tail = ReadArrayBytes(nilMsg, ReadUintBytes) for range seq { t.Fatalf("%s: got entries on nil", tc.name) } @@ -1297,7 +1374,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[uint8](msg, ReadUint8Bytes) + seq, tail := ReadArrayBytes(msg, ReadUint8Bytes) var got []uint8 for v := range seq { got = append(got, v) @@ -1322,7 +1399,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[uint16](msg, ReadUint16Bytes) + seq, tail := ReadArrayBytes(msg, ReadUint16Bytes) var got []uint16 for v := range seq { got = append(got, v) @@ -1347,7 +1424,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[uint32](msg, ReadUint32Bytes) + seq, tail := ReadArrayBytes(msg, ReadUint32Bytes) var got []uint32 for v := range seq { got = append(got, v) @@ -1372,7 +1449,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[uint64](msg, ReadUint64Bytes) + seq, tail := ReadArrayBytes(msg, ReadUint64Bytes) var got []uint64 for v := range seq { got = append(got, v) @@ -1397,7 +1474,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[int](msg, ReadIntBytes) + seq, tail := ReadArrayBytes(msg, ReadIntBytes) var got []int for v := range seq { got = append(got, v) @@ -1422,7 +1499,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[int8](msg, ReadInt8Bytes) + seq, tail := ReadArrayBytes(msg, ReadInt8Bytes) var got []int8 for v := range seq { got = append(got, v) @@ -1447,7 +1524,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[int16](msg, ReadInt16Bytes) + seq, tail := ReadArrayBytes(msg, ReadInt16Bytes) var got []int16 for v := range seq { got = append(got, v) @@ -1472,7 +1549,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[int32](msg, ReadInt32Bytes) + seq, tail := ReadArrayBytes(msg, ReadInt32Bytes) var got []int32 for v := range seq { got = append(got, v) @@ -1497,7 +1574,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[int64](msg, ReadInt64Bytes) + seq, tail := ReadArrayBytes(msg, ReadInt64Bytes) var got []int64 for v := range seq { got = append(got, v) @@ -1522,7 +1599,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[float32](msg, ReadFloat32Bytes) + seq, tail := ReadArrayBytes(msg, ReadFloat32Bytes) var got []float32 for v := range seq { got = append(got, v) @@ -1547,7 +1624,7 @@ func TestRoundtripNumberArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[float64](msg, ReadFloat64Bytes) + seq, tail := ReadArrayBytes(msg, ReadFloat64Bytes) var got []float64 for v := range seq { got = append(got, v) @@ -1610,7 +1687,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[bool](msg, ReadBoolBytes) + seq, tail := ReadArrayBytes(msg, ReadBoolBytes) var got []bool for v := range seq { got = append(got, v) @@ -1630,7 +1707,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { t.Fatalf("%s[%d]: got %v want %v", tc.name, i, got[i], tc.vals[i]) } } - seq, tail = ReadArrayBytes[bool](nilMsg, ReadBoolBytes) + seq, tail = ReadArrayBytes(nilMsg, ReadBoolBytes) for range seq { t.Fatalf("%s: got entries on nil", tc.name) } @@ -1646,7 +1723,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[string](msg, ReadStringBytes) + seq, tail := ReadArrayBytes(msg, ReadStringBytes) var got []string for v := range seq { got = append(got, v) @@ -1748,7 +1825,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[complex64](msg, ReadComplex64Bytes) + seq, tail := ReadArrayBytes(msg, ReadComplex64Bytes) var got []complex64 for v := range seq { got = append(got, v) @@ -1773,7 +1850,7 @@ func TestRoundtripArrayBytes_AllTypes(t *testing.T) { for _, v := range tc.vals { msg = tc.append(msg, v) } - seq, tail := ReadArrayBytes[complex128](msg, ReadComplex128Bytes) + seq, tail := ReadArrayBytes(msg, ReadComplex128Bytes) var got []complex128 for v := range seq { got = append(got, v) @@ -1872,7 +1949,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint, uint](r, r.ReadUint, r.ReadUint) + seq, tail := ReadMap(r, r.ReadUint, r.ReadUint) got := make(map[uint]uint, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1890,7 +1967,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { } // Test nil r = NewReader(bytes.NewReader(nilMsg)) - seq, tail = ReadMap[uint, uint](r, r.ReadUint, r.ReadUint) + seq, tail = ReadMap(r, r.ReadUint, r.ReadUint) for k, v := range seq { t.Fatalf("nil %s: got key %v val %v", tc.name, k, v) } @@ -1915,7 +1992,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint8, uint8](r, r.ReadUint8, r.ReadUint8) + seq, tail := ReadMap(r, r.ReadUint8, r.ReadUint8) got := make(map[uint8]uint8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1949,7 +2026,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint16, uint16](r, r.ReadUint16, r.ReadUint16) + seq, tail := ReadMap(r, r.ReadUint16, r.ReadUint16) got := make(map[uint16]uint16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -1983,7 +2060,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint32, uint32](r, r.ReadUint32, r.ReadUint32) + seq, tail := ReadMap(r, r.ReadUint32, r.ReadUint32) got := make(map[uint32]uint32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2017,7 +2094,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[uint64, uint64](r, r.ReadUint64, r.ReadUint64) + seq, tail := ReadMap(r, r.ReadUint64, r.ReadUint64) got := make(map[uint64]uint64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2051,7 +2128,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int, int](r, r.ReadInt, r.ReadInt) + seq, tail := ReadMap(r, r.ReadInt, r.ReadInt) got := make(map[int]int, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2085,7 +2162,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int8, int8](r, r.ReadInt8, r.ReadInt8) + seq, tail := ReadMap(r, r.ReadInt8, r.ReadInt8) got := make(map[int8]int8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2119,7 +2196,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int16, int16](r, r.ReadInt16, r.ReadInt16) + seq, tail := ReadMap(r, r.ReadInt16, r.ReadInt16) got := make(map[int16]int16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2153,7 +2230,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int32, int32](r, r.ReadInt32, r.ReadInt32) + seq, tail := ReadMap(r, r.ReadInt32, r.ReadInt32) got := make(map[int32]int32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2187,7 +2264,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[int64, int64](r, r.ReadInt64, r.ReadInt64) + seq, tail := ReadMap(r, r.ReadInt64, r.ReadInt64) got := make(map[int64]int64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2221,7 +2298,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[float32, float32](r, r.ReadFloat32, r.ReadFloat32) + seq, tail := ReadMap(r, r.ReadFloat32, r.ReadFloat32) got := make(map[float32]float32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2255,7 +2332,7 @@ func TestReadMap_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[float64, float64](r, r.ReadFloat64, r.ReadFloat64) + seq, tail := ReadMap(r, r.ReadFloat64, r.ReadFloat64) got := make(map[float64]float64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2335,7 +2412,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[bool, bool](r, r.ReadBool, r.ReadBool) + seq, tail := ReadMap(r, r.ReadBool, r.ReadBool) got := make(map[bool]bool, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2352,7 +2429,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { } } r = NewReader(bytes.NewReader(nilMsg)) - seq, tail = ReadMap[bool, bool](r, r.ReadBool, r.ReadBool) + seq, tail = ReadMap(r, r.ReadBool, r.ReadBool) for k, v := range seq { t.Fatalf("%s:expected ni results, got %v:%v", tc.name, k, v) } @@ -2378,7 +2455,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[string, string](r, r.ReadString, r.ReadString) + seq, tail := ReadMap(r, r.ReadString, r.ReadString) got := make(map[string]string, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2420,7 +2497,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { } r := NewReader(&buf) - seq, tail := ReadMap[[]byte, []byte](r, func() ([]byte, error) { + seq, tail := ReadMap(r, func() ([]byte, error) { return r.ReadBytes(nil) }, func() ([]byte, error) { return r.ReadBytes(nil) @@ -2470,7 +2547,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[time.Time, time.Time](r, r.ReadTime, r.ReadTime) + seq, tail := ReadMap(r, r.ReadTime, r.ReadTime) got := make(map[time.Time]time.Time, len(tc.keys)) for k, v := range seq { got[k.UTC()] = v.UTC() @@ -2504,7 +2581,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[time.Duration, time.Duration](r, r.ReadDuration, r.ReadDuration) + seq, tail := ReadMap(r, r.ReadDuration, r.ReadDuration) got := make(map[time.Duration]time.Duration, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2539,7 +2616,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[complex64, complex64](r, r.ReadComplex64, r.ReadComplex64) + seq, tail := ReadMap(r, r.ReadComplex64, r.ReadComplex64) got := make(map[complex64]complex64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2573,7 +2650,7 @@ func TestReadMap_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("flush: %v", err) } r := NewReader(&buf) - seq, tail := ReadMap[complex128, complex128](r, r.ReadComplex128, r.ReadComplex128) + seq, tail := ReadMap(r, r.ReadComplex128, r.ReadComplex128) got := make(map[complex128]complex128, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2628,7 +2705,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint, uint](msg, ReadUintBytes, ReadUintBytes) + seq, tail := ReadMapBytes(msg, ReadUintBytes, ReadUintBytes) got := make(map[uint]uint, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2648,7 +2725,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } - seq, tail = ReadMapBytes[uint, uint](nilMsg, ReadUintBytes, ReadUintBytes) + seq, tail = ReadMapBytes(nilMsg, ReadUintBytes, ReadUintBytes) for k, v := range seq { t.Fatalf("%s: got %v:%v want nothing", tc.name, k, v) } @@ -2665,7 +2742,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint8, uint8](msg, ReadUint8Bytes, ReadUint8Bytes) + seq, tail := ReadMapBytes(msg, ReadUint8Bytes, ReadUint8Bytes) got := make(map[uint8]uint8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2691,7 +2768,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint16, uint16](msg, ReadUint16Bytes, ReadUint16Bytes) + seq, tail := ReadMapBytes(msg, ReadUint16Bytes, ReadUint16Bytes) got := make(map[uint16]uint16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2717,7 +2794,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint32, uint32](msg, ReadUint32Bytes, ReadUint32Bytes) + seq, tail := ReadMapBytes(msg, ReadUint32Bytes, ReadUint32Bytes) got := make(map[uint32]uint32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2743,7 +2820,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[uint64, uint64](msg, ReadUint64Bytes, ReadUint64Bytes) + seq, tail := ReadMapBytes(msg, ReadUint64Bytes, ReadUint64Bytes) got := make(map[uint64]uint64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2769,7 +2846,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int, int](msg, ReadIntBytes, ReadIntBytes) + seq, tail := ReadMapBytes(msg, ReadIntBytes, ReadIntBytes) got := make(map[int]int, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2795,7 +2872,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int8, int8](msg, ReadInt8Bytes, ReadInt8Bytes) + seq, tail := ReadMapBytes(msg, ReadInt8Bytes, ReadInt8Bytes) got := make(map[int8]int8, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2821,7 +2898,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int16, int16](msg, ReadInt16Bytes, ReadInt16Bytes) + seq, tail := ReadMapBytes(msg, ReadInt16Bytes, ReadInt16Bytes) got := make(map[int16]int16, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2847,7 +2924,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int32, int32](msg, ReadInt32Bytes, ReadInt32Bytes) + seq, tail := ReadMapBytes(msg, ReadInt32Bytes, ReadInt32Bytes) got := make(map[int32]int32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2873,7 +2950,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[int64, int64](msg, ReadInt64Bytes, ReadInt64Bytes) + seq, tail := ReadMapBytes(msg, ReadInt64Bytes, ReadInt64Bytes) got := make(map[int64]int64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2899,7 +2976,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[float32, float32](msg, ReadFloat32Bytes, ReadFloat32Bytes) + seq, tail := ReadMapBytes(msg, ReadFloat32Bytes, ReadFloat32Bytes) got := make(map[float32]float32, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2925,7 +3002,7 @@ func TestReadMapBytes_AllNumberTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[float64, float64](msg, ReadFloat64Bytes, ReadFloat64Bytes) + seq, tail := ReadMapBytes(msg, ReadFloat64Bytes, ReadFloat64Bytes) got := make(map[float64]float64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -2997,7 +3074,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[bool, bool](msg, ReadBoolBytes, ReadBoolBytes) + seq, tail := ReadMapBytes(msg, ReadBoolBytes, ReadBoolBytes) got := make(map[bool]bool, len(tc.keys)) for k, v := range seq { got[k] = v @@ -3017,7 +3094,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { t.Fatalf("%s key %v got %v want %v", tc.name, tc.keys[i], got[tc.keys[i]], tc.vals[i]) } } - seq, tail = ReadMapBytes[bool, bool](nilMsg, ReadBoolBytes, ReadBoolBytes) + seq, tail = ReadMapBytes(nilMsg, ReadBoolBytes, ReadBoolBytes) for k, v := range seq { t.Fatalf("%s key %v:%v want nothing", tc.name, k, v) } @@ -3035,7 +3112,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[string, string](msg, ReadStringBytes, ReadStringBytes) + seq, tail := ReadMapBytes(msg, ReadStringBytes, ReadStringBytes) got := make(map[string]string, len(tc.keys)) for k, v := range seq { got[k] = v @@ -3104,7 +3181,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[time.Time, time.Time](msg, ReadTimeBytes, ReadTimeBytes) + seq, tail := ReadMapBytes(msg, ReadTimeBytes, ReadTimeBytes) got := make(map[time.Time]time.Time, len(tc.keys)) for k, v := range seq { got[k.UTC()] = v.UTC() @@ -3130,7 +3207,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[time.Duration, time.Duration](msg, ReadDurationBytes, ReadDurationBytes) + seq, tail := ReadMapBytes(msg, ReadDurationBytes, ReadDurationBytes) got := make(map[time.Duration]time.Duration, len(tc.keys)) for k, v := range seq { got[k] = v @@ -3156,7 +3233,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[complex64, complex64](msg, ReadComplex64Bytes, ReadComplex64Bytes) + seq, tail := ReadMapBytes(msg, ReadComplex64Bytes, ReadComplex64Bytes) got := make(map[complex64]complex64, len(tc.keys)) for k, v := range seq { got[k] = v @@ -3182,7 +3259,7 @@ func TestReadMapBytes_AllRegularTypes_SameKeyValueTypes(t *testing.T) { msg = tc.append(msg, tc.keys[i]) msg = tc.append(msg, tc.vals[i]) } - seq, tail := ReadMapBytes[complex128, complex128](msg, ReadComplex128Bytes, ReadComplex128Bytes) + seq, tail := ReadMapBytes(msg, ReadComplex128Bytes, ReadComplex128Bytes) got := make(map[complex128]complex128, len(tc.keys)) for k, v := range seq { got[k] = v @@ -3216,7 +3293,7 @@ func TestReadMapBytes_TailErrorOnTruncated(t *testing.T) { full := AppendInt(msg, 20) trunc := full[:len(full)-2] // truncate some bytes from the last value - seq, tail := ReadMapBytes[int, int](trunc, ReadIntBytes, ReadIntBytes) + seq, tail := ReadMapBytes(trunc, ReadIntBytes, ReadIntBytes) got := make(map[int]int) for k, v := range seq { got[k] = v From 97372b0f89dca926004026c277ed478d058a08cc Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 3 Sep 2025 12:52:13 +0200 Subject: [PATCH 7/7] Cleanup docs more. --- msgp/iter.go | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/msgp/iter.go b/msgp/iter.go index f72df24e..3689649e 100644 --- a/msgp/iter.go +++ b/msgp/iter.go @@ -70,7 +70,8 @@ func WriteArray[T any](w *Writer, a []T, writeFn func(T) error) error { // The type parameters K and V specify the types of the keys and values in the map. // The returned iterator implements the iter.Seq2[K, V] interface, // allowing for sequential access to the map elements. -// The returned function can be used to read any error that occurred during iteration when iteration is done. +// The returned function can be used to read any error that +// occurred during iteration when iteration is done. func ReadMap[K, V any](m *Reader, readKey func() (K, error), readVal func() (V, error)) (iter.Seq2[K, V], func() error) { var err error return func(yield func(K, V) bool) { @@ -137,7 +138,7 @@ func WriteMap[K comparable, V any](w *Writer, m map[K]V, writeKey func(K) error, // WriteMapSorted writes a map to the provided Writer. // The keys of the map are sorted before writing. -// This provides deterministic output. +// This provides deterministic output, but will allocate to sort the keys. // The writeKey and writeVal parameters specify the functions // to use to write each key and value of the map. func WriteMapSorted[K cmp.Ordered, V any](w *Writer, m map[K]V, writeKey func(K) error, writeVal func(V) error) error { @@ -214,7 +215,7 @@ func AppendArray[T any](b []byte, a []T, writeFn func(b []byte, v T) []byte) []b // ReadMapBytes returns an iterator over key/value // pairs from a MessagePack map encoded in b. -// The iterator yields K,V pairs and this function also returns +// The iterator yields K,V pairs, and this function also returns // a closure to get the remaining bytes and any error. func ReadMapBytes[K any, V any](b []byte, readK func([]byte) (K, []byte, error), @@ -273,7 +274,8 @@ func AppendMap[K comparable, V any](b []byte, m map[K]V, } // AppendMapSorted writes a map to the provided buffer. -// Keys are sorted before writing. This provides deterministic output. +// Keys are sorted before writing. +// This provides deterministic output, but will allocate to sort the keys. // The writeK and writeV parameters specify the functions to use to write each key and value of the map. // The returned buffer contains the encoded map. // The function panics if the map is larger than math.MaxUint32 elements. @@ -303,6 +305,7 @@ type DecodePtr[T any] interface { // DecoderFrom allows augmenting any type with a DecodeMsg method into a method // that reads from Reader and returns a T. // Provide an instance of T. This value isn't used. +// See ReadArray/ReadMap "struct" examples for usage. func DecoderFrom[T any, PT DecodePtr[T]](r *Reader, _ T) func() (T, error) { return func() (T, error) { var t T @@ -315,13 +318,14 @@ func DecoderFrom[T any, PT DecodePtr[T]](r *Reader, _ T) func() (T, error) { // FlexibleEncoder is a constraint for types where either T or *T implements Encodable type FlexibleEncoder[T any] interface { Encodable - *T // Include *T in the interface + *T } -// EncoderTo allows augmenting any type with a EncodeMsg method into a method -// that writes to Writer on each call. -// Provide an instance of T. This value isn't used.' -func EncoderTo[T any, PT FlexibleEncoder[T]](w *Writer, _ T) func(T) error { +// EncoderTo allows augmenting any type with an EncodeMsg +// method into a method that writes to Writer on each call. +// Provide an instance of T. This value isn't used. +// See ReadArray or ReadMap "struct" examples for usage. +func EncoderTo[T any, _ FlexibleEncoder[T]](w *Writer, _ T) func(T) error { return func(t T) error { // Check if T implements Marshaler if marshaler, ok := any(t).(Encodable); ok { @@ -345,6 +349,7 @@ type UnmarshalPtr[T any] interface { // DecoderFromBytes allows augmenting any type with an UnmarshalMsg // method into a method that reads from []byte and returns a T. // Provide an instance of T. This value isn't used. +// See ReadArrayBytes or ReadMapBytes "struct" examples for usage. func DecoderFromBytes[T any, PT UnmarshalPtr[T]](_ T) func([]byte) (T, []byte, error) { return func(b []byte) (T, []byte, error) { var t T @@ -363,7 +368,8 @@ type FlexibleMarshaler[T any] interface { // EncoderToBytes allows augmenting any type with a MarshalMsg method into a method // that reads from T and returns a []byte. // Provide an instance of T. This value isn't used. -func EncoderToBytes[T any, PT FlexibleMarshaler[T]](_ T) func([]byte, T) []byte { +// See ReadArrayBytes or ReadMapBytes "struct" examples for usage. +func EncoderToBytes[T any, _ FlexibleMarshaler[T]](_ T) func([]byte, T) []byte { return func(b []byte, t T) []byte { // Check if T implements Marshaler if marshaler, ok := any(t).(Marshaler); ok {