diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index afd9aa6d02f..72243018179 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -2,6 +2,7 @@ package trojan import ( "context" + sync "sync" core "github.com/v2fly/v2ray-core/v5" "github.com/v2fly/v2ray-core/v5/common" @@ -119,7 +120,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) packetReader := &PacketReader{Reader: conn} - splitReader := &PacketSplitReader{Reader: packetReader} + splitReader := &PacketConnectionReader{ + readerAccess: &sync.Mutex{}, + reader: packetReader, + } return udp.CopyPacketConn(packetConn, splitReader, udp.UpdateActivity(timer)) } diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index 0dddd298435..654f77db8da 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" gonet "net" + sync "sync" "github.com/v2fly/v2ray-core/v5/common/buf" "github.com/v2fly/v2ray-core/v5/common/net" @@ -287,25 +288,29 @@ func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) { return &PacketPayload{Target: dest, Buffer: mb}, nil } -type PacketSplitReader struct { - Reader *PacketReader - Payload *PacketPayload +type PacketConnectionReader struct { + readerAccess *sync.Mutex + reader *PacketReader + payload *PacketPayload } -func (r *PacketSplitReader) ReadFrom(p []byte) (n int, addr gonet.Addr, err error) { - if r.Payload == nil || r.Payload.Buffer.IsEmpty() { - r.Payload, err = r.Reader.ReadMultiBufferWithMetadata() +func (r *PacketConnectionReader) ReadFrom(p []byte) (n int, addr gonet.Addr, err error) { + r.readerAccess.Lock() + defer r.readerAccess.Unlock() + + if r.payload == nil || r.payload.Buffer.IsEmpty() { + r.payload, err = r.reader.ReadMultiBufferWithMetadata() if err != nil { return } } addr = &gonet.UDPAddr{ - IP: r.Payload.Target.Address.IP(), - Port: int(r.Payload.Target.Port), + IP: r.payload.Target.Address.IP(), + Port: int(r.payload.Target.Port), } - r.Payload.Buffer, n = buf.SplitBytes(r.Payload.Buffer, p) + r.payload.Buffer, n = buf.SplitFirstBytes(r.payload.Buffer, p) return }