Skip to content

Commit

Permalink
add support for sending error codes on stream reset
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Aug 27, 2024
1 parent 8adb9a8 commit 945e586
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 7 deletions.
21 changes: 21 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ func (e *GoAwayError) Is(target error) bool {
return false
}

// A StreamError is used for errors returned from Read and Write calls after the stream is Reset
type StreamError struct {
ErrorCode uint32
Remote bool
}

func (s *StreamError) Error() string {
if s.Remote {
return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode)

Check warning on line 68 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L66-L68

Added lines #L66 - L68 were not covered by tests
}
return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode)

Check warning on line 70 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L70

Added line #L70 was not covered by tests
}

func (s *StreamError) Is(target error) bool {
if target == ErrStreamReset {
return true
}
e, ok := target.(*StreamError)
return ok && *e == *s

Check warning on line 78 in const.go

View check run for this annotation

Codecov / codecov/patch

const.go#L77-L78

Added lines #L77 - L78 were not covered by tests
}

var (
// ErrInvalidVersion means we received a frame with an
// invalid version
Expand Down
53 changes: 52 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) {
wc.Wait()
}

func TestStreamResetWithError(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

wc := new(sync.WaitGroup)
wc.Add(2)
go func() {
defer wc.Done()
stream, err := server.AcceptStream()
if err != nil {
t.Error(err)
}

se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: true, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

stream, err := client.OpenStream(context.Background())
if err != nil {
t.Error(err)
}

go func() {
defer wc.Done()

se := &StreamError{}
_, err = io.ReadAll(stream)
if !errors.As(err, &se) {
t.Errorf("exptected StreamError, got type:%T, err: %s", err, err)
return
}
expected := &StreamError{Remote: false, ErrorCode: 42}
assert.Equal(t, se, expected)
}()

time.Sleep(1 * time.Second)
err = stream.ResetWithError(42)
if err != nil {
t.Fatal(err)
}
wc.Wait()
}

func TestLotsOfWritesWithStreamDeadline(t *testing.T) {
config := testConf()
config.EnableKeepAlive = false
Expand Down Expand Up @@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) {
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.EqualError(t, err, "stream reset")
require.ErrorIs(t, err, ErrStreamReset)

// Now close one of the streams.
// This should then allow the client to open a new stream.
Expand Down
26 changes: 20 additions & 6 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Stream struct {
state streamState
writeState, readState halfStreamState
stateLock sync.Mutex
resetErr *StreamError

recvBuf segmentedBuffer

Expand Down Expand Up @@ -101,7 +102,7 @@ START:
}
// Closed, but we have data pending -> read.
case halfReset:
return 0, ErrStreamReset
return 0, s.resetErr
default:
panic("unknown state")
}
Expand Down Expand Up @@ -155,7 +156,7 @@ START:
case halfClosed:
return 0, ErrStreamClosed
case halfReset:
return 0, ErrStreamReset
return 0, s.resetErr
default:
panic("unknown state")
}
Expand Down Expand Up @@ -251,12 +252,20 @@ func (s *Stream) sendClose() error {

// sendReset is used to send a RST
func (s *Stream) sendReset() error {
hdr := encode(typeWindowUpdate, flagRST, s.id, 0)
code := uint32(0)
if s.resetErr != nil {
code = s.resetErr.ErrorCode
}
hdr := encode(typeWindowUpdate, flagRST, s.id, code)
return s.session.sendMsg(hdr, nil, nil)
}

// Reset resets the stream (forcibly closes the stream)
func (s *Stream) Reset() error {
return s.ResetWithError(0)
}

func (s *Stream) ResetWithError(errCode uint32) error {
sendReset := false
s.stateLock.Lock()
switch s.state {
Expand All @@ -281,6 +290,7 @@ func (s *Stream) Reset() error {
s.readState = halfReset
}
s.state = streamFinished
s.resetErr = &StreamError{Remote: false, ErrorCode: errCode}
s.notifyWaiting()
s.stateLock.Unlock()
if sendReset {
Expand Down Expand Up @@ -382,7 +392,7 @@ func (s *Stream) cleanup() {

// processFlags is used to update the state of the stream
// based on set flags, if any. Lock must be held
func (s *Stream) processFlags(flags uint16) {
func (s *Stream) processFlags(flags uint16, hdr header) {
// Close the stream without holding the state lock
var closeStream bool
defer func() {
Expand Down Expand Up @@ -425,6 +435,10 @@ func (s *Stream) processFlags(flags uint16) {
s.writeState = halfReset
}
s.state = streamFinished
// Length in a window update frame with RST flag encodes an error code.
if hdr.MsgType() == typeWindowUpdate && s.resetErr == nil {
s.resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()}
}
s.stateLock.Unlock()
closeStream = true
s.notifyWaiting()
Expand All @@ -439,15 +453,15 @@ func (s *Stream) notifyWaiting() {

// incrSendWindow updates the size of our send window
func (s *Stream) incrSendWindow(hdr header, flags uint16) {
s.processFlags(flags)
s.processFlags(flags, hdr)
// Increase window, unblock a sender
atomic.AddUint32(&s.sendWindow, hdr.Length())
asyncNotify(s.sendNotifyCh)
}

// readData is used to handle a data frame
func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
s.processFlags(flags)
s.processFlags(flags, hdr)

// Check that our recv window is not exceeded
length := hdr.Length()
Expand Down

0 comments on commit 945e586

Please sign in to comment.