Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions app/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import (
"sync"
"time"

"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/features/dns"
Expand Down Expand Up @@ -83,9 +85,31 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
return nil, errors.New("unexpected query strategy ", config.QueryStrategy)
}

hosts, err := NewStaticHosts(config.StaticHosts)
if err != nil {
return nil, errors.New("failed to create hosts").Base(err)
var hosts *StaticHosts
mphLoaded := false
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
if domainMatcherPath != "" {
if f, err := os.Open(domainMatcherPath); err == nil {
defer f.Close()
if m, err := router.LoadGeoSiteMatcher(f, "HOSTS"); err == nil {
f.Seek(0, 0)
if hostIPs, err := router.LoadGeoSiteHosts(f); err == nil {
if sh, err := NewStaticHostsFromCache(m, hostIPs); err == nil {
hosts = sh
mphLoaded = true
errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for DNS hosts, size: ", sh.matchers.Size())
}
}
}
}
}

if !mphLoaded {
sh, err := NewStaticHosts(config.StaticHosts)
if err != nil {
return nil, errors.New("failed to create hosts").Base(err)
}
hosts = sh
}

var clients []*Client
Expand Down
49 changes: 48 additions & 1 deletion app/dns/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
// StaticHosts represents static domain-ip mapping in DNS server.
type StaticHosts struct {
ips [][]net.Address
matchers *strmatcher.MatcherGroup
matchers strmatcher.IndexMatcher
}

// NewStaticHosts creates a new StaticHosts instance.
Expand Down Expand Up @@ -124,3 +124,50 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) (
func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) {
return h.lookup(domain, option, 5)
}
func NewStaticHostsFromCache(matcher strmatcher.IndexMatcher, hostIPs map[string][]string) (*StaticHosts, error) {
sh := &StaticHosts{
ips: make([][]net.Address, matcher.Size()+1),
matchers: matcher,
}

order := hostIPs["_ORDER"]
var offset uint32

img, ok := matcher.(*strmatcher.IndexMatcherGroup)
if !ok {
// Single matcher (e.g. only manual or only one geosite)
if len(order) > 0 {
pattern := order[0]
ips := parseIPs(hostIPs[pattern])
for i := uint32(1); i <= matcher.Size(); i++ {
sh.ips[i] = ips
}
}
return sh, nil
}

for i, m := range img.Matchers {
if i < len(order) {
pattern := order[i]
ips := parseIPs(hostIPs[pattern])
for j := uint32(1); j <= m.Size(); j++ {
sh.ips[offset+j] = ips
}
offset += m.Size()
}
}
return sh, nil
}

func parseIPs(raw []string) []net.Address {
addrs := make([]net.Address, 0, len(raw))
for _, s := range raw {
if len(s) > 1 && s[0] == '#' {
rcode, _ := strconv.Atoi(s[1:])
addrs = append(addrs, dns.RCodeError(rcode))
} else {
addrs = append(addrs, net.ParseAddress(s))
}
}
return addrs
}
56 changes: 56 additions & 0 deletions app/dns/hosts_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package dns_test

import (
"bytes"
"testing"

"github.com/google/go-cmp/cmp"
. "github.com/xtls/xray-core/app/dns"
"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/dns"
Expand Down Expand Up @@ -130,3 +132,57 @@ func TestStaticHosts(t *testing.T) {
}
}
}
func TestStaticHostsFromCache(t *testing.T) {
sites := []*router.GeoSite{
{
CountryCode: "cloudflare-dns.com",
Domain: []*router.Domain{
{Type: router.Domain_Full, Value: "example.com"},
},
},
{
CountryCode: "geosite:cn",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "baidu.cn"},
},
},
}
deps := map[string][]string{
"HOSTS": {"cloudflare-dns.com", "geosite:cn"},
}
hostIPs := map[string][]string{
"cloudflare-dns.com": {"1.1.1.1"},
"geosite:cn": {"2.2.2.2"},
"_ORDER": {"cloudflare-dns.com", "geosite:cn"},
}

var buf bytes.Buffer
err := router.SerializeGeoSiteList(sites, deps, hostIPs, &buf)
common.Must(err)

// Load matcher
m, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "HOSTS")
common.Must(err)

// Load hostIPs
f := bytes.NewReader(buf.Bytes())
hips, err := router.LoadGeoSiteHosts(f)
common.Must(err)

hosts, err := NewStaticHostsFromCache(m, hips)
common.Must(err)

