diff --git a/stream.go b/stream.go index 7ac6d7d6..39768029 100644 --- a/stream.go +++ b/stream.go @@ -171,14 +171,20 @@ func NewEncoder(w io.Writer) *Encoder { // Encode writes the CBOR encoding of v. func (enc *Encoder) Encode(v any) error { - if len(enc.indefTypes) > 0 && v != nil { - indefType := enc.indefTypes[len(enc.indefTypes)-1] - if indefType == cborTypeTextString { + if len(enc.indefTypes) > 0 { + switch enc.indefTypes[len(enc.indefTypes)-1] { + case cborTypeTextString: + if v == nil { + return errors.New("cbor: cannot encode nil for indefinite-length text string") + } k := reflect.TypeOf(v).Kind() if k != reflect.String { return errors.New("cbor: cannot encode item type " + k.String() + " for indefinite-length text string") } - } else if indefType == cborTypeByteString { + case cborTypeByteString: + if v == nil { + return errors.New("cbor: cannot encode nil for indefinite-length byte string") + } t := reflect.TypeOf(v) k := t.Kind() if (k != reflect.Array && k != reflect.Slice) || t.Elem().Kind() != reflect.Uint8 { @@ -219,7 +225,7 @@ func (enc *Encoder) StartIndefiniteArray() error { return enc.startIndefinite(cborTypeArray) } -// StartIndefiniteMap starts array encoding of indefinite length. +// StartIndefiniteMap starts map encoding of indefinite length. // Subsequent calls of (*Encoder).Encode() encodes elements of the map // until EndIndefinite is called. func (enc *Encoder) StartIndefiniteMap() error { diff --git a/stream_test.go b/stream_test.go index 7c9b8c35..ddfc5fcb 100644 --- a/stream_test.go +++ b/stream_test.go @@ -883,6 +883,19 @@ func TestIndefiniteTextString(t *testing.T) { } } +func TestIndefiniteByteStringNilError(t *testing.T) { + var w bytes.Buffer + encoder := NewEncoder(&w) + if err := encoder.StartIndefiniteByteString(); err != nil { + t.Fatalf("StartIndefiniteByteString() returned error %v", err) + } + if err := encoder.Encode(nil); err == nil { + t.Errorf("Encode() didn't return an error") + } else if err.Error() != "cbor: cannot encode nil for indefinite-length byte string" { + t.Errorf("Encode() returned error %q, want %q", err.Error(), "cbor: cannot encode nil for indefinite-length byte string") + } +} + func TestIndefiniteTextStringError(t *testing.T) { var w bytes.Buffer encoder := NewEncoder(&w) @@ -894,6 +907,24 @@ func TestIndefiniteTextStringError(t *testing.T) { } else if err.Error() != "cbor: cannot encode item type slice for indefinite-length text string" { t.Errorf("Encode() returned error %q, want %q", err.Error(), "cbor: cannot encode item type slice for indefinite-length text string") } + if err := encoder.Encode(123); err == nil { + t.Errorf("Encode() didn't return an error") + } else if err.Error() != "cbor: cannot encode item type int for indefinite-length text string" { + t.Errorf("Encode() returned error %q, want %q", err.Error(), "cbor: cannot encode item type int for indefinite-length text string") + } +} + +func TestIndefiniteTextStringNilError(t *testing.T) { + var w bytes.Buffer + encoder := NewEncoder(&w) + if err := encoder.StartIndefiniteTextString(); err != nil { + t.Fatalf("StartIndefiniteTextString() returned error %v", err) + } + if err := encoder.Encode(nil); err == nil { + t.Errorf("Encode() didn't return an error") + } else if err.Error() != "cbor: cannot encode nil for indefinite-length text string" { + t.Errorf("Encode() returned error %q, want %q", err.Error(), "cbor: cannot encode nil for indefinite-length text string") + } } func TestIndefiniteArray(t *testing.T) { @@ -929,6 +960,41 @@ func TestIndefiniteArray(t *testing.T) { } } +func TestIndefiniteArrayWithNilElement(t *testing.T) { + want := mustHexDecode("9f01f6ff") // [1, null] + var w bytes.Buffer + encoder := NewEncoder(&w) + if err := encoder.StartIndefiniteArray(); err != nil { + t.Fatalf("StartIndefiniteArray() returned error %v", err) + } + if err := encoder.Encode(1); err != nil { + t.Fatalf("Encode() returned error %v", err) + } + if err := encoder.Encode(nil); err != nil { + t.Fatalf("Encode() returned error %v", err) + } + if err := encoder.EndIndefinite(); err != nil { + t.Fatalf("EndIndefinite() returned error %v", err) + } + if !bytes.Equal(w.Bytes(), want) { + t.Errorf("Encoding mismatch: got %v, want %v", w.Bytes(), want) + } + + var decoded []any + if err := Unmarshal(w.Bytes(), &decoded); err != nil { + t.Fatalf("Unmarshal() returned error %v", err) + } + if len(decoded) != 2 { + t.Fatalf("Unmarshal() returned %d elements, want 2", len(decoded)) + } + if decoded[0] != uint64(1) { + t.Errorf("decoded[0] = %v (%T), want uint64(1)", decoded[0], decoded[0]) + } + if decoded[1] != nil { + t.Errorf("decoded[1] = %v (%T), want nil", decoded[1], decoded[1]) + } +} + func TestIndefiniteMap(t *testing.T) { want := mustHexDecode("bf61610161629f0203ffff") var w bytes.Buffer @@ -969,6 +1035,42 @@ func TestIndefiniteMap(t *testing.T) { } } +func TestIndefiniteMapWithNilElement(t *testing.T) { + want := mustHexDecode("bf6161f6ff") // {"a": null} + var w bytes.Buffer + encoder := NewEncoder(&w) + if err := encoder.StartIndefiniteMap(); err != nil { + t.Fatalf("StartIndefiniteMap() returned error %v", err) + } + if err := encoder.Encode("a"); err != nil { + t.Fatalf("Encode() returned error %v", err) + } + if err := encoder.Encode(nil); err != nil { + t.Fatalf("Encode() returned error %v", err) + } + if err := encoder.EndIndefinite(); err != nil { + t.Fatalf("EndIndefinite() returned error %v", err) + } + if !bytes.Equal(w.Bytes(), want) { + t.Errorf("Encoding mismatch: got %v, want %v", w.Bytes(), want) + } + + var decoded map[string]any + if err := Unmarshal(w.Bytes(), &decoded); err != nil { + t.Fatalf("Unmarshal() returned error %v", err) + } + if len(decoded) != 1 { + t.Fatalf("Unmarshal() returned %d entries, want 1", len(decoded)) + } + v, ok := decoded["a"] + if !ok { + t.Fatalf("decoded map missing key \"a\"") + } + if v != nil { + t.Errorf("decoded[\"a\"] = %v (%T), want nil", v, v) + } +} + func TestIndefiniteLengthError(t *testing.T) { var w bytes.Buffer encoder := NewEncoder(&w)