diff --git a/proxy/tun/handler.go b/proxy/tun/handler.go index f830a4ea3dd0..2b73aad9474f 100644 --- a/proxy/tun/handler.go +++ b/proxy/tun/handler.go @@ -105,11 +105,12 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) { sid := session.NewID() ctx := c.ContextWithID(t.ctx, sid) + source := net.DestinationFromAddr(conn.RemoteAddr()) inbound := session.Inbound{ Name: "tun", Tag: t.tag, CanSpliceCopy: 3, - Source: net.DestinationFromAddr(conn.RemoteAddr()), + Source: source, User: &protocol.MemoryUser{ Level: t.config.UserLevel, }, @@ -127,7 +128,7 @@ func (t *Handler) HandleConnection(conn net.Conn, destination net.Destination) { Status: log.AccessAccepted, Reason: "", }) - errors.LogInfo(ctx, "processing TCP from ", conn.RemoteAddr(), " to ", destination) + errors.LogInfo(ctx, "processing from ", source, " to ", destination) link := &transport.Link{ Reader: &buf.TimeoutWrapperReader{Reader: buf.NewReader(conn)}, diff --git a/proxy/tun/stack_gvisor.go b/proxy/tun/stack_gvisor.go index d062c3d0dedd..952150a89913 100644 --- a/proxy/tun/stack_gvisor.go +++ b/proxy/tun/stack_gvisor.go @@ -9,6 +9,7 @@ import ( "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -102,21 +103,7 @@ func (t *stackGVisor) Start() error { ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) // Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support - udpForwarder := newUdpConnectionHandler(t.ctx, t.handler, func(p []byte) { - // extract network protocol from the packet - var networkProtocol tcpip.NetworkProtocolNumber - switch header.IPVersion(p) { - case header.IPv4Version: - networkProtocol = header.IPv4ProtocolNumber - case header.IPv6Version: - networkProtocol = header.IPv6ProtocolNumber - default: - // discard packet with unknown network version - return - } - - ipStack.WriteRawPacket(defaultNIC, networkProtocol, buffer.MakeWithData(p)) - }) + udpForwarder := newUdpConnectionHandler(t.handler.HandleConnection, t.writeRawUDPPacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { data := pkt.Data().AsRange().ToSlice() if len(data) == 0 { @@ -137,6 +124,69 @@ func (t *stackGVisor) Start() error { return nil } +func (t *stackGVisor) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error { + udpLen := header.UDPMinimumSize + len(payload) + srcIP := tcpip.AddrFromSlice(src.Address.IP()) + dstIP := tcpip.AddrFromSlice(dst.Address.IP()) + + // build packet with appropriate IP header size + isIPv4 := dst.Address.Family().IsIPv4() + ipHdrSize := header.IPv6MinimumSize + ipProtocol := header.IPv6ProtocolNumber + if isIPv4 { + ipHdrSize = header.IPv4MinimumSize + ipProtocol = header.IPv4ProtocolNumber + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize, + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + // Build UDP header + udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: uint16(src.Port), + DstPort: uint16(dst.Port), + Length: uint16(udpLen), + }) + + // Calculate and set UDP checksum + xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen)) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum))) + + // Build IP header + if isIPv4 { + ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + udpLen), + TTL: 64, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: srcIP, + DstAddr: dstIP, + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + } else { + ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ipHdr.Encode(&header.IPv6Fields{ + PayloadLength: uint16(udpLen), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 64, + SrcAddr: srcIP, + DstAddr: dstIP, + }) + } + + // dispatch the packet + err := t.stack.WriteRawPacket(defaultNIC, ipProtocol, buffer.MakeWithView(pkt.ToView())) + if err != nil { + return errors.New("failed to write raw udp packet back to stack", err) + } + + return nil +} + // Close is called by Handler to shut down the stack func (t *stackGVisor) Close() error { if t.stack == nil { diff --git a/proxy/tun/udp_fullcone.go b/proxy/tun/udp_fullcone.go index d2920b341416..df58ce4e7bfa 100644 --- a/proxy/tun/udp_fullcone.go +++ b/proxy/tun/udp_fullcone.go @@ -1,228 +1,134 @@ package tun import ( - "context" + "io" "sync" - "sync/atomic" - "time" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" - c "github.com/xtls/xray-core/common/ctx" - "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/common/protocol" - "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal/done" - "github.com/xtls/xray-core/common/task" - "github.com/xtls/xray-core/transport" - "github.com/xtls/xray-core/transport/pipe" - "gvisor.dev/gvisor/pkg/buffer" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checksum" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// udp connection abstraction -type udpConn struct { - lastActive atomic.Int64 - reader buf.Reader - writer buf.Writer - done *done.Instance - cancel context.CancelFunc -} - // sub-handler specifically for udp connections under main handler type udpConnectionHandler struct { sync.Mutex - ctx context.Context - handler *Handler - udpConns map[net.Destination]*udpConn - udpChecker *task.Periodic - writePacket func(p []byte) + + udpConns map[net.Destination]*udpConn + + handleConnection func(conn net.Conn, dest net.Destination) + writePacket func(data []byte, src net.Destination, dst net.Destination) error } -func newUdpConnectionHandler(ctx context.Context, h *Handler, writePacket func(p []byte)) *udpConnectionHandler { +func newUdpConnectionHandler(handleConnection func(conn net.Conn, dest net.Destination), writePacket func(data []byte, src net.Destination, dst net.Destination) error) *udpConnectionHandler { handler := &udpConnectionHandler{ - ctx: ctx, - handler: h, - udpConns: make(map[net.Destination]*udpConn), - writePacket: writePacket, + udpConns: make(map[net.Destination]*udpConn), + handleConnection: handleConnection, + writePacket: writePacket, } - handler.udpChecker = &task.Periodic{Interval: time.Minute, Execute: handler.cleanupUDP} - handler.udpChecker.Start() - return handler } -func (u *udpConnectionHandler) cleanupUDP() error { - u.Lock() - defer u.Unlock() - if len(u.udpConns) == 0 { - return errors.New("no connections") - } - now := time.Now().Unix() - for src, conn := range u.udpConns { - if now-conn.lastActive.Load() > 300 { - conn.cancel() - common.Must(conn.done.Close()) - common.Must(common.Close(conn.writer)) - delete(u.udpConns, src) - } - } - return nil -} - // HandlePacket handles UDP packets coming from tun, to forward to the dispatcher -// this custom handler support FullCone NAT of returning packets, binding connection only by the source port +// this custom handler support FullCone NAT of returning packets, binding connection only by the source addr:port func (u *udpConnectionHandler) HandlePacket(src net.Destination, dst net.Destination, data []byte) bool { u.Lock() conn, found := u.udpConns[src] if !found { - reader, writer := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024)) - conn = &udpConn{reader: reader, writer: writer, done: done.New()} + egress := make(chan []byte, 16) + conn = &udpConn{handler: u, egress: egress, src: src, dst: dst} u.udpConns[src] = conn - u.Unlock() - - go func() { - ctx, cancel := context.WithCancel(u.ctx) - conn.cancel = cancel - defer func() { - cancel() - u.Lock() - delete(u.udpConns, src) - u.Unlock() - common.Must(conn.done.Close()) - common.Must(common.Close(conn.writer)) - }() - - inbound := &session.Inbound{ - Name: "tun", - Tag: u.handler.tag, - Source: src, - CanSpliceCopy: 3, - User: &protocol.MemoryUser{Level: u.handler.config.UserLevel}, - } - ctx = session.ContextWithInbound(c.ContextWithID(ctx, session.NewID()), inbound) - ctx = session.ContextWithContent(ctx, &session.Content{ - SniffingRequest: u.handler.sniffingRequest, - }) - ctx = session.SubContextFromMuxInbound(ctx) - ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ - From: src, - To: dst, - Status: log.AccessAccepted, - Reason: "", - }) - errors.LogInfo(ctx, "processing UDP from ", src, " to ", dst) - link := &transport.Link{ - Reader: &buf.TimeoutWrapperReader{Reader: conn.reader}, - // reverse source and destination, indicating the packets to write are going in the other - // direction (written back to tun) and should have reversed addressing - Writer: &udpWriter{handler: u, src: dst, dst: src}, - } - _ = u.handler.dispatcher.DispatchLink(ctx, dst, link) - }() - } else { - conn.lastActive.Store(time.Now().Unix()) - u.Unlock() + + go u.handleConnection(conn, dst) } + u.Unlock() - b := buf.New() - b.Write(data) - b.UDP = &dst - conn.writer.WriteMultiBuffer(buf.MultiBuffer{b}) + // send packet data to the egress channel, if it has buffer, or discard + select { + case conn.egress <- data: + default: + } return true } -type udpWriter struct { +func (u *udpConnectionHandler) connectionFinished(src net.Destination) { + u.Lock() + conn, found := u.udpConns[src] + if found { + delete(u.udpConns, src) + close(conn.egress) + } + u.Unlock() +} + +// udp connection abstraction +type udpConn struct { + net.Conn + buf.Writer + handler *udpConnectionHandler - // address in the side of stack, where packet will be coming from - src net.Destination - // address on the side of tun, where packet will be destined to - dst net.Destination + + egress chan []byte + src net.Destination + dst net.Destination } -func (w *udpWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { - for _, b := range mb { - // use captured in the dispatched packet source address b.UDP as source, if available, - // otherwise use captured in the writer source w.src - srcAddr := w.src - if b.UDP != nil { - srcAddr = *b.UDP - } +// Read packets from the connection +func (c *udpConn) Read(p []byte) (int, error) { + data, ok := <-c.egress + if !ok { + return 0, io.EOF + } - // validate address family matches - if srcAddr.Address.Family() != w.src.Address.Family() { - errors.LogWarning(context.Background(), "UDP return packet address family mismatch: expected ", w.src.Address.Family(), ", got ", srcAddr.Address.Family()) - b.Release() - continue - } + n := copy(p, data) + return n, nil +} + +// Write returning packets back +func (c *udpConn) Write(p []byte) (int, error) { + // sending packets back mean sending payload with source/destination reversed + err := c.handler.writePacket(p, c.dst, c.src) + if err != nil { + return 0, nil + } + + return len(p), nil +} - payload := b.Bytes() - udpLen := header.UDPMinimumSize + len(payload) - srcIP := tcpip.AddrFromSlice(srcAddr.Address.IP()) - dstIP := tcpip.AddrFromSlice(w.dst.Address.IP()) +func (c *udpConn) Close() error { + c.handler.connectionFinished(c.src) - // build packet with appropriate IP header size - isIPv4 := srcAddr.Address.Family().IsIPv4() - ipHdrSize := header.IPv6MinimumSize - if isIPv4 { - ipHdrSize = header.IPv4MinimumSize + return nil +} + +func (c *udpConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: c.dst.Address.IP(), Port: int(c.dst.Port.Value())} +} + +func (c *udpConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: c.src.Address.IP(), Port: int(c.src.Port.Value())} +} + +// Write returning packets back +func (c *udpConn) WriteMultiBuffer(mb buf.MultiBuffer) error { + for _, b := range mb { + dst := c.dst + if b.UDP != nil { + dst = *b.UDP } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize, - Payload: buffer.MakeWithData(payload), - }) - - // Build UDP header - udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - udpHdr.Encode(&header.UDPFields{ - SrcPort: uint16(srcAddr.Port), - DstPort: uint16(w.dst.Port), - Length: uint16(udpLen), - }) - - // Calculate and set UDP checksum - xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen)) - udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum))) - - // Build IP header - if isIPv4 { - ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) - ipHdr.Encode(&header.IPv4Fields{ - TotalLength: uint16(header.IPv4MinimumSize + udpLen), - TTL: 64, - Protocol: uint8(header.UDPProtocolNumber), - SrcAddr: srcIP, - DstAddr: dstIP, - }) - ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) - } else { - ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ipHdr.Encode(&header.IPv6Fields{ - PayloadLength: uint16(udpLen), - TransportProtocol: header.UDPProtocolNumber, - HopLimit: 64, - SrcAddr: srcIP, - DstAddr: dstIP, - }) + // validate address family matches between buffer packet and the connection + if dst.Address.Family() != c.dst.Address.Family() { + continue } - // Write raw packet to network stack - views := pkt.AsSlices() - var data []byte - for _, view := range views { - data = append(data, view...) + // sending packets back mean sending payload with source/destination reversed + err := c.handler.writePacket(b.Bytes(), dst, c.src) + if err != nil { + // udp doesn't guarantee delivery, so in any failure we just continue to the next packet + continue } - w.handler.writePacket(data) - pkt.DecRef() - b.Release() } + return nil }