diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 716faae6ceaa..b8a437d0566c 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -26,6 +26,7 @@ import ( "github.com/xtls/xray-core/transport/internet/finalmask/mkcp/original" "github.com/xtls/xray-core/transport/internet/finalmask/salamander" "github.com/xtls/xray-core/transport/internet/finalmask/xdns" + "github.com/xtls/xray-core/transport/internet/finalmask/xicmp" "github.com/xtls/xray-core/transport/internet/httpupgrade" "github.com/xtls/xray-core/transport/internet/hysteria" "github.com/xtls/xray-core/transport/internet/kcp" @@ -1239,6 +1240,7 @@ var ( "mkcp-aes128gcm": func() interface{} { return new(Aes128Gcm) }, "salamander": func() interface{} { return new(Salamander) }, "xdns": func() interface{} { return new(Xdns) }, + "xicmp": func() interface{} { return new(Xicmp) }, }, "type", "settings") ) @@ -1327,6 +1329,24 @@ func (c *Xdns) Build() (proto.Message, error) { }, nil } +type Xicmp struct { + ListenIp string `json:"listenIp"` + Id uint16 `json:"id"` +} + +func (c *Xicmp) Build() (proto.Message, error) { + config := &xicmp.Config{ + Ip: c.ListenIp, + Id: int32(c.Id), + } + + if config.Ip == "" { + config.Ip = "0.0.0.0" + } + + return config, nil +} + type Mask struct { Type string `json:"type"` Settings *json.RawMessage `json:"settings"` diff --git a/transport/internet/finalmask/xdns/client.go b/transport/internet/finalmask/xdns/client.go index c8b815d37bf8..9d80bc225762 100644 --- a/transport/internet/finalmask/xdns/client.go +++ b/transport/internet/finalmask/xdns/client.go @@ -209,7 +209,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { encoded, err := encode(p, c.clientID, c.domain) if err != nil { - errors.LogDebug(context.Background(), "xdns encode err", err) + errors.LogDebug(context.Background(), "xdns encode err ", err) return 0, errors.New("xdns encode").Base(err) } diff --git a/transport/internet/finalmask/xicmp/client.go b/transport/internet/finalmask/xicmp/client.go new file mode 100644 index 000000000000..fa7c6a4fb4c1 --- /dev/null +++ b/transport/internet/finalmask/xicmp/client.go @@ -0,0 +1,348 @@ +package xicmp + +import ( + "context" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/xtls/xray-core/common/crypto" + "github.com/xtls/xray-core/common/errors" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +const ( + initPollDelay = 500 * time.Millisecond + maxPollDelay = 10 * time.Second + pollDelayMultiplier = 2.0 + pollLimit = 16 +) + +type packet struct { + p []byte + addr net.Addr +} + +type seqStatus struct { + needSeqByte bool + seqByte byte + received bool +} + +type xicmpConnClient struct { + conn net.PacketConn + icmpConn *icmp.PacketConn + + typ icmp.Type + id int + seq int + proto int + seqStatus map[int]*seqStatus + + pollChan chan struct{} + readQueue chan *packet + writeQueue chan *packet + + closed bool + mutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { + if !end { + return nil, errors.New("xicmp requires being at the outermost level") + } + + network := "ip4:icmp" + typ := icmp.Type(ipv4.ICMPTypeEcho) + proto := 1 + if strings.Contains(c.Ip, ":") { + network = "ip6:ipv6-icmp" + typ = ipv6.ICMPTypeEchoRequest + proto = 58 + } + + icmpConn, err := icmp.ListenPacket(network, c.Ip) + if err != nil { + return nil, errors.New("xicmp listen err").Base(err) + } + + if c.Id == 0 { + c.Id = int32(crypto.RandBetween(0, 65535)) + } + + conn := &xicmpConnClient{ + conn: raw, + icmpConn: icmpConn, + + typ: typ, + id: int(c.Id), + seq: 1, + proto: proto, + seqStatus: make(map[int]*seqStatus), + + pollChan: make(chan struct{}, pollLimit), + readQueue: make(chan *packet, 128), + writeQueue: make(chan *packet, 128), + } + + go conn.recvLoop() + go conn.sendLoop() + + return conn, nil +} + +func (c *xicmpConnClient) encode(p []byte) ([]byte, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + needSeqByte := false + var seqByte byte + data := p + if len(p) > 0 { + needSeqByte = true + seqByte = p[0] + } + + msg := icmp.Message{ + Type: c.typ, + Code: 0, + Body: &icmp.Echo{ + ID: c.id, + Seq: c.seq, + Data: data, + }, + } + + buf, err := msg.Marshal(nil) + if err != nil { + return nil, err + } + + if len(buf) > 8192 { + return nil, errors.New("xicmp len(buf) > 8192") + } + + c.seqStatus[c.seq] = &seqStatus{ + needSeqByte: needSeqByte, + seqByte: seqByte, + received: false, + } + + c.seq++ + + if c.seq == 65536 { + c.seq = 1 + } + + return buf, nil +} + +func (c *xicmpConnClient) recvLoop() { + for { + if c.closed { + break + } + + var buf [8192]byte + + n, addr, err := c.icmpConn.ReadFrom(buf[:]) + if err != nil { + continue + } + + msg, err := icmp.ParseMessage(c.proto, buf[:n]) + if err != nil { + continue + } + + if msg.Type != ipv4.ICMPTypeEchoReply && msg.Type != ipv6.ICMPTypeEchoReply { + continue + } + + echo, ok := msg.Body.(*icmp.Echo) + if !ok { + continue + } + + seqStatus, ok := c.seqStatus[echo.Seq] + + if !ok { + continue + } + + if seqStatus.received { + continue + } + + if seqStatus.needSeqByte { + if len(echo.Data) <= 1 { + continue + } + if echo.Data[0] == seqStatus.seqByte { + continue + } + echo.Data = echo.Data[1:] + } + + if len(echo.Data) > 0 { + seqStatus.received = true + + buf := make([]byte, len(echo.Data)) + copy(buf, echo.Data) + select { + case c.readQueue <- &packet{ + p: buf, + addr: &net.UDPAddr{IP: addr.(*net.IPAddr).IP}, + }: + default: + } + + select { + case c.pollChan <- struct{}{}: + default: + } + } + } + + close(c.pollChan) + close(c.readQueue) +} + +func (c *xicmpConnClient) sendLoop() { + var addr net.Addr + + pollDelay := initPollDelay + pollTimer := time.NewTimer(pollDelay) + for { + var p *packet + pollTimerExpired := false + + select { + case p = <-c.writeQueue: + default: + select { + case p = <-c.writeQueue: + case <-c.pollChan: + case <-pollTimer.C: + pollTimerExpired = true + } + } + + if p != nil { + addr = p.addr + + select { + case <-c.pollChan: + default: + } + } else if addr != nil { + encoded, _ := c.encode(nil) + p = &packet{ + p: encoded, + addr: addr, + } + } + + if pollTimerExpired { + pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier) + if pollDelay > maxPollDelay { + pollDelay = maxPollDelay + } + } else { + if !pollTimer.Stop() { + <-pollTimer.C + } + pollDelay = initPollDelay + } + pollTimer.Reset(pollDelay) + + if c.closed { + return + } + + if p != nil { + _, err := c.icmpConn.WriteTo(p.p, p.addr) + if err != nil { + errors.LogDebug(context.Background(), "xicmp writeto err ", err) + } + } + } +} + +func (c *xicmpConnClient) Size() int32 { + return 0 +} + +func (c *xicmpConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, ok := <-c.readQueue + if !ok { + return 0, nil, io.EOF + } + n = copy(p, packet.p) + if n != len(packet.p) { + return 0, nil, io.ErrShortBuffer + } + return n, packet.addr, nil +} + +func (c *xicmpConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { + encoded, err := c.encode(p) + if err != nil { + return 0, errors.New("xicmp encode").Base(err) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return 0, errors.New("xicmp closed") + } + + select { + case c.writeQueue <- &packet{ + p: encoded, + addr: &net.IPAddr{IP: addr.(*net.UDPAddr).IP}, + }: + return len(p), nil + default: + return 0, errors.New("xicmp queue full") + } +} + +func (c *xicmpConnClient) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + c.closed = true + close(c.writeQueue) + + _ = c.icmpConn.Close() + return c.conn.Close() +} + +func (c *xicmpConnClient) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP, + Port: c.id, + } +} + +func (c *xicmpConnClient) SetDeadline(t time.Time) error { + return c.icmpConn.SetDeadline(t) +} + +func (c *xicmpConnClient) SetReadDeadline(t time.Time) error { + return c.icmpConn.SetReadDeadline(t) +} + +func (c *xicmpConnClient) SetWriteDeadline(t time.Time) error { + return c.icmpConn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/xicmp/config.go b/transport/internet/finalmask/xicmp/config.go new file mode 100644 index 000000000000..81a483af8539 --- /dev/null +++ b/transport/internet/finalmask/xicmp/config.go @@ -0,0 +1,16 @@ +package xicmp + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, end) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, end) +} diff --git a/transport/internet/finalmask/xicmp/config.pb.go b/transport/internet/finalmask/xicmp/config.pb.go new file mode 100644 index 000000000000..290b508525e4 --- /dev/null +++ b/transport/internet/finalmask/xicmp/config.pb.go @@ -0,0 +1,132 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc v6.33.1 +// source: transport/internet/finalmask/xicmp/config.proto + +package xicmp + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` + Id int32 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_xicmp_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_xicmp_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_xicmp_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +func (x *Config) GetId() int32 { + if x != nil { + return x.Id + } + return 0 +} + +var File_transport_internet_finalmask_xicmp_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_xicmp_config_proto_rawDesc = "" + + "\n" + + "/transport/internet/finalmask/xicmp/config.proto\x12'xray.transport.internet.finalmask.xicmp\"(\n" + + "\x06Config\x12\x0e\n" + + "\x02ip\x18\x01 \x01(\tR\x02ip\x12\x0e\n" + + "\x02id\x18\x02 \x01(\x05R\x02idB\x97\x01\n" + + "+com.xray.transport.internet.finalmask.xicmpP\x01Z= idleTimeout { + close(q.queue) + delete(c.writeQueueMap, key) + } + } + + return false + } + + for { + time.Sleep(idleTimeout / 2) + if f() { + return + } + } +} + +func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + q, ok := c.writeQueueMap[addr.String()] + if !ok { + q = &queue{ + queue: make(chan []byte, 128), + } + c.writeQueueMap[addr.String()] = q + } + q.lash = time.Now() + + return q +} + +func (c *xicmpConnServer) encode(p []byte, id int, seq int, needSeqByte bool, seqByte byte) ([]byte, error) { + data := p + if needSeqByte { + b2 := c.randUntil(seqByte) + data = append([]byte{b2}, p...) + } + + msg := icmp.Message{ + Type: c.typ, + Code: 0, + Body: &icmp.Echo{ + ID: id, + Seq: seq, + Data: data, + }, + } + + buf, err := msg.Marshal(nil) + if err != nil { + return nil, err + } + + if len(buf) > 8192 { + return nil, errors.New("xicmp len(buf) > 8192") + } + + return buf, nil +} + +func (c *xicmpConnServer) randUntil(b1 byte) byte { + b2 := byte(crypto.RandBetween(0, 255)) + for { + if b2 != b1 { + return b2 + } + b2 = byte(crypto.RandBetween(0, 255)) + } +} + +func (c *xicmpConnServer) recvLoop() { + for { + if c.closed { + break + } + + var buf [8192]byte + + n, addr, err := c.icmpConn.ReadFrom(buf[:]) + if err != nil { + continue + } + + msg, err := icmp.ParseMessage(c.proto, buf[:n]) + if err != nil { + continue + } + + if msg.Type != ipv4.ICMPTypeEcho && msg.Type != ipv6.ICMPTypeEchoRequest { + continue + } + + echo, ok := msg.Body.(*icmp.Echo) + if !ok { + continue + } + + if c.config.Id != 0 && echo.ID != int(c.config.Id) { + continue + } + + needSeqByte := false + var seqByte byte + + if len(echo.Data) > 0 { + needSeqByte = true + seqByte = echo.Data[0] + + buf := make([]byte, len(echo.Data)) + copy(buf, echo.Data) + select { + case c.readQueue <- &packet{ + p: buf, + addr: &net.UDPAddr{ + IP: addr.(*net.IPAddr).IP, + Port: echo.ID, + }, + }: + default: + } + } + + select { + case c.ch <- &record{ + id: echo.ID, + seq: echo.Seq, + needSeqByte: needSeqByte, + seqByte: seqByte, + addr: &net.UDPAddr{ + IP: addr.(*net.IPAddr).IP, + Port: echo.ID, + }, + }: + default: + } + } + + close(c.ch) + close(c.readQueue) +} + +func (c *xicmpConnServer) sendLoop() { + var nextRec *record + for { + rec := nextRec + nextRec = nil + + if rec == nil { + var ok bool + rec, ok = <-c.ch + if !ok { + break + } + } + + queue := c.ensureQueue(rec.addr) + if queue == nil { + return + } + + var p []byte + + timer := time.NewTimer(maxResponseDelay) + + select { + case p = <-queue.queue: + default: + select { + case p = <-queue.queue: + case <-timer.C: + case nextRec = <-c.ch: + } + } + + timer.Stop() + + if len(p) == 0 { + continue + } + + buf, err := c.encode(p, rec.id, rec.seq, rec.needSeqByte, rec.seqByte) + if err != nil { + continue + } + + if c.closed { + return + } + + _, err = c.icmpConn.WriteTo(buf, &net.IPAddr{IP: rec.addr.(*net.UDPAddr).IP}) + if err != nil { + errors.LogDebug(context.Background(), "xicmp writeto err ", err) + } + } +} + +func (c *xicmpConnServer) Size() int32 { + return 0 +} + +func (c *xicmpConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, ok := <-c.readQueue + if !ok { + return 0, nil, io.EOF + } + n = copy(p, packet.p) + if n != len(packet.p) { + return 0, nil, io.ErrShortBuffer + } + return n, packet.addr, nil +} + +func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { + q := c.ensureQueue(addr) + if q == nil { + return 0, errors.New("xicmp closed") + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return 0, errors.New("xicmp closed") + } + + buf := make([]byte, len(p)) + copy(buf, p) + + select { + case q.queue <- buf: + return len(p), nil + default: + return 0, errors.New("xicmp queue full") + } +} + +func (c *xicmpConnServer) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + c.closed = true + for key, q := range c.writeQueueMap { + close(q.queue) + delete(c.writeQueueMap, key) + } + + _ = c.icmpConn.Close() + return c.conn.Close() +} + +func (c *xicmpConnServer) LocalAddr() net.Addr { + return &net.UDPAddr{IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP} +} + +func (c *xicmpConnServer) SetDeadline(t time.Time) error { + return c.icmpConn.SetDeadline(t) +} + +func (c *xicmpConnServer) SetReadDeadline(t time.Time) error { + return c.icmpConn.SetReadDeadline(t) +} + +func (c *xicmpConnServer) SetWriteDeadline(t time.Time) error { + return c.icmpConn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/xicmp/xicmp_test.go b/transport/internet/finalmask/xicmp/xicmp_test.go new file mode 100644 index 000000000000..1ac921819bea --- /dev/null +++ b/transport/internet/finalmask/xicmp/xicmp_test.go @@ -0,0 +1,74 @@ +package xicmp_test + +import ( + "bytes" + "fmt" + "testing" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func TestICMPEchoMarshal(t *testing.T) { + msg := icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + ICMPTypeEcho, _ := msg.Marshal(nil) + fmt.Println("ICMPTypeEcho", len(ICMPTypeEcho), ICMPTypeEcho) + + msg = icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + ICMPTypeEchoReply, _ := msg.Marshal(nil) + fmt.Println("ICMPTypeEchoReply", len(ICMPTypeEchoReply), ICMPTypeEchoReply) + + msg = icmp.Message{ + Type: ipv6.ICMPTypeEchoRequest, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + ICMPTypeEchoRequest, _ := msg.Marshal(nil) + fmt.Println("ICMPTypeEchoRequest", len(ICMPTypeEchoRequest), ICMPTypeEchoRequest) + + msg = icmp.Message{ + Type: ipv6.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + V6ICMPTypeEchoReply, _ := msg.Marshal(nil) + fmt.Println("V6ICMPTypeEchoReply", len(V6ICMPTypeEchoReply), V6ICMPTypeEchoReply) + + if !bytes.Equal(ICMPTypeEcho[0:2], []byte{8, 0}) || !bytes.Equal(ICMPTypeEcho[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("ICMPTypeEcho Type/Code or ID/Seq mismatch: %v", ICMPTypeEcho) + } + if !bytes.Equal(ICMPTypeEchoReply[0:2], []byte{0, 0}) || !bytes.Equal(ICMPTypeEchoReply[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("ICMPTypeEchoReply Type/Code or ID/Seq mismatch: %v", ICMPTypeEchoReply) + } + if !bytes.Equal(ICMPTypeEchoRequest[0:2], []byte{128, 0}) || !bytes.Equal(ICMPTypeEchoRequest[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("ICMPTypeEchoRequest Type/Code or ID/Seq mismatch: %v", ICMPTypeEchoRequest) + } + if !bytes.Equal(V6ICMPTypeEchoReply[0:2], []byte{129, 0}) || !bytes.Equal(V6ICMPTypeEchoReply[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("V6ICMPTypeEchoReply Type/Code or ID/Seq mismatch: %v", V6ICMPTypeEchoReply) + } +}