Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: Add configurable Decoder window size #394

Merged
merged 1 commit into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions zstd/blockdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {

// Read block data.
if cap(b.dataStorage) < cSize {
if b.lowMem {
if b.lowMem || cSize > maxCompressedBlockSize {
b.dataStorage = make([]byte, 0, cSize)
} else {
b.dataStorage = make([]byte, 0, maxBlockSize)
b.dataStorage = make([]byte, 0, maxCompressedBlockSize)
}
}
if cap(b.dst) <= maxSize {
Expand Down
25 changes: 22 additions & 3 deletions zstd/decoder_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ type decoderOptions struct {
lowMem bool
concurrent int
maxDecodedSize uint64
maxWindowSize uint64
dicts []dict
}

func (o *decoderOptions) setDefault() {
*o = decoderOptions{
// use less ram: true for now, but may change.
lowMem: true,
concurrent: runtime.GOMAXPROCS(0),
lowMem: true,
concurrent: runtime.GOMAXPROCS(0),
maxWindowSize: MaxWindowSize,
}
o.maxDecodedSize = 1 << 63
}
Expand Down Expand Up @@ -52,7 +54,6 @@ func WithDecoderConcurrency(n int) DOption {
// WithDecoderMaxMemory allows to set a maximum decoded size for in-memory
// non-streaming operations or maximum window size for streaming operations.
// This can be used to control memory usage of potentially hostile content.
// For streaming operations, the maximum window size is capped at 1<<30 bytes.
// Maximum and default is 1 << 63 bytes.
func WithDecoderMaxMemory(n uint64) DOption {
return func(o *decoderOptions) error {
Expand Down Expand Up @@ -81,3 +82,21 @@ func WithDecoderDicts(dicts ...[]byte) DOption {
return nil
}
}

// WithDecoderMaxWindow allows to set a maximum window size for decodes.
// This allows rejecting packets that will cause big memory usage.
// The Decoder will likely allocate more memory based on the WithDecoderLowmem setting.
// If WithDecoderMaxMemory is set to a lower value, that will be used.
// Default is 512MB, Maximum is ~3.75 TB as per zstandard spec.
func WithDecoderMaxWindow(size uint64) DOption {
return func(o *decoderOptions) error {
if size < MinWindowSize {
return errors.New("WithMaxWindowSize must be at least 1KB, 1024 bytes")
}
if size > (1<<41)+7*(1<<38) {
return errors.New("WithMaxWindowSize must be less than (1<<41) + 7*(1<<38) ~ 3.75TB")
}
o.maxWindowSize = size
return nil
}
}
122 changes: 117 additions & 5 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestNewDecoder(t *testing.T) {
func TestNewDecoderMemory(t *testing.T) {
defer timeout(60 * time.Second)()
var testdata bytes.Buffer
enc, err := NewWriter(&testdata, WithWindowSize(64<<10), WithSingleSegment(false))
enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false))
if err != nil {
t.Fatal(err)
}
Expand All @@ -200,6 +200,9 @@ func TestNewDecoderMemory(t *testing.T) {
n = 200
}

// 16K buffer
var tmp [16 << 10]byte

var before, after runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&before)
Expand All @@ -214,8 +217,6 @@ func TestNewDecoderMemory(t *testing.T) {
}
}

// 32K buffer
var tmp [128 << 10]byte
for i := range decs {
_, err := io.ReadFull(decs[i], tmp[:])
if err != nil {
Expand All @@ -226,17 +227,128 @@ func TestNewDecoderMemory(t *testing.T) {
runtime.GC()
runtime.ReadMemStats(&after)
size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024

const expect = 124
t.Log(size, "KiB per decoder")
// This is not exact science, but fail if we suddenly get more than 2x what we expect.
if size > 221*2 && !testing.Short() {
t.Errorf("expected < 221KB per decoder, got %d", size)
if size > expect*2 && !testing.Short() {
t.Errorf("expected < %dKB per decoder, got %d", expect, size)
}

for _, dec := range decs {
dec.Close()
}
}

func TestNewDecoderMemoryHighMem(t *testing.T) {
defer timeout(60 * time.Second)()
var testdata bytes.Buffer
enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false))
if err != nil {
t.Fatal(err)
}
// Write 256KB
for i := 0; i < 256; i++ {
tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
_, err := enc.Write([]byte(tmp))
if err != nil {
t.Fatal(err)
}
}
err = enc.Close()
if err != nil {
t.Fatal(err)
}

var n = 50
if testing.Short() {
n = 10
}

// 16K buffer
var tmp [16 << 10]byte

var before, after runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&before)

var decs = make([]*Decoder, n)
for i := range decs {
// Wrap in NopCloser to avoid shortcut.
input := ioutil.NopCloser(bytes.NewBuffer(testdata.Bytes()))
decs[i], err = NewReader(input, WithDecoderConcurrency(1), WithDecoderLowmem(false))
if err != nil {
t.Fatal(err)
}
}

for i := range decs {
_, err := io.ReadFull(decs[i], tmp[:])
if err != nil {
t.Fatal(err)
}
}

runtime.GC()
runtime.ReadMemStats(&after)
size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024

const expect = 3915
t.Log(size, "KiB per decoder")
// This is not exact science, but fail if we suddenly get more than 2x what we expect.
if size > expect*2 && !testing.Short() {
t.Errorf("expected < %dKB per decoder, got %d", expect, size)
}

for _, dec := range decs {
dec.Close()
}
}

func TestNewDecoderFrameSize(t *testing.T) {
defer timeout(60 * time.Second)()
var testdata bytes.Buffer
enc, err := NewWriter(&testdata, WithWindowSize(64<<10))
if err != nil {
t.Fatal(err)
}
// Write 256KB
for i := 0; i < 256; i++ {
tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
_, err := enc.Write([]byte(tmp))
if err != nil {
t.Fatal(err)
}
}
err = enc.Close()
if err != nil {
t.Fatal(err)
}
// Must fail
dec, err := NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(32<<10))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(ioutil.Discard, dec)
if err == nil {
dec.Close()
t.Fatal("Wanted error, got none")
}
dec.Close()

// Must succeed.
dec, err = NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(64<<10))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(ioutil.Discard, dec)
if err != nil {
dec.Close()
t.Fatalf("Wanted no error, got %+v", err)
}
dec.Close()
}

func TestNewDecoderGood(t *testing.T) {
defer timeout(30 * time.Second)()
testDecoderFile(t, "testdata/good.zip")
Expand Down
32 changes: 17 additions & 15 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ type frameDec struct {

WindowSize uint64

// maxWindowSize is the maximum windows size to support.
// should never be bigger than max-int.
maxWindowSize uint64

// In order queue of blocks being decoded.
decoding chan *blockDec

Expand All @@ -50,8 +46,11 @@ type frameDec struct {
}

const (
// The minimum Window_Size is 1 KB.
// MinWindowSize is the minimum Window Size, which is 1 KB.
MinWindowSize = 1 << 10

// MaxWindowSize is the maximum encoder window size
// and the default decoder maximum window size.
MaxWindowSize = 1 << 29
)

Expand All @@ -61,12 +60,11 @@ var (
)

func newFrameDec(o decoderOptions) *frameDec {
d := frameDec{
o: o,
maxWindowSize: MaxWindowSize,
if o.maxWindowSize > o.maxDecodedSize {
o.maxWindowSize = o.maxDecodedSize
}
if d.maxWindowSize > o.maxDecodedSize {
d.maxWindowSize = o.maxDecodedSize
d := frameDec{
o: o,
}
return &d
}
Expand Down Expand Up @@ -251,13 +249,17 @@ func (d *frameDec) reset(br byteBuffer) error {
}
}

if d.WindowSize > d.maxWindowSize {
printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
if d.WindowSize > uint64(d.o.maxWindowSize) {
if debugDecoder {
printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
}
return ErrWindowSizeExceeded
}
// The minimum Window_Size is 1 KB.
if d.WindowSize < MinWindowSize {
println("got window size: ", d.WindowSize)
if debugDecoder {
println("got window size: ", d.WindowSize)
}
return ErrWindowSizeTooSmall
}
d.history.windowSize = int(d.WindowSize)
Expand Down Expand Up @@ -352,8 +354,8 @@ func (d *frameDec) checkCRC() error {

func (d *frameDec) initAsync() {
if !d.o.lowMem && !d.SingleSegment {
// set max extra size history to 10MB.
d.history.maxSize = d.history.windowSize + maxBlockSize*5
// set max extra size history to 2MB.
d.history.maxSize = d.history.windowSize + maxBlockSize
}
// re-alloc if more than one extra block size.
if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
Expand Down