Skip to content

Commit

Permalink
Add batch read/write for udp listener
Browse files Browse the repository at this point in the history
Add batch read/write for udp listener
  • Loading branch information
cnderrauber committed Aug 31, 2023
1 parent 6890c79 commit 46a4b2c
Show file tree
Hide file tree
Showing 3 changed files with 314 additions and 8 deletions.
162 changes: 162 additions & 0 deletions udp/batchconn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package udp

import (
"net"
"runtime"
"sync"
"sync/atomic"
"time"

"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)

// BatchWriter represents conn can write messages in batch
type BatchWriter interface {
WriteBatch(ms []ipv4.Message, flags int) (int, error)
}

// BatchReader represents conn can read messages in batch
type BatchReader interface {
ReadBatch(msg []ipv4.Message, flags int) (int, error)
}

// BatchPacketConn represents conn can read/write messages in batch
type BatchPacketConn interface {
BatchWriter
BatchReader
}

// BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch,
// only available in linux. In other platform, it will use single Write/Read as same as net.Conn.
type BatchConn struct {
net.PacketConn

batchConn BatchPacketConn

batchWriteMutex sync.Mutex
batchWriteMessages []ipv4.Message
batchWritePos int
batchWriteLast time.Time

batchWriteSize int
batchWriteInterval time.Duration

closed atomic.Bool
}

// NewBatchConn creates a *BatchCon from net.PacketConn with batch configs.
func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn {
bc := &BatchConn{
PacketConn: conn,
batchWriteLast: time.Now(),
batchWriteInterval: batchWriteInterval,
batchWriteSize: batchWriteSize,
batchWriteMessages: make([]ipv4.Message, batchWriteSize),
}
for i := range bc.batchWriteMessages {
bc.batchWriteMessages[i].Buffers = [][]byte{make([]byte, sendMTU)}
}

// batch write only supports linux
if runtime.GOOS == "linux" {
if pc4 := ipv4.NewPacketConn(conn); pc4 != nil {
bc.batchConn = pc4
} else if pc6 := ipv6.NewPacketConn(conn); pc6 != nil {
bc.batchConn = pc6
}
}

if bc.batchConn != nil {
go func() {
writeTicker := time.NewTicker(batchWriteInterval / 2)
defer writeTicker.Stop()

for !bc.closed.Load() {
<-writeTicker.C
bc.batchWriteMutex.Lock()
if bc.batchWritePos > 0 && time.Since(bc.batchWriteLast) >= bc.batchWriteInterval {
_ = bc.flush()
}
bc.batchWriteMutex.Unlock()
}
}()
}

return bc
}

// Close batchConn and the underlying PacketConn
func (c *BatchConn) Close() error {
c.closed.Store(true)
return c.PacketConn.Close()
}

// WriteTo write message to an UDPAddr, addr should be nil if it is a connected socket.
func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if c.batchConn == nil {
return c.PacketConn.WriteTo(b, addr)
}
return c.writeBatch(b, addr)
}

func (c *BatchConn) writeBatch(buf []byte, raddr net.Addr) (int, error) {
var err error
c.batchWriteMutex.Lock()
defer c.batchWriteMutex.Unlock()

// c.writeCounter++
msg := &c.batchWriteMessages[c.batchWritePos]
// reset buffers
msg.Buffers = msg.Buffers[:1]
msg.Buffers[0] = msg.Buffers[0][:cap(msg.Buffers[0])]

c.batchWritePos++
if raddr != nil {
msg.Addr = raddr
}
if n := copy(msg.Buffers[0], buf); n < len(buf) {
extraBuffer := make([]byte, len(buf)-n)
copy(extraBuffer, buf[n:])
msg.Buffers = append(msg.Buffers, extraBuffer)
} else {
msg.Buffers[0] = msg.Buffers[0][:n]
}
if c.batchWritePos == c.batchWriteSize {
err = c.flush()
}
return len(buf), err
}

// ReadBatch reads messages in batch, return length of message readed and error.
func (c *BatchConn) ReadBatch(msgs []ipv4.Message, flags int) (int, error) {
if c.batchConn == nil {
n, addr, err := c.PacketConn.ReadFrom(msgs[0].Buffers[0])
if err == nil {
msgs[0].N = n
msgs[0].Addr = addr
return 1, nil
}
return 0, err
}
return c.batchConn.ReadBatch(msgs, flags)
}

func (c *BatchConn) flush() error {
var writeErr error
var txN int
for txN < c.batchWritePos {
n, err := c.batchConn.WriteBatch(c.batchWriteMessages[txN:c.batchWritePos], 0)
if err != nil {
writeErr = err
break
}
txN += n
}
c.batchWritePos = 0
c.batchWriteLast = time.Now()
return writeErr
}
90 changes: 82 additions & 8 deletions udp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@ import (

"github.com/pion/transport/v2/deadline"
"github.com/pion/transport/v2/packetio"
"golang.org/x/net/ipv4"
)

