diff --git a/packetio/buffer.go b/packetio/buffer.go index cae8219..9f91b04 100644 --- a/packetio/buffer.go +++ b/packetio/buffer.go @@ -173,14 +173,9 @@ func (b *Buffer) Write(packet []byte) (int, error) { } b.count++ - waiting := b.waiting - b.waiting = false - - if waiting { - select { - case b.notify <- struct{}{}: - default: - } + select { + case b.notify <- struct{}{}: + default: } b.mutex.Unlock() @@ -244,7 +239,6 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit } b.count-- - b.waiting = false b.mutex.Unlock() if copied < count { @@ -257,8 +251,6 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit b.mutex.Unlock() return 0, io.EOF } - - b.waiting = true b.mutex.Unlock() select { diff --git a/packetio/buffer_test.go b/packetio/buffer_test.go index 2b8cb03..131f066 100644 --- a/packetio/buffer_test.go +++ b/packetio/buffer_test.go @@ -8,9 +8,11 @@ import ( "fmt" "io" "net" + "sync/atomic" "testing" "time" + "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) @@ -611,6 +613,7 @@ func TestBufferConcurrentRead(t *testing.T) { errCh := make(chan error, 2) readIntoErr := func() { + packet := make([]byte, 4) _, readErr := buffer.Read(packet) errCh <- readErr } @@ -626,3 +629,38 @@ func TestBufferConcurrentRead(t *testing.T) { err = <-errCh assert.Equal(io.EOF, err) } + +func TestBufferConcurrentReadWrite(t *testing.T) { + defer test.TimeOut(time.Second * 5).Stop() + + assert := assert.New(t) + + buffer := NewBuffer() + + numPkts := 1000 + var numRead uint64 + allRead := make(chan struct{}) + readPkts := func(count int) { + packet := make([]byte, 4) + for i := 0; i < count; i++ { + _, readErr := buffer.Read(packet) + if readErr != nil { + return + } + if atomic.AddUint64(&numRead, 1) == uint64(numPkts) { + close(allRead) + return + } + } + } + go readPkts(numPkts) + go readPkts(numPkts / 100) + + for i := 0; i < numPkts; i++ { + _, writeErr := buffer.Write([]byte{2, 3, 4}) + assert.NoError(writeErr) + } + <-allRead + + assert.NoError(buffer.Close()) +}