Skip to content

Commit

Permalink
Merge pull request #3241 from lucas-clemente/fix-stream-cancel-read-race
Browse files Browse the repository at this point in the history
fix race when stream.Read and CancelRead are called concurrently
  • Loading branch information
marten-seemann authored Aug 5, 2021
2 parents b54cc07 + fbc30cd commit be68f7f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
54 changes: 50 additions & 4 deletions integrationtests/self/cancelation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ var _ = Describe("Stream Cancelations", func() {

// The server accepts a single session, and then opens numStreams unidirectional streams.
// On each of these streams, it (tries to) write PRData.
runServer := func() <-chan int32 {
// When done, it sends the number of canceled streams on the channel.
runServer := func(data []byte) <-chan int32 {
numCanceledStreamsChan := make(chan int32)
var err error
server, err = quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
Expand All @@ -44,7 +45,7 @@ var _ = Describe("Stream Cancelations", func() {
defer wg.Done()
str, err := sess.OpenUniStreamSync(context.Background())
Expect(err).ToNot(HaveOccurred())
if _, err := str.Write(PRData); err != nil {
if _, err := str.Write(data); err != nil {
Expect(err).To(MatchError(&quic.StreamError{
StreamID: str.StreamID(),
ErrorCode: quic.StreamErrorCode(str.StreamID()),
Expand All @@ -70,7 +71,7 @@ var _ = Describe("Stream Cancelations", func() {
})

It("downloads when the client immediately cancels most streams", func() {
serverCanceledCounterChan := runServer()
serverCanceledCounterChan := runServer(PRData)
sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
Expand Down Expand Up @@ -113,7 +114,7 @@ var _ = Describe("Stream Cancelations", func() {
})

It("downloads when the client cancels streams after reading from them for a bit", func() {
serverCanceledCounterChan := runServer()
serverCanceledCounterChan := runServer(PRData)

sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
Expand Down Expand Up @@ -159,6 +160,51 @@ var _ = Describe("Stream Cancelations", func() {
Expect(clientCanceledCounter).To(BeNumerically(">", numStreams/10))
Expect(numStreams - clientCanceledCounter).To(BeNumerically(">", numStreams/10))
})

It("allows concurrent Read and CancelRead calls", func() {
// This test is especially valuable when run with race detector,
// see https://github.com/lucas-clemente/quic-go/issues/3239.
serverCanceledCounterChan := runServer(make([]byte, 100)) // make sure the FIN is sent with the STREAM frame

sess, err := quic.DialAddr(
fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
getTLSClientConfig(),
getQuicConfig(&quic.Config{MaxIncomingUniStreams: numStreams / 2}),
)
Expect(err).ToNot(HaveOccurred())

var wg sync.WaitGroup
wg.Add(numStreams)
var counter int32
for i := 0; i < numStreams; i++ {
go func() {
defer GinkgoRecover()
defer wg.Done()
str, err := sess.AcceptUniStream(context.Background())
Expect(err).ToNot(HaveOccurred())

done := make(chan struct{})
go func() {
defer GinkgoRecover()
defer close(done)
b := make([]byte, 32)
if _, err := str.Read(b); err != nil {
atomic.AddInt32(&counter, 1)
Expect(err.Error()).To(ContainSubstring("canceled with error code 1234"))
return
}
}()
go str.CancelRead(1234)
Eventually(done).Should(BeClosed())
}()
}
wg.Wait()
Expect(sess.CloseWithError(0, "")).To(Succeed())
numCanceled := atomic.LoadInt32(&counter)
fmt.Fprintf(GinkgoWriter, "canceled %d out of %d streams", numCanceled, numStreams)
Expect(numCanceled).ToNot(BeZero())
Eventually(serverCanceledCounterChan).Should(Receive())
})
})

Context("canceling the write side", func() {
Expand Down
5 changes: 4 additions & 1 deletion internal/flowcontrol/stream_flow_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
}

func (c *streamFlowController) Abandon() {
if unread := c.highestReceived - c.bytesRead; unread > 0 {
c.mutex.Lock()
unread := c.highestReceived - c.bytesRead
c.mutex.Unlock()
if unread > 0 {
c.connection.AddBytesRead(unread)
}
}
Expand Down

0 comments on commit be68f7f

Please sign in to comment.