diff --git a/udp/batchconn.go b/udp/batchconn.go new file mode 100644 index 0000000..18380fd --- /dev/null +++ b/udp/batchconn.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package udp + +import ( + "net" + "runtime" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// BatchWriter represents conn can write messages in batch +type BatchWriter interface { + WriteBatch(ms []ipv4.Message, flags int) (int, error) +} + +// BatchReader represents conn can read messages in batch +type BatchReader interface { + ReadBatch(msg []ipv4.Message, flags int) (int, error) +} + +// BatchPacketConn represents conn can read/write messages in batch +type BatchPacketConn interface { + BatchWriter + BatchReader +} + +// BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch, +// only available in linux. In other platform, it will use single Write/Read as same as net.Conn. +type BatchConn struct { + net.PacketConn + + batchConn BatchPacketConn + + batchWriteMutex sync.Mutex + batchWriteMessages []ipv4.Message + batchWritePos int + batchWriteLast time.Time + + batchWriteSize int + batchWriteInterval time.Duration + + closed atomic.Bool +} + +// NewBatchConn creates a *BatchCon from net.PacketConn with batch configs. +func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn { + bc := &BatchConn{ + PacketConn: conn, + batchWriteLast: time.Now(), + batchWriteInterval: batchWriteInterval, + batchWriteSize: batchWriteSize, + batchWriteMessages: make([]ipv4.Message, batchWriteSize), + } + for i := range bc.batchWriteMessages { + bc.batchWriteMessages[i].Buffers = [][]byte{make([]byte, sendMTU)} + } + + // batch write only supports linux + if runtime.GOOS == "linux" { + if pc4 := ipv4.NewPacketConn(conn); pc4 != nil { + bc.batchConn = pc4 + } else if pc6 := ipv6.NewPacketConn(conn); pc6 != nil { + bc.batchConn = pc6 + } + } + + if bc.batchConn != nil { + go func() { + writeTicker := time.NewTicker(batchWriteInterval / 2) + defer writeTicker.Stop() + + for !bc.closed.Load() { + <-writeTicker.C + bc.batchWriteMutex.Lock() + if bc.batchWritePos > 0 && time.Since(bc.batchWriteLast) >= bc.batchWriteInterval { + _ = bc.flush() + } + bc.batchWriteMutex.Unlock() + } + }() + } + + return bc +} + +// Close batchConn and the underlying PacketConn +func (c *BatchConn) Close() error { + c.closed.Store(true) + return c.PacketConn.Close() +} + +// WriteTo write message to an UDPAddr, addr should be nil if it is a connected socket. +func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) { + if c.batchConn == nil { + return c.PacketConn.WriteTo(b, addr) + } + return c.writeBatch(b, addr) +} + +func (c *BatchConn) writeBatch(buf []byte, raddr net.Addr) (int, error) { + var err error + c.batchWriteMutex.Lock() + defer c.batchWriteMutex.Unlock() + + // c.writeCounter++ + msg := &c.batchWriteMessages[c.batchWritePos] + // reset buffers + msg.Buffers = msg.Buffers[:1] + msg.Buffers[0] = msg.Buffers[0][:cap(msg.Buffers[0])] + + c.batchWritePos++ + if raddr != nil { + msg.Addr = raddr + } + if n := copy(msg.Buffers[0], buf); n < len(buf) { + extraBuffer := make([]byte, len(buf)-n) + copy(extraBuffer, buf[n:]) + msg.Buffers = append(msg.Buffers, extraBuffer) + } else { + msg.Buffers[0] = msg.Buffers[0][:n] + } + if c.batchWritePos == c.batchWriteSize { + err = c.flush() + } + return len(buf), err +} + +// ReadBatch reads messages in batch, return length of message readed and error. +func (c *BatchConn) ReadBatch(msgs []ipv4.Message, flags int) (int, error) { + if c.batchConn == nil { + n, addr, err := c.PacketConn.ReadFrom(msgs[0].Buffers[0]) + if err == nil { + msgs[0].N = n + msgs[0].Addr = addr + return 1, nil + } + return 0, err + } + return c.batchConn.ReadBatch(msgs, flags) +} + +func (c *BatchConn) flush() error { + var writeErr error + var txN int + for txN < c.batchWritePos { + n, err := c.batchConn.WriteBatch(c.batchWriteMessages[txN:c.batchWritePos], 0) + if err != nil { + writeErr = err + break + } + txN += n + } + c.batchWritePos = 0 + c.batchWriteLast = time.Now() + return writeErr +} diff --git a/udp/conn.go b/udp/conn.go index d05af97..3cb02df 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -14,10 +14,12 @@ import ( "github.com/pion/transport/v2/deadline" "github.com/pion/transport/v2/packetio" + "golang.org/x/net/ipv4" ) const ( receiveMTU = 8192 + sendMTU = 1500 defaultListenBacklog = 128 // same as Linux default ) @@ -25,11 +27,14 @@ const ( var ( ErrClosedListener = errors.New("udp: listener closed") ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") + ErrInvalidBatchConfig = errors.New("udp: invalid batch config") ) // listener augments a connection-oriented Listener over a UDP PacketConn type listener struct { - pConn *net.UDPConn + pConn net.PacketConn + + readBatchSize int accepting atomic.Value // bool acceptCh chan *Conn @@ -109,6 +114,19 @@ func (l *listener) Addr() net.Addr { return l.pConn.LocalAddr() } +// BatchIOConfig indicates config to batch read/write packets, +// it will use ReadBatch/WriteBatch to improve throughput for UDP. +type BatchIOConfig struct { + Enable bool + // ReadBatchSize indicates the maximum number of packets to be read in one batch + ReadBatchSize int + // WriteBatchSize indicates the maximum number of packets to be written in one batch + WriteBatchSize int + // WriteBatchInterval indicates the maximum interval to wait before writing packets in one batch + // small interval will reduce latency/jitter, but increase the io count. + WriteBatchInterval time.Duration +} + // ListenConfig stores options for listening to an address. type ListenConfig struct { // Backlog defines the maximum length of the queue of pending @@ -122,6 +140,16 @@ type ListenConfig struct { // AcceptFilter determines whether the new conn should be made for // the incoming packet. If not set, any packet creates new conn. AcceptFilter func([]byte) bool + + // ReadBufferSize sets the size of the operating system's + // receive buffer associated with the listener. + ReadBufferSize int + + // WriteBufferSize sets the size of the operating system's + // send buffer associated with the connection. + WriteBufferSize int + + Batch BatchIOConfig } // Listen creates a new listener based on the ListenConfig. @@ -130,11 +158,22 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener lc.Backlog = defaultListenBacklog } + if lc.Batch.Enable && (lc.Batch.ReadBatchSize <= 0 || lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) { + return nil, ErrInvalidBatchConfig + } + conn, err := net.ListenUDP(network, laddr) if err != nil { return nil, err } + if lc.ReadBufferSize > 0 { + _ = conn.SetReadBuffer(lc.ReadBufferSize) + } + if lc.WriteBufferSize > 0 { + _ = conn.SetWriteBuffer(lc.WriteBufferSize) + } + l := &listener{ pConn: conn, acceptCh: make(chan *Conn, lc.Backlog), @@ -145,6 +184,11 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener readDoneCh: make(chan struct{}), } + if lc.Batch.Enable { + l.pConn = NewBatchConn(conn, lc.Batch.WriteBatchSize, lc.Batch.WriteBatchInterval) + l.readBatchSize = lc.Batch.ReadBatchSize + } + l.accepting.Store(true) l.connWG.Add(1) l.readWG.Add(2) // wait readLoop and Close execution routine @@ -174,21 +218,51 @@ func (l *listener) readLoop() { defer l.readWG.Done() defer close(l.readDoneCh) - buf := make([]byte, receiveMTU) + if br, ok := l.pConn.(BatchReader); ok { + l.readBatch(br) + } else { + l.read() + } +} +func (l *listener) readBatch(br BatchReader) { + msgs := make([]ipv4.Message, l.readBatchSize) + for i := range msgs { + msg := &msgs[i] + msg.Buffers = [][]byte{make([]byte, receiveMTU)} + msg.OOB = make([]byte, 40) + } for { - n, raddr, err := l.pConn.ReadFrom(buf) + n, err := br.ReadBatch(msgs, 0) if err != nil { l.errRead.Store(err) return } - conn, ok, err := l.getConn(raddr, buf[:n]) - if err != nil { - continue + for i := 0; i < n; i++ { + l.dispatchMsg(msgs[i].Addr, msgs[i].Buffers[0][:msgs[i].N]) } - if ok { - _, _ = conn.buffer.Write(buf[:n]) + } +} + +func (l *listener) read() { + buf := make([]byte, receiveMTU) + for { + n, raddr, err := l.pConn.ReadFrom(buf) + if err != nil { + l.errRead.Store(err) + return } + l.dispatchMsg(raddr, buf[:n]) + } +} + +func (l *listener) dispatchMsg(addr net.Addr, buf []byte) { + conn, ok, err := l.getConn(addr, buf) + if err != nil { + return + } + if ok { + _, _ = conn.buffer.Write(buf) } } diff --git a/udp/conn_test.go b/udp/conn_test.go index 722538d..50bf399 100644 --- a/udp/conn_test.go +++ b/udp/conn_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/pion/transport/v2/test" + "github.com/stretchr/testify/assert" ) var errHandshakeFailed = errors.New("handshake failed") @@ -478,3 +479,72 @@ func TestConnClose(t *testing.T) { } }) } + +func TestBatchIO(t *testing.T) { + lc := ListenConfig{ + Batch: BatchIOConfig{ + Enable: true, + ReadBatchSize: 10, + WriteBatchSize: 3, + WriteBatchInterval: 5 * time.Millisecond, + }, + } + + laddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 15678} + listener, err := lc.Listen("udp", laddr) + if err != nil { + t.Fatal(err) + } + + acceptQc := make(chan struct{}) + go func() { + defer close(acceptQc) + for { + buf := make([]byte, 1400) + conn, err := listener.Accept() + if errors.Is(err, ErrClosedListener) { + break + } + assert.NoError(t, err) + go func() { + defer func() { _ = conn.Close() }() + for { + n, err := conn.Read(buf) + assert.NoError(t, err) + _, err = conn.Write(buf[:n]) + assert.NoError(t, err) + } + }() + } + }() + + raddr, _ := listener.Addr().(*net.UDPAddr) + + wgs := sync.WaitGroup{} + cc := 3 + wgs.Add(cc) + + for i := 0; i < cc; i++ { + sendStr := fmt.Sprintf("hello %d", i) + go func() { + defer wgs.Done() + buf := make([]byte, 1400) + client, err := net.DialUDP("udp", nil, raddr) + assert.NoError(t, err) + defer func() { _ = client.Close() }() + for i := 0; i < 100; i++ { + _, err := client.Write([]byte(sendStr)) + assert.NoError(t, err) + err = client.SetReadDeadline(time.Now().Add(time.Second)) + assert.NoError(t, err) + n, err := client.Read(buf) + assert.NoError(t, err) + assert.Equal(t, sendStr, string(buf[:n]), i) + } + }() + } + wgs.Wait() + + _ = listener.Close() + <-acceptQc +}