diff --git a/udp/batchconn.go b/udp/batchconn.go index 18380fd..b01be25 100644 --- a/udp/batchconn.go +++ b/udp/batchconn.go @@ -4,6 +4,7 @@ package udp import ( + "io" "net" "runtime" "sync" @@ -28,6 +29,7 @@ type BatchReader interface { type BatchPacketConn interface { BatchWriter BatchReader + io.Closer } // BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch, @@ -48,7 +50,7 @@ type BatchConn struct { closed atomic.Bool } -// NewBatchConn creates a *BatchCon from net.PacketConn with batch configs. +// NewBatchConn creates a *BatchConn from net.PacketConn with batch configs. func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn { bc := &BatchConn{ PacketConn: conn, @@ -92,6 +94,14 @@ func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval ti // Close batchConn and the underlying PacketConn func (c *BatchConn) Close() error { c.closed.Store(true) + c.batchWriteMutex.Lock() + if c.batchWritePos > 0 { + _ = c.flush() + } + c.batchWriteMutex.Unlock() + if c.batchConn != nil { + return c.batchConn.Close() + } return c.PacketConn.Close() } @@ -100,15 +110,14 @@ func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) { if c.batchConn == nil { return c.PacketConn.WriteTo(b, addr) } - return c.writeBatch(b, addr) + return c.enqueueMessage(b, addr) } -func (c *BatchConn) writeBatch(buf []byte, raddr net.Addr) (int, error) { +func (c *BatchConn) enqueueMessage(buf []byte, raddr net.Addr) (int, error) { var err error c.batchWriteMutex.Lock() defer c.batchWriteMutex.Unlock() - // c.writeCounter++ msg := &c.batchWriteMessages[c.batchWritePos] // reset buffers msg.Buffers = msg.Buffers[:1] diff --git a/udp/conn.go b/udp/conn.go index 3cb02df..e2378f8 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -118,7 +118,8 @@ func (l *listener) Addr() net.Addr { // it will use ReadBatch/WriteBatch to improve throughput for UDP. type BatchIOConfig struct { Enable bool - // ReadBatchSize indicates the maximum number of packets to be read in one batch + // ReadBatchSize indicates the maximum number of packets to be read in one batch, a batch size less than 2 means + // disable read batch. ReadBatchSize int // WriteBatchSize indicates the maximum number of packets to be written in one batch WriteBatchSize int @@ -158,7 +159,7 @@ func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener lc.Backlog = defaultListenBacklog } - if lc.Batch.Enable && (lc.Batch.ReadBatchSize <= 0 || lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) { + if lc.Batch.Enable && (lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) { return nil, ErrInvalidBatchConfig } @@ -218,7 +219,7 @@ func (l *listener) readLoop() { defer l.readWG.Done() defer close(l.readDoneCh) - if br, ok := l.pConn.(BatchReader); ok { + if br, ok := l.pConn.(BatchReader); ok && l.readBatchSize > 1 { l.readBatch(br) } else { l.read() diff --git a/udp/conn_test.go b/udp/conn_test.go index 50bf399..1904d55 100644 --- a/udp/conn_test.go +++ b/udp/conn_test.go @@ -13,6 +13,7 @@ import ( "io" "net" "sync" + "sync/atomic" "testing" "time" @@ -488,6 +489,8 @@ func TestBatchIO(t *testing.T) { WriteBatchSize: 3, WriteBatchInterval: 5 * time.Millisecond, }, + ReadBufferSize: 64 * 1024, + WriteBufferSize: 64 * 1024, } laddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 15678} @@ -496,23 +499,36 @@ func TestBatchIO(t *testing.T) { t.Fatal(err) } - acceptQc := make(chan struct{}) + var serverConnWg sync.WaitGroup + serverConnWg.Add(1) go func() { - defer close(acceptQc) + var exit atomic.Bool + defer func() { + defer serverConnWg.Done() + exit.Store(true) + }() for { buf := make([]byte, 1400) - conn, err := listener.Accept() - if errors.Is(err, ErrClosedListener) { + conn, lerr := listener.Accept() + if errors.Is(lerr, ErrClosedListener) { break } - assert.NoError(t, err) + assert.NoError(t, lerr) + serverConnWg.Add(1) go func() { - defer func() { _ = conn.Close() }() - for { - n, err := conn.Read(buf) - assert.NoError(t, err) - _, err = conn.Write(buf[:n]) - assert.NoError(t, err) + defer func() { + _ = conn.Close() + serverConnWg.Done() + }() + for !exit.Load() { + _ = conn.SetReadDeadline(time.Now().Add(time.Second)) + n, rerr := conn.Read(buf) + if rerr != nil { + assert.ErrorContains(t, rerr, "timeout") + } else { + _, rerr = conn.Write(buf[:n]) + assert.NoError(t, rerr) + } } }() } @@ -520,6 +536,17 @@ func TestBatchIO(t *testing.T) { raddr, _ := listener.Addr().(*net.UDPAddr) + // test flush by WriteBatchInterval expired + readBuf := make([]byte, 1400) + cli, err := net.DialUDP("udp", nil, raddr) + assert.NoError(t, err) + flushStr := "flushbytimer" + _, err = cli.Write([]byte("flushbytimer")) + assert.NoError(t, err) + n, err := cli.Read(readBuf) + assert.NoError(t, err) + assert.Equal(t, flushStr, string(readBuf[:n])) + wgs := sync.WaitGroup{} cc := 3 wgs.Add(cc) @@ -532,7 +559,7 @@ func TestBatchIO(t *testing.T) { client, err := net.DialUDP("udp", nil, raddr) assert.NoError(t, err) defer func() { _ = client.Close() }() - for i := 0; i < 100; i++ { + for i := 0; i < 1; i++ { _, err := client.Write([]byte(sendStr)) assert.NoError(t, err) err = client.SetReadDeadline(time.Now().Add(time.Second)) @@ -546,5 +573,5 @@ func TestBatchIO(t *testing.T) { wgs.Wait() _ = listener.Close() - <-acceptQc + serverConnWg.Wait() }