{
ips, _ := hosts.Lookup("example.com", dns.IPOption{IPv4Enable: true})
if len(ips) != 1 || ips[0].String() != "1.1.1.1" {
t.Error("failed to lookup example.com from cache")
}
}

{
ips, _ := hosts.Lookup("baidu.cn", dns.IPOption{IPv4Enable: true})
if len(ips) != 1 || ips[0].String() != "2.2.2.2" {
t.Error("failed to lookup baidu.cn from cache deps")
}
}
}
71 changes: 53 additions & 18 deletions app/dns/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,27 @@ import (
"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
"github.com/xtls/xray-core/features/routing"
)

type mphMatcherWrapper struct {
m strmatcher.IndexMatcher
}

func (w *mphMatcherWrapper) Match(s string) bool {
return w.m.Match(s) != nil
}

func (w *mphMatcherWrapper) String() string {
return "mph-matcher"
}

// Server is the interface for Name Server.
type Server interface {
// Name of the Client.
Expand Down Expand Up @@ -132,29 +146,50 @@ func NewClient(
var rules []string
ruleCurr := 0
ruleIter := 0
for i, domain := range ns.PrioritizedDomain {
ns.PrioritizedDomain[i] = nil
domainRule, err := toStrMatcher(domain.Type, domain.Domain)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]")
domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule")

// Check if domain matcher cache is provided via environment
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
var mphLoaded bool

if domainMatcherPath != "" && ns.Tag != "" {
f, err := filesystem.NewFileReader(domainMatcherPath)
if err == nil {
defer f.Close()
g, err := router.LoadGeoSiteMatcher(f, ns.Tag)
if err == nil {
errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for ", ns.Tag, " dns tag)")
updateDomainRule(&mphMatcherWrapper{m: g}, 0, *matcherInfos)
rules = append(rules, "[MPH Cache]")
mphLoaded = true
}
}
originalRuleIdx := ruleCurr
if ruleCurr < len(ns.OriginalRules) {
rule := ns.OriginalRules[ruleCurr]
if ruleCurr >= len(rules) {
rules = append(rules, rule.Rule)
}

if !mphLoaded {
for i, domain := range ns.PrioritizedDomain {
ns.PrioritizedDomain[i] = nil
domainRule, err := toStrMatcher(domain.Type, domain.Domain)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]")
domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule")
}
ruleIter++
if ruleIter >= int(rule.Size) {
ruleIter = 0
originalRuleIdx := ruleCurr
if ruleCurr < len(ns.OriginalRules) {
rule := ns.OriginalRules[ruleCurr]
if ruleCurr >= len(rules) {
rules = append(rules, rule.Rule)
}
ruleIter++
if ruleIter >= int(rule.Size) {
ruleIter = 0
ruleCurr++
}
} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
rules = append(rules, domainRule.String())
ruleCurr++
}
} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
rules = append(rules, domainRule.String())
ruleCurr++
updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
}
updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
}
ns.PrioritizedDomain = nil
runtime.GC()
Expand Down
34 changes: 31 additions & 3 deletions app/router/condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package router

import (
"context"
"io"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -52,7 +53,34 @@ var matcherTypeMap = map[Domain_Type]strmatcher.Type{
}

type DomainMatcher struct {
matchers strmatcher.IndexMatcher
Matchers strmatcher.IndexMatcher
}

func SerializeDomainMatcher(domains []*Domain, w io.Writer) error {

g := strmatcher.NewMphMatcherGroup()
for _, d := range domains {
matcherType, f := matcherTypeMap[d.Type]
if !f {
continue
}

_, err := g.AddPattern(d.Value, matcherType)
if err != nil {
return err
}
}
g.Build()
// serialize
return g.Serialize(w)
}

func NewDomainMatcherFromBuffer(data []byte) (*strmatcher.MphMatcherGroup, error) {
matcher, err := strmatcher.NewMphMatcherGroupFromBuffer(data)
if err != nil {
return nil, err
}
return matcher, nil
}

func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
Expand All @@ -72,12 +100,12 @@ func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
}
g.Build()
return &DomainMatcher{
matchers: g,
Matchers: g,
}, nil
}

func (m *DomainMatcher) ApplyDomain(domain string) bool {
return len(m.matchers.Match(strings.ToLower(domain))) > 0
return len(m.Matchers.Match(strings.ToLower(domain))) > 0
}

// Apply implements Condition.
Expand Down
Loading