Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ReadFrom over-read #109

Merged
merged 2 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions bitset.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ Example use:

As an alternative to BitSets, one should check out the 'big' package,
which provides a (less set-theoretical) view of bitsets.

*/
package bitset

Expand Down Expand Up @@ -434,21 +433,20 @@ func (b *BitSet) NextSet(i uint) (uint, bool) {
// including possibly the current index and up to cap(buffer).
// If the returned slice has len zero, then no more set bits were found
//
// buffer := make([]uint, 256) // this should be reused
// j := uint(0)
// j, buffer = bitmap.NextSetMany(j, buffer)
// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) {
// for k := range buffer {
// do something with buffer[k]
// }
// j += 1
// }
//
// buffer := make([]uint, 256) // this should be reused
// j := uint(0)
// j, buffer = bitmap.NextSetMany(j, buffer)
// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) {
// for k := range buffer {
// do something with buffer[k]
// }
// j += 1
// }
//
// It is possible to retrieve all set bits as follow:
//
// indices := make([]uint, bitmap.Count())
// bitmap.NextSetMany(0, indices)
// indices := make([]uint, bitmap.Count())
// bitmap.NextSetMany(0, indices)
//
// However if bitmap.Count() is large, it might be preferable to
// use several calls to NextSetMany, for performance reasons.
Expand Down Expand Up @@ -932,6 +930,9 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) {
// Read length first
err := binary.Read(stream, binaryOrder, &length)
if err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, err
}
newset := New(uint(length))
Expand All @@ -940,17 +941,17 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) {
return 0, errors.New("unmarshalling error: type mismatch")
}

// Read remaining bytes as set
// current implementation bufio.Reader is more memory efficient than
// binary.Read for large set
reader := bufio.NewReader(stream)
var item = make([]byte, binary.Size(uint64(0))) // one uint64
nWords := uint64(wordsNeeded(uint(length)))
for i := uint64(0); i < nWords; i++ {
if _, err := reader.Read(item); err != nil {
var item [8]byte
nWords := wordsNeeded(uint(length))
reader := bufio.NewReader(io.LimitReader(stream, 8*int64(nWords)))
for i := 0; i < nWords; i++ {
if _, err := reader.Read(item[:]); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, err
}
newset.set[i] = binaryOrder.Uint64(item)
newset.set[i] = binaryOrder.Uint64(item[:])
}

*b = *newset
Expand Down
154 changes: 153 additions & 1 deletion bitset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
package bitset

import (
"bytes"
"compress/gzip"
"encoding"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"strconv"
"testing"
Expand Down Expand Up @@ -1222,6 +1228,10 @@ func TestMarshalUnmarshalBinary(t *testing.T) {

func TestMarshalUnmarshalBinaryByLittleEndian(t *testing.T) {
LittleEndian()
defer func() {
// Revert when done.
binaryOrder = binary.BigEndian
}()
a := New(1010).Set(10).Set(1001)
b := new(BitSet)

Expand Down Expand Up @@ -1279,7 +1289,7 @@ func TestMarshalUnmarshalJSONWithTrailingData(t *testing.T) {
}

// appending some noise
data = data[:len(data) - 3] // remove "
data = data[:len(data)-3] // remove "
data = append(data, []byte(`AAAAAAAAAA"`)...)

b := new(BitSet)
Expand Down Expand Up @@ -1644,3 +1654,145 @@ func TestDeleteWithBitSetInstance(t *testing.T) {

}
}

func TestWriteTo(t *testing.T) {
const length = 9585
const oneEvery = 97
addBuf := []byte(`12345678`)
bs := New(length)
// Add some bits
for i := uint(0); i < length; i += oneEvery {
bs = bs.Set(i)
}

var buf bytes.Buffer
n, err := bs.WriteTo(&buf)
if err != nil {
t.Fatal(err)
}
wantSz := buf.Len() // Size of the serialized data in bytes.
if n != int64(wantSz) {
t.Errorf("want write size to be %d, got %d", wantSz, n)
}
buf.Write(addBuf) // Add additional data on stream.

// Generate test input for regression tests:
if false {
gzout := bytes.NewBuffer(nil)
gz, err := gzip.NewWriterLevel(gzout, 9)
if err != nil {
t.Fatal(err)
}
gz.Write(buf.Bytes())
gz.Close()
t.Log("Encoded:", base64.StdEncoding.EncodeToString(gzout.Bytes()))
}

// Read back.
bs = New(length)
n, err = bs.ReadFrom(&buf)
if err != nil {
t.Fatal(err)
}
if n != int64(wantSz) {
t.Errorf("want read size to be %d, got %d", wantSz, n)
}
// Check bits
for i := uint(0); i < length; i += oneEvery {
if !bs.Test(i) {
t.Errorf("bit %d was not set", i)
}
}

more, err := io.ReadAll(&buf)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(more, addBuf) {
t.Fatalf("extra mismatch. got %v, want %v", more, addBuf)
}
}

func TestReadFrom(t *testing.T) {
addBuf := []byte(`12345678`) // Bytes after stream
tests := []struct {
length uint
oneEvery uint
input string // base64+gzipped
wantErr error
}{
{
length: 9585,
oneEvery: 97,
input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwMSACVgYGBg4sIgLMDAwKGARd2BgYGjAFB41noDx6IAJajw64IAajw4UoMajg4ZR4/EaP5pQh1g+MDQyNjE1M7cABAAA//9W5OoOwAQAAA==",
},
{
length: 1337,
oneEvery: 42,
input: "H4sIAAAAAAAC/2IAA1ZLBgYWEIPRAUQKgJkMcCZYisEBzkSSYkSTYqCxAYZGxiamZuYWgAAAAP//D0wyWbgAAAA=",
},
{
length: 1337, // Truncated input.
oneEvery: 42,
input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwARmAQIAAP//vR3xdRkAAAA=",
wantErr: io.ErrUnexpectedEOF,
},
{
length: 1337, // Empty input.
oneEvery: 42,
input: "H4sIAAAAAAAC/wEAAP//AAAAAAAAAAA=",
wantErr: io.ErrUnexpectedEOF,
},
}

for i, test := range tests {
t.Run(fmt.Sprint(i), func(t *testing.T) {
fatalErr := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}

var buf bytes.Buffer
b, err := base64.StdEncoding.DecodeString(test.input)
fatalErr(err)
gz, err := gzip.NewReader(bytes.NewBuffer(b))
fatalErr(err)
_, err = io.Copy(&buf, gz)
fatalErr(err)
fatalErr(gz.Close())

bs := New(test.length)
_, err = bs.ReadFrom(&buf)
if err != nil {
if errors.Is(err, test.wantErr) {
// Correct, nothing more we can test.
return
}
t.Fatalf("did not get expected error %v, got %v", test.wantErr, err)
} else {
if test.wantErr != nil {
t.Fatalf("did not get expected error %v", test.wantErr)
}
}
fatalErr(err)

// Test if correct bits are set.
for i := uint(0); i < test.length; i++ {
want := i%test.oneEvery == 0
got := bs.Test(i)
if want != got {
t.Errorf("bit %d was %v, should be %v", i, got, want)
}
}

more, err := io.ReadAll(&buf)
fatalErr(err)

if !bytes.Equal(more, addBuf) {
t.Errorf("extra mismatch. got %v, want %v", more, addBuf)
}
})
}
}