Skip to content

Commit

Permalink
Remove buffer waiting condition
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Jul 23, 2024
1 parent 3a1ddc0 commit e3c5398
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
14 changes: 3 additions & 11 deletions packetio/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,9 @@ func (b *Buffer) Write(packet []byte) (int, error) {
}
b.count++

waiting := b.waiting
b.waiting = false

if waiting {
select {
case b.notify <- struct{}{}:
default:
}
select {
case b.notify <- struct{}{}:
default:
}
b.mutex.Unlock()

Expand Down Expand Up @@ -244,7 +239,6 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
}

b.count--
b.waiting = false
b.mutex.Unlock()

if copied < count {
Expand All @@ -257,8 +251,6 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
b.mutex.Unlock()
return 0, io.EOF
}

b.waiting = true
b.mutex.Unlock()

select {
Expand Down
38 changes: 38 additions & 0 deletions packetio/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"

"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -611,6 +613,7 @@ func TestBufferConcurrentRead(t *testing.T) {

errCh := make(chan error, 2)
readIntoErr := func() {
packet := make([]byte, 4)
_, readErr := buffer.Read(packet)
errCh <- readErr
}
Expand All @@ -626,3 +629,38 @@ func TestBufferConcurrentRead(t *testing.T) {
err = <-errCh
assert.Equal(io.EOF, err)
}

func TestBufferConcurrentReadWrite(t *testing.T) {
defer test.TimeOut(time.Second * 5).Stop()

assert := assert.New(t)

buffer := NewBuffer()

numPkts := 1000
var numRead uint64
allRead := make(chan struct{})
readPkts := func(count int) {
packet := make([]byte, 4)
for i := 0; i < count; i++ {
_, readErr := buffer.Read(packet)
if readErr != nil {
return
}
if atomic.AddUint64(&numRead, 1) == uint64(numPkts) {
close(allRead)
return
}
}
}
go readPkts(numPkts)
go readPkts(numPkts / 100)

for i := 0; i < numPkts; i++ {
_, writeErr := buffer.Write([]byte{2, 3, 4})
assert.NoError(writeErr)
}
<-allRead

assert.NoError(buffer.Close())
}

0 comments on commit e3c5398

Please sign in to comment.