From 2719b1a9a8d0d8f3eb78162da9b23a9eaaabda1b Mon Sep 17 00:00:00 2001 From: Gregory Man Date: Thu, 20 Sep 2018 11:48:17 +0300 Subject: [PATCH] compress/flate: return error on closed stream write Previously flate.Writer allowed writes after Close, and this behavior could lead to stream corruption. Fixes #27741 Change-Id: Iee1ac69f8199232f693dba77b275f7078257b582 Reviewed-on: https://go-review.googlesource.com/c/go/+/136475 Run-TryBot: Ian Lance Taylor TryBot-Result: Gopher Robot Auto-Submit: Ian Lance Taylor Reviewed-by: Joseph Tsai Reviewed-by: Carlos Amedee Reviewed-by: Ian Lance Taylor Run-TryBot: Ian Lance Taylor --- src/compress/flate/deflate.go | 36 +++++++++++- src/compress/flate/deflate_test.go | 88 +++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 4 deletions(-) diff --git a/src/compress/flate/deflate.go b/src/compress/flate/deflate.go index 550032176d76c..ccf03d74eb211 100644 --- a/src/compress/flate/deflate.go +++ b/src/compress/flate/deflate.go @@ -5,6 +5,7 @@ package flate import ( + "errors" "fmt" "io" "math" @@ -699,17 +700,27 @@ func (w *dictWriter) Write(b []byte) (n int, err error) { return w.w.Write(b) } +var errWriteAfterClose = errors.New("compress/flate: write after close") + // A Writer takes data written to it and writes the compressed // form of that data to an underlying writer (see NewWriter). type Writer struct { d compressor dict []byte + err error } // Write writes data to w, which will eventually write the // compressed form of data to its underlying writer. func (w *Writer) Write(data []byte) (n int, err error) { - return w.d.write(data) + if w.err != nil { + return 0, w.err + } + n, err = w.d.write(data) + if err != nil { + w.err = err + } + return n, err } // Flush flushes any pending data to the underlying writer. @@ -724,18 +735,37 @@ func (w *Writer) Write(data []byte) (n int, err error) { func (w *Writer) Flush() error { // For more about flushing: // https://www.bolet.org/~pornin/deflate-flush.html - return w.d.syncFlush() + if w.err != nil { + return w.err + } + if err := w.d.syncFlush(); err != nil { + w.err = err + return err + } + return nil } // Close flushes and closes the writer. func (w *Writer) Close() error { - return w.d.close() + if w.err == errWriteAfterClose { + return nil + } + if w.err != nil { + return w.err + } + if err := w.d.close(); err != nil { + w.err = err + return err + } + w.err = errWriteAfterClose + return nil } // Reset discards the writer's state and makes it equivalent to // the result of NewWriter or NewWriterDict called with dst // and w's level and dictionary. func (w *Writer) Reset(dst io.Writer) { + w.err = nil if dw, ok := w.d.w.writer.(*dictWriter); ok { // w was created with NewWriterDict dw.w = dst diff --git a/src/compress/flate/deflate_test.go b/src/compress/flate/deflate_test.go index ff567121236d2..8c9ee72feb47e 100644 --- a/src/compress/flate/deflate_test.go +++ b/src/compress/flate/deflate_test.go @@ -125,6 +125,40 @@ func TestDeflate(t *testing.T) { } } +func TestWriterClose(t *testing.T) { + b := new(bytes.Buffer) + zw, err := NewWriter(b, 6) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + + if c, err := zw.Write([]byte("Test")); err != nil || c != 4 { + t.Fatalf("Write to not closed writer: %s, %d", err, c) + } + + if err := zw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + afterClose := b.Len() + + if c, err := zw.Write([]byte("Test")); err == nil || c != 0 { + t.Fatalf("Write to closed writer: %s, %d", err, c) + } + + if err := zw.Flush(); err == nil { + t.Fatalf("Flush to closed writer: %s", err) + } + + if err := zw.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + if afterClose != b.Len() { + t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len()) + } +} + // A sparseReader returns a stream consisting of 0s followed by 1<<16 1s. // This tests missing hash references in a very large input. type sparseReader struct { @@ -683,7 +717,7 @@ func (w *failWriter) Write(b []byte) (int, error) { return len(b), nil } -func TestWriterPersistentError(t *testing.T) { +func TestWriterPersistentWriteError(t *testing.T) { t.Parallel() d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt") if err != nil { @@ -706,12 +740,16 @@ func TestWriterPersistentError(t *testing.T) { _, werr := zw.Write(d) cerr := zw.Close() + ferr := zw.Flush() if werr != errIO && werr != nil { t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO) } if cerr != errIO && fw.n < 0 { t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO) } + if ferr != errIO && fw.n < 0 { + t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO) + } if fw.n >= 0 { // At this point, the failure threshold was sufficiently high enough // that we wrote the whole stream without any errors. @@ -719,6 +757,54 @@ func TestWriterPersistentError(t *testing.T) { } } } +func TestWriterPersistentFlushError(t *testing.T) { + zw, err := NewWriter(&failWriter{0}, DefaultCompression) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + flushErr := zw.Flush() + closeErr := zw.Close() + _, writeErr := zw.Write([]byte("Test")) + checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t) +} + +func TestWriterPersistentCloseError(t *testing.T) { + // If underlying writer return error on closing stream we should persistent this error across all writer calls. + zw, err := NewWriter(&failWriter{0}, DefaultCompression) + if err != nil { + t.Fatalf("NewWriter: %v", err) + } + closeErr := zw.Close() + flushErr := zw.Flush() + _, writeErr := zw.Write([]byte("Test")) + checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t) + + // After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil + // on next Close calls. + var b bytes.Buffer + zw.Reset(&b) + err = zw.Close() + if err != nil { + t.Fatalf("First call to close returned error: %s", err) + } + err = zw.Close() + if err != nil { + t.Fatalf("Second call to close returned error: %s", err) + } + + flushErr = zw.Flush() + _, writeErr = zw.Write([]byte("Test")) + checkErrors([]error{flushErr, writeErr}, errWriteAfterClose, t) +} + +func checkErrors(got []error, want error, t *testing.T) { + t.Helper() + for _, err := range got { + if err != want { + t.Errorf("Errors dosn't match\nWant: %s\nGot: %s", want, got) + } + } +} func TestBestSpeedMatch(t *testing.T) { t.Parallel()