diff --git a/.gitignore b/.gitignore index 22469b88f8d0..ac744f79a3a2 100644 --- a/.gitignore +++ b/.gitignore @@ -14,10 +14,18 @@ # Dependency directories (remove the comment below to include it) # vendor/ +# macOS specific files *.DS_Store -.idea + +# IDE specific files +.idea/ +.vscode/ + +# Archive files *.zip *.tar.gz + +# Binaries xray xray_softfloat mockgen @@ -26,8 +34,13 @@ vprotogen errorgen !common/errors/errorgen/ *.dat -.vscode + +# Build assets /build_assets # Output from dlv test **/debug.* + +# Certificates +*.crt +*.key diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index 8a23e0c302a4..6cdb8ee78fd2 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -357,7 +357,7 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, clientIP net errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) } else { ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse { + if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index 2b0b7b66e1e3..6ce5809b669f 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -300,7 +300,7 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP ne errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) } else { ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse { + if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err diff --git a/app/dns/nameserver_tcp.go b/app/dns/nameserver_tcp.go index d84974b03d85..49854312f1ed 100644 --- a/app/dns/nameserver_tcp.go +++ b/app/dns/nameserver_tcp.go @@ -325,7 +325,7 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) } else { ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse { + if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 79df14edd0c2..23803efa24e1 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -282,7 +282,7 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name) } else { ips, ttl, err := s.findIPsForDomain(fqdn, option) - if err == nil || err == dns_feature.ErrEmptyResponse { + if err == nil || err == dns_feature.ErrEmptyResponse || dns_feature.RCodeFromError(err) == 3 { errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips) log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, ttl, err diff --git a/common/protocol/tls/cert/.gitignore b/common/protocol/tls/cert/.gitignore deleted file mode 100644 index b8987f0ba089..000000000000 --- a/common/protocol/tls/cert/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.crt -*.key \ No newline at end of file diff --git a/core/core.go b/core/core.go index 345f9dafd822..4e51211850da 100644 --- a/core/core.go +++ b/core/core.go @@ -19,7 +19,7 @@ import ( var ( Version_x byte = 25 Version_y byte = 3 - Version_z byte = 6 + Version_z byte = 31 ) var ( diff --git a/go.mod b/go.mod index 57bf8b993791..f62fca594cc3 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/golang/mock v1.7.0-rc.1 github.com/google/go-cmp v0.7.0 github.com/gorilla/websocket v1.5.3 - github.com/miekg/dns v1.1.64 + github.com/miekg/dns v1.1.65 github.com/pelletier/go-toml v1.9.5 github.com/pires/go-proxyproto v0.8.0 github.com/quic-go/quic-go v0.50.1 @@ -22,12 +22,12 @@ require ( github.com/vishvananda/netlink v1.3.0 github.com/xtls/reality v0.0.0-20240712055506-48f0b2d5ed6d go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.36.0 - golang.org/x/net v0.37.0 - golang.org/x/sync v0.12.0 - golang.org/x/sys v0.31.0 + golang.org/x/crypto v0.37.0 + golang.org/x/net v0.38.0 + golang.org/x/sync v0.13.0 + golang.org/x/sys v0.32.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 - google.golang.org/grpc v1.71.0 + google.golang.org/grpc v1.71.1 google.golang.org/protobuf v1.36.6 gvisor.dev/gvisor v0.0.0-20240320123526-dc6abceb7ff0 h12.io/socks v1.0.3 @@ -51,7 +51,7 @@ require ( go.uber.org/mock v0.5.0 // indirect golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc // indirect golang.org/x/mod v0.23.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.30.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index aacd27cadbcd..7873d77959b4 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,8 @@ github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0N github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/miekg/dns v1.1.64 h1:wuZgD9wwCE6XMT05UU/mlSko71eRSXEAm2EbjQXLKnQ= -github.com/miekg/dns v1.1.64/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= +github.com/miekg/dns v1.1.65 h1:0+tIPHzUW0GCge7IiK3guGP57VAw7hoPDfApjkMD1Fc= +github.com/miekg/dns v1.1.65/go.mod h1:Dzw9769uoKVaLuODMDZz9M6ynFU6Em65csPuoi8G0ck= github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= @@ -97,8 +97,8 @@ go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBs go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc h1:O9NuF4s+E/PvMIy+9IUZB9znFwUIXEWSstNjek6VpVg= golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= @@ -107,12 +107,12 @@ golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -121,14 +121,14 @@ golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -145,8 +145,8 @@ golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uI golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= -google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= +google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/infra/conf/cfgcommon/duration/duration.go b/infra/conf/cfgcommon/duration/duration.go index aed8e613e153..f1bbd4d73d97 100644 --- a/infra/conf/cfgcommon/duration/duration.go +++ b/infra/conf/cfgcommon/duration/duration.go @@ -8,11 +8,13 @@ import ( type Duration int64 +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON func (d *Duration) MarshalJSON() ([]byte, error) { dr := time.Duration(*d) return json.Marshal(dr.String()) } +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (d *Duration) UnmarshalJSON(b []byte) error { var v interface{} if err := json.Unmarshal(b, &v); err != nil { diff --git a/infra/conf/common.go b/infra/conf/common.go index fa48edea283b..ab3cfba79ac4 100644 --- a/infra/conf/common.go +++ b/infra/conf/common.go @@ -23,6 +23,7 @@ func (v StringList) Len() int { return len(v) } +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (v *StringList) UnmarshalJSON(data []byte) error { var strarray []string if err := json.Unmarshal(data, &strarray); err == nil { @@ -43,10 +44,12 @@ type Address struct { net.Address } -func (v Address) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (v *Address) MarshalJSON() ([]byte, error) { return json.Marshal(v.Address.String()) } +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (v *Address) UnmarshalJSON(data []byte) error { var rawStr string if err := json.Unmarshal(data, &rawStr); err != nil { @@ -81,6 +84,7 @@ func (v Network) Build() net.Network { type NetworkList []Network +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (v *NetworkList) UnmarshalJSON(data []byte) error { var strarray []Network if err := json.Unmarshal(data, &strarray); err == nil { @@ -169,6 +173,19 @@ func (v *PortRange) Build() *net.PortRange { } } +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (v *PortRange) MarshalJSON() ([]byte, error) { + return json.Marshal(v.String()) +} + +func (port *PortRange) String() string { + if port.From == port.To { + return strconv.Itoa(int(port.From)) + } else { + return fmt.Sprintf("%d-%d", port.From, port.To) + } +} + // UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (v *PortRange) UnmarshalJSON(data []byte) error { port, err := parseIntPort(data) @@ -203,20 +220,21 @@ func (list *PortList) Build() *net.PortList { return portList } -func (v PortList) MarshalJSON() ([]byte, error) { - return json.Marshal(v.String()) +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (v *PortList) MarshalJSON() ([]byte, error) { + portStr := v.String() + port, err := strconv.Atoi(portStr) + if err == nil { + return json.Marshal(port) + } else { + return json.Marshal(portStr) + } } func (v PortList) String() string { ports := []string{} for _, port := range v.Range { - if port.From == port.To { - p := strconv.Itoa(int(port.From)) - ports = append(ports, p) - } else { - p := fmt.Sprintf("%d-%d", port.From, port.To) - ports = append(ports, p) - } + ports = append(ports, port.String()) } return strings.Join(ports, ",") } @@ -277,7 +295,8 @@ type Int32Range struct { To int32 } -func (v Int32Range) MarshalJSON() ([]byte, error) { +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (v *Int32Range) MarshalJSON() ([]byte, error) { return json.Marshal(v.String()) } @@ -289,6 +308,7 @@ func (v Int32Range) String() string { } } +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (v *Int32Range) UnmarshalJSON(data []byte) error { defer v.ensureOrder() var str string diff --git a/infra/conf/dns.go b/infra/conf/dns.go index 607cbf07a4c6..7baeda873e39 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -25,6 +25,7 @@ type NameServerConfig struct { TimeoutMs uint64 `json:"timeoutMs"` } +// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (c *NameServerConfig) UnmarshalJSON(data []byte) error { var address Address if err := json.Unmarshal(data, &address); err == nil { @@ -163,6 +164,18 @@ type HostAddress struct { addrs []*Address } +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (h *HostAddress) MarshalJSON() ([]byte, error) { + if (h.addr != nil) != (h.addrs != nil) { + if h.addr != nil { + return json.Marshal(h.addr) + } else if h.addrs != nil { + return json.Marshal(h.addrs) + } + } + return nil, errors.New("unexpected config state") +} + // UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (h *HostAddress) UnmarshalJSON(data []byte) error { addr := new(Address) @@ -208,6 +221,11 @@ func getHostMapping(ha *HostAddress) *dns.Config_HostMapping { } } +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (m *HostsWrapper) MarshalJSON() ([]byte, error) { + return json.Marshal(m.Hosts) +} + // UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (m *HostsWrapper) UnmarshalJSON(data []byte) error { hosts := make(map[string]*HostAddress) diff --git a/infra/conf/fakedns.go b/infra/conf/fakedns.go index 862403668cbf..3aa20115cae1 100644 --- a/infra/conf/fakedns.go +++ b/infra/conf/fakedns.go @@ -20,6 +20,18 @@ type FakeDNSConfig struct { pools []*FakeDNSPoolElementConfig } +// MarshalJSON implements encoding/json.Marshaler.MarshalJSON +func (f *FakeDNSConfig) MarshalJSON() ([]byte, error) { + if (f.pool != nil) != (f.pools != nil) { + if f.pool != nil { + return json.Marshal(f.pool) + } else if f.pools != nil { + return json.Marshal(f.pools) + } + } + return nil, errors.New("unexpected config state") +} + // UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON func (f *FakeDNSConfig) UnmarshalJSON(data []byte) error { var pool FakeDNSPoolElementConfig diff --git a/infra/conf/wireguard.go b/infra/conf/wireguard.go index 9952101a0a5a..34ce7215b1d8 100644 --- a/infra/conf/wireguard.go +++ b/infra/conf/wireguard.go @@ -67,7 +67,7 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { var err error config.SecretKey, err = ParseWireGuardKey(c.SecretKey) if err != nil { - return nil, err + return nil, errors.New("invalid WireGuard secret key: %w", err) } if c.Address == nil { @@ -126,6 +126,10 @@ func (c *WireGuardConfig) Build() (proto.Message, error) { func ParseWireGuardKey(str string) (string, error) { var err error + if str == "" { + return "", errors.New("key must not be empty") + } + if len(str)%2 == 0 { _, err = hex.DecodeString(str) if err == nil { diff --git a/infra/conf/xray.go b/infra/conf/xray.go index a9cc88bcf98c..4b084b5638fd 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -241,14 +241,14 @@ func (c *InboundDetourConfig) Build() (*core.InboundHandlerConfig, error) { } rawConfig, err := inboundConfigLoader.LoadWithID(settings, c.Protocol) if err != nil { - return nil, errors.New("failed to load inbound detour config.").Base(err) + return nil, errors.New("failed to load inbound detour config for protocol ", c.Protocol).Base(err) } if dokodemoConfig, ok := rawConfig.(*DokodemoConfig); ok { receiverSettings.ReceiveOriginalDestination = dokodemoConfig.Redirect } ts, err := rawConfig.(Buildable).Build() if err != nil { - return nil, err + return nil, errors.New("failed to build inbound handler for protocol ", c.Protocol).Base(err) } return &core.InboundHandlerConfig{ @@ -303,7 +303,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if c.StreamSetting != nil { ss, err := c.StreamSetting.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build stream settings for outbound detour").Base(err) } senderSettings.StreamSettings = ss } @@ -311,7 +311,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if c.ProxySettings != nil { ps, err := c.ProxySettings.Build() if err != nil { - return nil, errors.New("invalid outbound detour proxy settings.").Base(err) + return nil, errors.New("invalid outbound detour proxy settings").Base(err) } if ps.TransportLayerProxy { if senderSettings.StreamSettings != nil { @@ -331,7 +331,7 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { if c.MuxSettings != nil { ms, err := c.MuxSettings.Build() if err != nil { - return nil, errors.New("failed to build Mux config.").Base(err) + return nil, errors.New("failed to build Mux config").Base(err) } senderSettings.MultiplexSettings = ms } @@ -342,11 +342,11 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) { } rawConfig, err := outboundConfigLoader.LoadWithID(settings, c.Protocol) if err != nil { - return nil, errors.New("failed to parse to outbound detour config.").Base(err) + return nil, errors.New("failed to load outbound detour config for protocol ", c.Protocol).Base(err) } ts, err := rawConfig.(Buildable).Build() if err != nil { - return nil, err + return nil, errors.New("failed to build outbound handler for protocol ", c.Protocol).Base(err) } return &core.OutboundHandlerConfig{ @@ -490,7 +490,7 @@ func (c *Config) Override(o *Config, fn string) { // Build implements Buildable. func (c *Config) Build() (*core.Config, error) { if err := PostProcessConfigureFile(c); err != nil { - return nil, err + return nil, errors.New("failed to post-process configuration file").Base(err) } config := &core.Config{ @@ -504,21 +504,21 @@ func (c *Config) Build() (*core.Config, error) { if c.API != nil { apiConf, err := c.API.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build API configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(apiConf)) } if c.Metrics != nil { metricsConf, err := c.Metrics.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build metrics configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(metricsConf)) } if c.Stats != nil { statsConf, err := c.Stats.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build stats configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(statsConf)) } @@ -536,7 +536,7 @@ func (c *Config) Build() (*core.Config, error) { if c.RouterConfig != nil { routerConfig, err := c.RouterConfig.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build routing configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(routerConfig)) } @@ -544,7 +544,7 @@ func (c *Config) Build() (*core.Config, error) { if c.DNSConfig != nil { dnsApp, err := c.DNSConfig.Build() if err != nil { - return nil, errors.New("failed to parse DNS config").Base(err) + return nil, errors.New("failed to build DNS configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(dnsApp)) } @@ -552,7 +552,7 @@ func (c *Config) Build() (*core.Config, error) { if c.Policy != nil { pc, err := c.Policy.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build policy configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(pc)) } @@ -560,7 +560,7 @@ func (c *Config) Build() (*core.Config, error) { if c.Reverse != nil { r, err := c.Reverse.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build reverse configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(r)) } @@ -568,7 +568,7 @@ func (c *Config) Build() (*core.Config, error) { if c.FakeDNS != nil { r, err := c.FakeDNS.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build fake DNS configuration").Base(err) } config.App = append([]*serial.TypedMessage{serial.ToTypedMessage(r)}, config.App...) } @@ -576,7 +576,7 @@ func (c *Config) Build() (*core.Config, error) { if c.Observatory != nil { r, err := c.Observatory.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build observatory configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(r)) } @@ -584,7 +584,7 @@ func (c *Config) Build() (*core.Config, error) { if c.BurstObservatory != nil { r, err := c.BurstObservatory.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build burst observatory configuration").Base(err) } config.App = append(config.App, serial.ToTypedMessage(r)) } @@ -602,7 +602,7 @@ func (c *Config) Build() (*core.Config, error) { for _, rawInboundConfig := range inbounds { ic, err := rawInboundConfig.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build inbound config with tag ", rawInboundConfig.Tag).Base(err) } config.Inbound = append(config.Inbound, ic) } @@ -616,7 +616,7 @@ func (c *Config) Build() (*core.Config, error) { for _, rawOutboundConfig := range outbounds { oc, err := rawOutboundConfig.Build() if err != nil { - return nil, err + return nil, errors.New("failed to build outbound config with tag ", rawOutboundConfig.Tag).Base(err) } config.Outbound = append(config.Outbound, oc) } diff --git a/proxy/wireguard/gvisortun/tun.go b/proxy/wireguard/gvisortun/tun.go index 65677c483e62..2f9aa33cda4f 100644 --- a/proxy/wireguard/gvisortun/tun.go +++ b/proxy/wireguard/gvisortun/tun.go @@ -10,6 +10,7 @@ import ( "fmt" "net/netip" "os" + "sync" "syscall" "golang.zx2c4.com/wireguard/tun" @@ -33,6 +34,7 @@ type netTun struct { incomingPacket chan *buffer.View mtu int hasV4, hasV6 bool + closeOnce sync.Once } type Net netTun @@ -174,18 +176,15 @@ func (tun *netTun) Flush() error { // Close implements tun.Device func (tun *netTun) Close() error { - tun.stack.RemoveNIC(1) + tun.closeOnce.Do(func() { + tun.stack.RemoveNIC(1) - if tun.events != nil { close(tun.events) - } - tun.ep.Close() + tun.ep.Close() - if tun.incomingPacket != nil { close(tun.incomingPacket) - } - + }) return nil } diff --git a/proxy/wireguard/server_test.go b/proxy/wireguard/server_test.go new file mode 100644 index 000000000000..057b508edb74 --- /dev/null +++ b/proxy/wireguard/server_test.go @@ -0,0 +1,52 @@ +package wireguard_test + +import ( + "context" + "github.com/stretchr/testify/assert" + "runtime/debug" + "testing" + + "github.com/xtls/xray-core/core" + "github.com/xtls/xray-core/proxy/wireguard" +) + +// TestWireGuardServerInitializationError verifies that an error during TUN initialization +// (triggered by an empty SecretKey) in the WireGuard server does not cause a panic and returns an error instead. +func TestWireGuardServerInitializationError(t *testing.T) { + // Create a minimal core instance with default features + config := &core.Config{} + instance, err := core.New(config) + if err != nil { + t.Fatalf("Failed to create core instance: %v", err) + } + // Set the Xray instance in the context + ctx := context.WithValue(context.Background(), core.XrayKey(1), instance) + + // Define the server configuration with an empty SecretKey to trigger error + conf := &wireguard.DeviceConfig{ + IsClient: false, + Endpoint: []string{"10.0.0.1/32"}, + Mtu: 1420, + SecretKey: "", // Empty SecretKey to trigger error + Peers: []*wireguard.PeerConfig{ + { + PublicKey: "some_public_key", + AllowedIps: []string{"10.0.0.2/32"}, + }, + }, + } + + // Use defer to catch any panic and fail the test explicitly + defer func() { + if r := recover(); r != nil { + t.Errorf("TUN initialization panicked: %v", r) + debug.PrintStack() + } + }() + + // Attempt to initialize the WireGuard server + _, err = wireguard.NewServer(ctx, conf) + + // Check that an error is returned + assert.ErrorContains(t, err, "failed to set private_key: hex string does not fit the slice") +} diff --git a/transport/internet/sockopt_darwin.go b/transport/internet/sockopt_darwin.go index 79e2133a0b19..f684de98ef05 100644 --- a/transport/internet/sockopt_darwin.go +++ b/transport/internet/sockopt_darwin.go @@ -1,7 +1,7 @@ package internet import ( - network "net" + gonet "net" "os" "syscall" "unsafe" @@ -108,14 +108,6 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf return err } } - if config.Interface != "" { - InterfaceIndex := getInterfaceIndexByName(config.Interface) - if InterfaceIndex != 0 { - if err := unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, InterfaceIndex); err != nil { - return errors.New("failed to set Interface").Base(err) - } - } - } if config.TcpKeepAliveIdle > 0 || config.TcpKeepAliveInterval > 0 { if config.TcpKeepAliveIdle > 0 { @@ -138,6 +130,23 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf } } + if config.Interface != "" { + iface, err := gonet.InterfaceByName(config.Interface) + + if err != nil { + return errors.New("failed to get interface ", config.Interface).Base(err) + } + if network == "tcp6" || network == "udp6" { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil { + return errors.New("failed to set IPV6_BOUND_IF").Base(err) + } + } else { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil { + return errors.New("failed to set IP_BOUND_IF").Base(err) + } + } + } + return nil } @@ -152,14 +161,6 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) return err } } - if config.Interface != "" { - InterfaceIndex := getInterfaceIndexByName(config.Interface) - if InterfaceIndex != 0 { - if err := unix.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_BOUND_IF, InterfaceIndex); err != nil { - return errors.New("failed to set Interface").Base(err) - } - } - } if config.TcpKeepAliveIdle > 0 || config.TcpKeepAliveInterval > 0 { if config.TcpKeepAliveIdle > 0 { @@ -182,6 +183,29 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) } } + if config.Interface != "" { + iface, err := gonet.InterfaceByName(config.Interface) + + if err != nil { + return errors.New("failed to get interface ", config.Interface).Base(err) + } + if network == "tcp6" || network == "udp6" { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil { + return errors.New("failed to set IPV6_BOUND_IF").Base(err) + } + } else { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil { + return errors.New("failed to set IP_BOUND_IF").Base(err) + } + } + } + + if config.V6Only { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1); err != nil { + return errors.New("failed to set IPV6_V6ONLY").Base(err) + } + } + return nil } @@ -224,24 +248,3 @@ func setReusePort(fd uintptr) error { } return nil } -func getInterfaceIndexByName(name string) int { - ifaces, err := network.Interfaces() - if err == nil { - for _, iface := range ifaces { - if (iface.Flags&network.FlagUp == network.FlagUp) && (iface.Flags&network.FlagLoopback != network.FlagLoopback) { - addrs, _ := iface.Addrs() - for _, addr := range addrs { - if ipnet, ok := addr.(*network.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - if iface.Name == name { - return iface.Index - } - } - } - } - } - - } - } - return 0 -} diff --git a/transport/internet/sockopt_windows.go b/transport/internet/sockopt_windows.go index fa45011d2f1d..cbd3b41e2fa3 100644 --- a/transport/internet/sockopt_windows.go +++ b/transport/internet/sockopt_windows.go @@ -13,6 +13,9 @@ const ( TCP_FASTOPEN = 15 IP_UNICAST_IF = 31 IPV6_UNICAST_IF = 31 + IP_MULTICAST_IF = 9 + IPV6_MULTICAST_IF = 9 + IPV6_V6ONLY = 27 ) func setTFO(fd syscall.Handle, tfo int) error { @@ -41,10 +44,16 @@ func applyOutboundSocketOptions(network string, address string, fd uintptr, conf if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, IP_UNICAST_IF, int(idx)); err != nil { return errors.New("failed to set IP_UNICAST_IF").Base(err) } + if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IP, IP_MULTICAST_IF, int(idx)); err != nil { + return errors.New("failed to set IP_MULTICAST_IF").Base(err) + } } else { if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IPV6, IPV6_UNICAST_IF, inf.Index); err != nil { return errors.New("failed to set IPV6_UNICAST_IF").Base(err) } + if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IPV6, IPV6_MULTICAST_IF, inf.Index); err != nil { + return errors.New("failed to set IPV6_MULTICAST_IF").Base(err) + } } } @@ -82,6 +91,12 @@ func applyInboundSocketOptions(network string, fd uintptr, config *SocketConfig) } } + if config.V6Only { + if err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.IPPROTO_IPV6, IPV6_V6ONLY, 1); err != nil { + return errors.New("failed to set IPV6_V6ONLY").Base(err) + } + } + return nil } diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index 0b3c2f10e7be..63365099cb20 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -59,7 +59,17 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne Port: 0, } } - packetConn, err := ListenSystemPacket(ctx, srcAddr, sockopt) + var lc net.ListenConfig + lc.Control = func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if sockopt != nil { + if err := applyOutboundSocketOptions(network, "", fd, sockopt); err != nil { + errors.LogInfo(ctx, err, "failed to apply socket options") + } + } + }) + } + packetConn, err := lc.ListenPacket(ctx, srcAddr.Network(), srcAddr.String()) if err != nil { return nil, err } @@ -67,23 +77,6 @@ func (d *DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest ne if err != nil { return nil, err } - if sockopt != nil { - sys, err := packetConn.(*net.UDPConn).SyscallConn() - if err != nil { - return nil, err - } - sys.Control(func(fd uintptr) { - var network string - if destAddr.IP.To4() != nil { - network = "udp4" - } else { - network = "udp6" - } - if err := applyOutboundSocketOptions(network, dest.NetAddr(), fd, sockopt); err != nil { - errors.LogInfo(ctx, err, "failed to apply socket options") - } - }) - } return &PacketConnWrapper{ Conn: packetConn, Dest: destAddr,