From 473d6b8326d73dc58b1dda4f761c7d2f088ce216 Mon Sep 17 00:00:00 2001 From: Thanh Pham Kieu Date: Thu, 8 Sep 2022 00:16:27 +0700 Subject: [PATCH] fix over-reading while de-serializing --- bitset.go | 5 +++-- bitset_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/bitset.go b/bitset.go index 3f274d7..3aab21b 100644 --- a/bitset.go +++ b/bitset.go @@ -945,10 +945,11 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) { // binary.Read for large set reader := bufio.NewReader(stream) var item = make([]byte, binary.Size(uint64(0))) // one uint64 - for i := uint64(0); i < length; i++ { + nWords := uint64(wordsNeeded(uint(length))) + for i := uint64(0); i < nWords; i++ { if _, err := reader.Read(item); err != nil { if err == io.EOF { - break // done + return 0, io.ErrUnexpectedEOF } return 0, err } diff --git a/bitset_test.go b/bitset_test.go index b8eee43..771d78e 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -1270,6 +1270,33 @@ func TestMarshalUnmarshalJSON(t *testing.T) { } } +func TestMarshalUnmarshalJSONWithTrailingData(t *testing.T) { + a := New(1010).Set(10).Set(1001) + data, err := json.Marshal(a) + if err != nil { + t.Errorf(err.Error()) + return + } + + // appending some noise + data = data[:len(data) - 3] // remove " + data = append(data, []byte(`AAAAAAAAAA"`)...) + + b := new(BitSet) + err = json.Unmarshal(data, b) + if err != nil { + t.Errorf(err.Error()) + return + } + + // Bitsets must be equal after marshalling and unmarshalling + // Do not over-reading when unmarshalling + if !a.Equal(b) { + t.Error("Bitsets are not equal:\n\t", a.DumpAsBits(), "\n\t", b.DumpAsBits()) + return + } +} + func TestMarshalUnmarshalJSONByStdEncoding(t *testing.T) { Base64StdEncoding() a := New(1010).Set(10).Set(1001)