From 1ee0185dad51692b26b4ee6f9e111349f0cdb581 Mon Sep 17 00:00:00 2001 From: Spencer Nelson Date: Thu, 5 Oct 2017 15:35:48 -0400 Subject: [PATCH] Check expected invariants while unmarshaling We expect some things of our digests: - The centroids should be in order of increasing mean. - The digest's countTotal field should be the sum of the counts of the centroids (overflowing is forbidden). - Centroids should have non-negative counts. Fuzzing detected cases where these are not true. --- fuzz.go | 29 ++++++++++++++++++++++++-- fuzz_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++----- serde.go | 19 ++++++++++++++--- 3 files changed, 96 insertions(+), 10 deletions(-) diff --git a/fuzz.go b/fuzz.go index df96b8c..59dc5e0 100644 --- a/fuzz.go +++ b/fuzz.go @@ -2,16 +2,41 @@ package tdigest +import ( + "bytes" + "fmt" + "log" + + "github.com/davecgh/go-spew/spew" +) + func Fuzz(data []byte) int { v := new(TDigest) err := v.UnmarshalBinary(data) if err != nil { return 0 } - _, err = v.MarshalBinary() + + remarshaled, err := v.MarshalBinary() if err != nil { panic(err) } - return 1 + if !bytes.HasPrefix(data, remarshaled) { + panic(fmt.Sprintf("not equal: \n%v\nvs\n%v", data, remarshaled)) + } + + for q := float64(0.1); q <= 1.0; q += 0.05 { + prev, this := v.Quantile(q-0.1), v.Quantile(q) + if prev-this > 1e-100 { // Floating point math makes this slightly imprecise. + log.Printf("v: %s", spew.Sprint(v)) + log.Printf("q: %v", q) + log.Printf("prev: %v", prev) + log.Printf("this: %v", this) + panic("quantiles should only increase") + } + } + + v.Add(1, 1) + return 1 } diff --git a/fuzz_test.go b/fuzz_test.go index 9c92f05..45ddf9b 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -1,22 +1,44 @@ package tdigest -import "testing" +import ( + "bytes" + "testing" + + "github.com/davecgh/go-spew/spew" +) func TestFuzzPanicRegressions(t *testing.T) { // This test contains a list of byte sequences discovered by // github.com/dvyukov/go-fuzz which, at one time, caused tdigest to panic. The // test just makes sure that they no longer cause a panic. testcase := func(crasher []byte) func(*testing.T) { - return func(*testing.T) { + return func(t *testing.T) { v := new(TDigest) err := v.UnmarshalBinary(crasher) if err != nil { return } - _, err = v.MarshalBinary() + remarshaled, err := v.MarshalBinary() if err != nil { - panic(err) + t.Fatalf("marshal error: %v", err) + } + + if !bytes.HasPrefix(crasher, remarshaled) { + t.Fatalf("not equal: \n%v\nvs\n%v", crasher, remarshaled) + } + + for q := float64(0.1); q <= 1.0; q += 0.05 { + prev, this := v.Quantile(q-0.1), v.Quantile(q) + if prev-this > 1e-100 { // Floating point math makes this slightly imprecise. + t.Logf("v: %s", spew.Sprint(v)) + t.Logf("q: %v", q) + t.Logf("prev: %v", prev) + t.Logf("this: %v", this) + t.Fatal("quantiles should only increase") + } } + + v.Add(1, 1) } } t.Run("fuzz1", testcase([]byte{ @@ -32,5 +54,31 @@ func TestFuzzPanicRegressions(t *testing.T) { 0x37, 0x35, 0x37, 0x36, 0x37, 0x37, 0x37, 0x38, 0x37, 0x39, 0x28, })) - + t.Run("fuzz3", testcase([]byte{ + 0x80, 0x0c, 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x02, 0x00, + 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0xbf, + })) + t.Run("fuzz4", testcase([]byte{ + 0x80, 0x0c, 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x02, 0x00, + 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x63, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x4e, + })) + t.Run("fuzz5", testcase([]byte{ + 0x80, 0x0c, 0x01, 0x00, 0x00, 0x00, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x02, 0x00, + 0x00, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x30, 0x00, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, + 0x92, 0x00, + })) } diff --git a/serde.go b/serde.go index 3b32679..a8f3d8e 100644 --- a/serde.go +++ b/serde.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "io" + "math" ) const ( @@ -39,11 +40,11 @@ func unmarshalBinary(d *TDigest, p []byte) error { r := &binaryReader{r: bytes.NewReader(p)} r.readValue(&mv) if mv != magic { - return fmt.Errorf("invalid header magic value, data might be corrupted: %d", mv) + return fmt.Errorf("data corruption detected: invalid header magic value %d", mv) } r.readValue(&ev) if ev != encodingVersion { - return fmt.Errorf("invalid encoding version: %d", ev) + return fmt.Errorf("data corruption detected: invalid encoding version %d", ev) } r.readValue(&d.compression) r.readValue(&n) @@ -51,7 +52,7 @@ func unmarshalBinary(d *TDigest, p []byte) error { return r.err } if n < 0 { - return fmt.Errorf("invalid n, cannot be negative: %v", n) + return fmt.Errorf("data corruption detected: number of centroids cannot be negative, have %v", n) } if n > 1<<20 { return fmt.Errorf("invalid n, cannot be greater than 2^20: %v", n) @@ -64,7 +65,19 @@ func unmarshalBinary(d *TDigest, p []byte) error { if r.err != nil { return r.err } + if c.count < 0 { + return fmt.Errorf("data corruption detected: negative count: %d", c.count) + } + if i > 0 { + prev := d.centroids[i-1] + if c.mean < prev.mean { + return fmt.Errorf("data corruption detected: centroid %d has lower mean (%v) than preceding centroid %d (%v)", i, c.mean, i-1, prev.mean) + } + } d.centroids[i] = c + if c.count > math.MaxInt64-d.countTotal { + return fmt.Errorf("data corruption detected: centroid total size overflow") + } d.countTotal += c.count }