Skip to content

Commit 57b9352

Browse files
mdenushevMax Denushev
and
Max Denushev
authored
fix: propagate body stream error to close function (#1743) (#1757)
* fix: propagate body stream error to close function (#1743) * fix: http test * fix: close body stream with error in encoding functions * fix: lint --------- Co-authored-by: Max Denushev <[email protected]>
1 parent e88bd48 commit 57b9352

File tree

3 files changed

+51
-31
lines changed

3 files changed

+51
-31
lines changed

Diff for: client.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -2975,12 +2975,12 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret
29752975
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
29762976
if customStreamBody && resp.bodyStream != nil {
29772977
rbs := resp.bodyStream
2978-
resp.bodyStream = newCloseReader(rbs, func() error {
2978+
resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error {
29792979
hc.releaseReader(br)
29802980
if r, ok := rbs.(*requestStream); ok {
29812981
releaseRequestStream(r)
29822982
}
2983-
if closeConn || resp.ConnectionClose() {
2983+
if closeConn || resp.ConnectionClose() || wErr != nil {
29842984
hc.closeConn(cc)
29852985
} else {
29862986
hc.releaseConn(cc)

Diff for: http.go

+48-28
Original file line numberDiff line numberDiff line change
@@ -321,26 +321,31 @@ func (resp *Response) BodyStream() io.Reader {
321321
}
322322

323323
func (resp *Response) CloseBodyStream() error {
324-
return resp.closeBodyStream()
324+
return resp.closeBodyStream(nil)
325+
}
326+
327+
type ReadCloserWithError interface {
328+
io.Reader
329+
CloseWithError(err error) error
325330
}
326331

327332
type closeReader struct {
328333
io.Reader
329-
closeFunc func() error
334+
closeFunc func(err error) error
330335
}
331336

332-
func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser {
337+
func newCloseReaderWithError(r io.Reader, closeFunc func(err error) error) ReadCloserWithError {
333338
if r == nil {
334339
panic(`BUG: reader is nil`)
335340
}
336341
return &closeReader{Reader: r, closeFunc: closeFunc}
337342
}
338343

339-
func (c *closeReader) Close() error {
344+
func (c *closeReader) CloseWithError(err error) error {
340345
if c.closeFunc == nil {
341346
return nil
342347
}
343-
return c.closeFunc()
348+
return c.closeFunc(err)
344349
}
345350

346351
// BodyWriter returns writer for populating request body.
@@ -394,7 +399,7 @@ func (resp *Response) Body() []byte {
394399
bodyBuf := resp.bodyBuffer()
395400
bodyBuf.Reset()
396401
_, err := copyZeroAlloc(bodyBuf, resp.bodyStream)
397-
resp.closeBodyStream() //nolint:errcheck
402+
resp.closeBodyStream(err) //nolint:errcheck
398403
if err != nil {
399404
bodyBuf.SetString(err.Error())
400405
}
@@ -618,7 +623,7 @@ func (req *Request) BodyWriteTo(w io.Writer) error {
618623
func (resp *Response) BodyWriteTo(w io.Writer) error {
619624
if resp.bodyStream != nil {
620625
_, err := copyZeroAlloc(w, resp.bodyStream)
621-
resp.closeBodyStream() //nolint:errcheck
626+
resp.closeBodyStream(err) //nolint:errcheck
622627
return err
623628
}
624629
_, err := w.Write(resp.bodyBytes())
@@ -629,29 +634,29 @@ func (resp *Response) BodyWriteTo(w io.Writer) error {
629634
//
630635
// It is safe re-using p after the function returns.
631636
func (resp *Response) AppendBody(p []byte) {
632-
resp.closeBodyStream() //nolint:errcheck
637+
resp.closeBodyStream(nil) //nolint:errcheck
633638
resp.bodyBuffer().Write(p) //nolint:errcheck
634639
}
635640

636641
// AppendBodyString appends s to response body.
637642
func (resp *Response) AppendBodyString(s string) {
638-
resp.closeBodyStream() //nolint:errcheck
643+
resp.closeBodyStream(nil) //nolint:errcheck
639644
resp.bodyBuffer().WriteString(s) //nolint:errcheck
640645
}
641646

642647
// SetBody sets response body.
643648
//
644649
// It is safe re-using body argument after the function returns.
645650
func (resp *Response) SetBody(body []byte) {
646-
resp.closeBodyStream() //nolint:errcheck
651+
resp.closeBodyStream(nil) //nolint:errcheck
647652
bodyBuf := resp.bodyBuffer()
648653
bodyBuf.Reset()
649654
bodyBuf.Write(body) //nolint:errcheck
650655
}
651656

652657
// SetBodyString sets response body.
653658
func (resp *Response) SetBodyString(body string) {
654-
resp.closeBodyStream() //nolint:errcheck
659+
resp.closeBodyStream(nil) //nolint:errcheck
655660
bodyBuf := resp.bodyBuffer()
656661
bodyBuf.Reset()
657662
bodyBuf.WriteString(body) //nolint:errcheck
@@ -660,7 +665,7 @@ func (resp *Response) SetBodyString(body string) {
660665
// ResetBody resets response body.
661666
func (resp *Response) ResetBody() {
662667
resp.bodyRaw = nil
663-
resp.closeBodyStream() //nolint:errcheck
668+
resp.closeBodyStream(nil) //nolint:errcheck
664669
if resp.body != nil {
665670
if resp.keepBodyBuffer {
666671
resp.body.Reset()
@@ -700,7 +705,7 @@ func (resp *Response) ReleaseBody(size int) {
700705
return
701706
}
702707
if cap(resp.body.B) > size {
703-
resp.closeBodyStream() //nolint:errcheck
708+
resp.closeBodyStream(nil) //nolint:errcheck
704709
resp.body = nil
705710
}
706711
}
@@ -734,7 +739,7 @@ func (resp *Response) SwapBody(body []byte) []byte {
734739
if resp.bodyStream != nil {
735740
bb.Reset()
736741
_, err := copyZeroAlloc(bb, resp.bodyStream)
737-
resp.closeBodyStream() //nolint:errcheck
742+
resp.closeBodyStream(err) //nolint:errcheck
738743
if err != nil {
739744
bb.Reset()
740745
bb.SetString(err.Error())
@@ -1725,10 +1730,13 @@ func (resp *Response) brotliBody(level int) {
17251730
wf: zw,
17261731
bw: sw,
17271732
}
1728-
copyZeroAlloc(fw, bs) //nolint:errcheck
1733+
_, wErr := copyZeroAlloc(fw, bs)
17291734
releaseStacklessBrotliWriter(zw, level)
1730-
if bsc, ok := bs.(io.Closer); ok {
1731-
bsc.Close()
1735+
switch v := bs.(type) {
1736+
case io.Closer:
1737+
v.Close()
1738+
case ReadCloserWithError:
1739+
v.CloseWithError(wErr) //nolint:errcheck
17321740
}
17331741
})
17341742
} else {
@@ -1780,10 +1788,13 @@ func (resp *Response) gzipBody(level int) {
17801788
wf: zw,
17811789
bw: sw,
17821790
}
1783-
copyZeroAlloc(fw, bs) //nolint:errcheck
1791+
_, wErr := copyZeroAlloc(fw, bs)
17841792
releaseStacklessGzipWriter(zw, level)
1785-
if bsc, ok := bs.(io.Closer); ok {
1786-
bsc.Close()
1793+
switch v := bs.(type) {
1794+
case io.Closer:
1795+
v.Close()
1796+
case ReadCloserWithError:
1797+
v.CloseWithError(wErr) //nolint:errcheck
17871798
}
17881799
})
17891800
} else {
@@ -1835,10 +1846,13 @@ func (resp *Response) deflateBody(level int) {
18351846
wf: zw,
18361847
bw: sw,
18371848
}
1838-
copyZeroAlloc(fw, bs) //nolint:errcheck
1849+
_, wErr := copyZeroAlloc(fw, bs)
18391850
releaseStacklessDeflateWriter(zw, level)
1840-
if bsc, ok := bs.(io.Closer); ok {
1841-
bsc.Close()
1851+
switch v := bs.(type) {
1852+
case io.Closer:
1853+
v.Close()
1854+
case ReadCloserWithError:
1855+
v.CloseWithError(wErr) //nolint:errcheck
18421856
}
18431857
})
18441858
} else {
@@ -1887,10 +1901,13 @@ func (resp *Response) zstdBody(level int) {
18871901
wf: zw,
18881902
bw: sw,
18891903
}
1890-
copyZeroAlloc(fw, bs) //nolint:errcheck
1904+
_, wErr := copyZeroAlloc(fw, bs)
18911905
releaseStacklessZstdWriter(zw, level)
1892-
if bsc, ok := bs.(io.Closer); ok {
1893-
bsc.Close()
1906+
switch v := bs.(type) {
1907+
case io.Closer:
1908+
v.Close()
1909+
case ReadCloserWithError:
1910+
v.CloseWithError(wErr) //nolint:errcheck
18941911
}
18951912
})
18961913
} else {
@@ -2053,7 +2070,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error
20532070
}
20542071
}
20552072
}
2056-
errc := resp.closeBodyStream()
2073+
errc := resp.closeBodyStream(err)
20572074
if err == nil {
20582075
err = errc
20592076
}
@@ -2075,14 +2092,17 @@ func (req *Request) closeBodyStream() error {
20752092
return err
20762093
}
20772094

2078-
func (resp *Response) closeBodyStream() error {
2095+
func (resp *Response) closeBodyStream(wErr error) error {
20792096
if resp.bodyStream == nil {
20802097
return nil
20812098
}
20822099
var err error
20832100
if bsc, ok := resp.bodyStream.(io.Closer); ok {
20842101
err = bsc.Close()
20852102
}
2103+
if bsc, ok := resp.bodyStream.(ReadCloserWithError); ok {
2104+
err = bsc.CloseWithError(wErr)
2105+
}
20862106
if bsr, ok := resp.bodyStream.(*requestStream); ok {
20872107
releaseRequestStream(bsr)
20882108
}

Diff for: http_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2943,7 +2943,7 @@ func TestResponseBodyStream(t *testing.T) {
29432943
t.Fatalf("parse response find err: %v", err)
29442944
}
29452945
defer func() {
2946-
if err := response.closeBodyStream(); err != nil {
2946+
if err := response.closeBodyStream(nil); err != nil {
29472947
t.Fatalf("close body stream err: %v", err)
29482948
}
29492949
}()

0 commit comments

Comments
 (0)