const (
receiveMTU = 8192
sendMTU = 1500
defaultListenBacklog = 128 // same as Linux default
)

// Typed errors
var (
ErrClosedListener = errors.New("udp: listener closed")
ErrListenQueueExceeded = errors.New("udp: listen queue exceeded")
ErrInvalidBatchConfig = errors.New("udp: invalid batch config")
)

// listener augments a connection-oriented Listener over a UDP PacketConn
type listener struct {
pConn *net.UDPConn
pConn net.PacketConn

readBatchSize int

accepting atomic.Value // bool
acceptCh chan *Conn
Expand Down Expand Up @@ -109,6 +114,19 @@ func (l *listener) Addr() net.Addr {
return l.pConn.LocalAddr()
}

// BatchIOConfig indicates config to batch read/write packets,
// it will use ReadBatch/WriteBatch to improve throughput for UDP.
type BatchIOConfig struct {
Enable bool
// ReadBatchSize indicates the maximum number of packets to be read in one batch
ReadBatchSize int
// WriteBatchSize indicates the maximum number of packets to be written in one batch
WriteBatchSize int
// WriteBatchInterval indicates the maximum interval to wait before writing packets in one batch
// small interval will reduce latency/jitter, but increase the io count.
WriteBatchInterval time.Duration
}

// ListenConfig stores options for listening to an address.
type ListenConfig struct {
// Backlog defines the maximum length of the queue of pending
Expand All @@ -122,6 +140,16 @@ type ListenConfig struct {
// AcceptFilter determines whether the new conn should be made for
// the incoming packet. If not set, any packet creates new conn.
AcceptFilter func([]byte) bool

// ReadBufferSize sets the size of the operating system's
// receive buffer associated with the listener.
ReadBufferSize int

// WriteBufferSize sets the size of the operating system's
// send buffer associated with the connection.
WriteBufferSize int

Batch BatchIOConfig
}

// Listen creates a new listener based on the ListenConfig.
Expand All @@ -130,11 +158,22 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener
lc.Backlog = defaultListenBacklog
}

if lc.Batch.Enable && (lc.Batch.ReadBatchSize <= 0 || lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) {
return nil, ErrInvalidBatchConfig
}

conn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, err
}

if lc.ReadBufferSize > 0 {
_ = conn.SetReadBuffer(lc.ReadBufferSize)
}
if lc.WriteBufferSize > 0 {
_ = conn.SetWriteBuffer(lc.WriteBufferSize)
}

l := &listener{
pConn: conn,
acceptCh: make(chan *Conn, lc.Backlog),
Expand All @@ -145,6 +184,11 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener
readDoneCh: make(chan struct{}),
}

if lc.Batch.Enable {
l.pConn = NewBatchConn(conn, lc.Batch.WriteBatchSize, lc.Batch.WriteBatchInterval)
l.readBatchSize = lc.Batch.ReadBatchSize
}

l.accepting.Store(true)
l.connWG.Add(1)
l.readWG.Add(2) // wait readLoop and Close execution routine
Expand Down Expand Up @@ -174,21 +218,51 @@ func (l *listener) readLoop() {
defer l.readWG.Done()
defer close(l.readDoneCh)

buf := make([]byte, receiveMTU)
if br, ok := l.pConn.(BatchReader); ok {
l.readBatch(br)
} else {
l.read()
}
}

func (l *listener) readBatch(br BatchReader) {
msgs := make([]ipv4.Message, l.readBatchSize)
for i := range msgs {
msg := &msgs[i]
msg.Buffers = [][]byte{make([]byte, receiveMTU)}
msg.OOB = make([]byte, 40)
}
for {
n, raddr, err := l.pConn.ReadFrom(buf)
n, err := br.ReadBatch(msgs, 0)
if err != nil {
l.errRead.Store(err)
return
}
conn, ok, err := l.getConn(raddr, buf[:n])
if err != nil {
continue
for i := 0; i < n; i++ {
l.dispatchMsg(msgs[i].Addr, msgs[i].Buffers[0][:msgs[i].N])
}
if ok {
_, _ = conn.buffer.Write(buf[:n])
}
}

func (l *listener) read() {
buf := make([]byte, receiveMTU)
for {
n, raddr, err := l.pConn.ReadFrom(buf)
if err != nil {
l.errRead.Store(err)
return
}
l.dispatchMsg(raddr, buf[:n])
}
}

func (l *listener) dispatchMsg(addr net.Addr, buf []byte) {
conn, ok, err := l.getConn(addr, buf)
if err != nil {
return
}
if ok {
_, _ = conn.buffer.Write(buf)
}
}

Expand Down
Loading

0 comments on commit 46a4b2c

Please sign in to comment.