diff --git a/bitset.go b/bitset.go index 8883df8..3829a2b 100644 --- a/bitset.go +++ b/bitset.go @@ -33,7 +33,6 @@ Example use: As an alternative to BitSets, one should check out the 'big' package, which provides a (less set-theoretical) view of bitsets. - */ package bitset @@ -434,21 +433,20 @@ func (b *BitSet) NextSet(i uint) (uint, bool) { // including possibly the current index and up to cap(buffer). // If the returned slice has len zero, then no more set bits were found // -// buffer := make([]uint, 256) // this should be reused -// j := uint(0) -// j, buffer = bitmap.NextSetMany(j, buffer) -// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) { -// for k := range buffer { -// do something with buffer[k] -// } -// j += 1 -// } -// +// buffer := make([]uint, 256) // this should be reused +// j := uint(0) +// j, buffer = bitmap.NextSetMany(j, buffer) +// for ; len(buffer) > 0; j, buffer = bitmap.NextSetMany(j,buffer) { +// for k := range buffer { +// do something with buffer[k] +// } +// j += 1 +// } // // It is possible to retrieve all set bits as follow: // -// indices := make([]uint, bitmap.Count()) -// bitmap.NextSetMany(0, indices) +// indices := make([]uint, bitmap.Count()) +// bitmap.NextSetMany(0, indices) // // However if bitmap.Count() is large, it might be preferable to // use several calls to NextSetMany, for performance reasons. @@ -932,6 +930,9 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) { // Read length first err := binary.Read(stream, binaryOrder, &length) if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return 0, err } newset := New(uint(length)) @@ -940,17 +941,17 @@ func (b *BitSet) ReadFrom(stream io.Reader) (int64, error) { return 0, errors.New("unmarshalling error: type mismatch") } - // Read remaining bytes as set - // current implementation bufio.Reader is more memory efficient than - // binary.Read for large set - reader := bufio.NewReader(stream) - var item = make([]byte, binary.Size(uint64(0))) // one uint64 - nWords := uint64(wordsNeeded(uint(length))) - for i := uint64(0); i < nWords; i++ { - if _, err := reader.Read(item); err != nil { + var item [8]byte + nWords := wordsNeeded(uint(length)) + reader := bufio.NewReader(io.LimitReader(stream, 8*int64(nWords))) + for i := 0; i < nWords; i++ { + if _, err := reader.Read(item[:]); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } return 0, err } - newset.set[i] = binaryOrder.Uint64(item) + newset.set[i] = binaryOrder.Uint64(item[:]) } *b = *newset diff --git a/bitset_test.go b/bitset_test.go index 771d78e..8d86be1 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -7,9 +7,15 @@ package bitset import ( + "bytes" + "compress/gzip" "encoding" + "encoding/base64" + "encoding/binary" "encoding/json" + "errors" "fmt" + "io" "math" "strconv" "testing" @@ -1222,6 +1228,10 @@ func TestMarshalUnmarshalBinary(t *testing.T) { func TestMarshalUnmarshalBinaryByLittleEndian(t *testing.T) { LittleEndian() + defer func() { + // Revert when done. + binaryOrder = binary.BigEndian + }() a := New(1010).Set(10).Set(1001) b := new(BitSet) @@ -1279,7 +1289,7 @@ func TestMarshalUnmarshalJSONWithTrailingData(t *testing.T) { } // appending some noise - data = data[:len(data) - 3] // remove " + data = data[:len(data)-3] // remove " data = append(data, []byte(`AAAAAAAAAA"`)...) b := new(BitSet) @@ -1644,3 +1654,145 @@ func TestDeleteWithBitSetInstance(t *testing.T) { } } + +func TestWriteTo(t *testing.T) { + const length = 9585 + const oneEvery = 97 + addBuf := []byte(`12345678`) + bs := New(length) + // Add some bits + for i := uint(0); i < length; i += oneEvery { + bs = bs.Set(i) + } + + var buf bytes.Buffer + n, err := bs.WriteTo(&buf) + if err != nil { + t.Fatal(err) + } + wantSz := buf.Len() // Size of the serialized data in bytes. + if n != int64(wantSz) { + t.Errorf("want write size to be %d, got %d", wantSz, n) + } + buf.Write(addBuf) // Add additional data on stream. + + // Generate test input for regression tests: + if false { + gzout := bytes.NewBuffer(nil) + gz, err := gzip.NewWriterLevel(gzout, 9) + if err != nil { + t.Fatal(err) + } + gz.Write(buf.Bytes()) + gz.Close() + t.Log("Encoded:", base64.StdEncoding.EncodeToString(gzout.Bytes())) + } + + // Read back. + bs = New(length) + n, err = bs.ReadFrom(&buf) + if err != nil { + t.Fatal(err) + } + if n != int64(wantSz) { + t.Errorf("want read size to be %d, got %d", wantSz, n) + } + // Check bits + for i := uint(0); i < length; i += oneEvery { + if !bs.Test(i) { + t.Errorf("bit %d was not set", i) + } + } + + more, err := io.ReadAll(&buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(more, addBuf) { + t.Fatalf("extra mismatch. got %v, want %v", more, addBuf) + } +} + +func TestReadFrom(t *testing.T) { + addBuf := []byte(`12345678`) // Bytes after stream + tests := []struct { + length uint + oneEvery uint + input string // base64+gzipped + wantErr error + }{ + { + length: 9585, + oneEvery: 97, + input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwMSACVgYGBg4sIgLMDAwKGARd2BgYGjAFB41noDx6IAJajw64IAajw4UoMajg4ZR4/EaP5pQh1g+MDQyNjE1M7cABAAA//9W5OoOwAQAAA==", + }, + { + length: 1337, + oneEvery: 42, + input: "H4sIAAAAAAAC/2IAA1ZLBgYWEIPRAUQKgJkMcCZYisEBzkSSYkSTYqCxAYZGxiamZuYWgAAAAP//D0wyWbgAAAA=", + }, + { + length: 1337, // Truncated input. + oneEvery: 42, + input: "H4sIAAAAAAAC/2IAA9VCCM3AyMDAwARmAQIAAP//vR3xdRkAAAA=", + wantErr: io.ErrUnexpectedEOF, + }, + { + length: 1337, // Empty input. + oneEvery: 42, + input: "H4sIAAAAAAAC/wEAAP//AAAAAAAAAAA=", + wantErr: io.ErrUnexpectedEOF, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + fatalErr := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + var buf bytes.Buffer + b, err := base64.StdEncoding.DecodeString(test.input) + fatalErr(err) + gz, err := gzip.NewReader(bytes.NewBuffer(b)) + fatalErr(err) + _, err = io.Copy(&buf, gz) + fatalErr(err) + fatalErr(gz.Close()) + + bs := New(test.length) + _, err = bs.ReadFrom(&buf) + if err != nil { + if errors.Is(err, test.wantErr) { + // Correct, nothing more we can test. + return + } + t.Fatalf("did not get expected error %v, got %v", test.wantErr, err) + } else { + if test.wantErr != nil { + t.Fatalf("did not get expected error %v", test.wantErr) + } + } + fatalErr(err) + + // Test if correct bits are set. + for i := uint(0); i < test.length; i++ { + want := i%test.oneEvery == 0 + got := bs.Test(i) + if want != got { + t.Errorf("bit %d was %v, should be %v", i, got, want) + } + } + + more, err := io.ReadAll(&buf) + fatalErr(err) + + if !bytes.Equal(more, addBuf) { + t.Errorf("extra mismatch. got %v, want %v", more, addBuf) + } + }) + } +}