diff --git a/app/dns/dns.go b/app/dns/dns.go index 603640f1549f..c108208362ea 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -12,9 +12,11 @@ import ( "sync" "time" + "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/features/dns" @@ -83,9 +85,31 @@ func New(ctx context.Context, config *Config) (*DNS, error) { return nil, errors.New("unexpected query strategy ", config.QueryStrategy) } - hosts, err := NewStaticHosts(config.StaticHosts) - if err != nil { - return nil, errors.New("failed to create hosts").Base(err) + var hosts *StaticHosts + mphLoaded := false + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + if domainMatcherPath != "" { + if f, err := os.Open(domainMatcherPath); err == nil { + defer f.Close() + if m, err := router.LoadGeoSiteMatcher(f, "HOSTS"); err == nil { + f.Seek(0, 0) + if hostIPs, err := router.LoadGeoSiteHosts(f); err == nil { + if sh, err := NewStaticHostsFromCache(m, hostIPs); err == nil { + hosts = sh + mphLoaded = true + errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for DNS hosts, size: ", sh.matchers.Size()) + } + } + } + } + } + + if !mphLoaded { + sh, err := NewStaticHosts(config.StaticHosts) + if err != nil { + return nil, errors.New("failed to create hosts").Base(err) + } + hosts = sh } var clients []*Client diff --git a/app/dns/hosts.go b/app/dns/hosts.go index 7c9cdee37b71..fab08d54c9e1 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -14,7 +14,7 @@ import ( // StaticHosts represents static domain-ip mapping in DNS server. type StaticHosts struct { ips [][]net.Address - matchers *strmatcher.MatcherGroup + matchers strmatcher.IndexMatcher } // NewStaticHosts creates a new StaticHosts instance. @@ -124,3 +124,50 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ( func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) { return h.lookup(domain, option, 5) } +func NewStaticHostsFromCache(matcher strmatcher.IndexMatcher, hostIPs map[string][]string) (*StaticHosts, error) { + sh := &StaticHosts{ + ips: make([][]net.Address, matcher.Size()+1), + matchers: matcher, + } + + order := hostIPs["_ORDER"] + var offset uint32 + + img, ok := matcher.(*strmatcher.IndexMatcherGroup) + if !ok { + // Single matcher (e.g. only manual or only one geosite) + if len(order) > 0 { + pattern := order[0] + ips := parseIPs(hostIPs[pattern]) + for i := uint32(1); i <= matcher.Size(); i++ { + sh.ips[i] = ips + } + } + return sh, nil + } + + for i, m := range img.Matchers { + if i < len(order) { + pattern := order[i] + ips := parseIPs(hostIPs[pattern]) + for j := uint32(1); j <= m.Size(); j++ { + sh.ips[offset+j] = ips + } + offset += m.Size() + } + } + return sh, nil +} + +func parseIPs(raw []string) []net.Address { + addrs := make([]net.Address, 0, len(raw)) + for _, s := range raw { + if len(s) > 1 && s[0] == '#' { + rcode, _ := strconv.Atoi(s[1:]) + addrs = append(addrs, dns.RCodeError(rcode)) + } else { + addrs = append(addrs, net.ParseAddress(s)) + } + } + return addrs +} diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 2c7f8b69ce76..2b9c24d8422b 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -1,10 +1,12 @@ package dns_test import ( + "bytes" "testing" "github.com/google/go-cmp/cmp" . "github.com/xtls/xray-core/app/dns" + "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/features/dns" @@ -130,3 +132,57 @@ func TestStaticHosts(t *testing.T) { } } } +func TestStaticHostsFromCache(t *testing.T) { + sites := []*router.GeoSite{ + { + CountryCode: "cloudflare-dns.com", + Domain: []*router.Domain{ + {Type: router.Domain_Full, Value: "example.com"}, + }, + }, + { + CountryCode: "geosite:cn", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "baidu.cn"}, + }, + }, + } + deps := map[string][]string{ + "HOSTS": {"cloudflare-dns.com", "geosite:cn"}, + } + hostIPs := map[string][]string{ + "cloudflare-dns.com": {"1.1.1.1"}, + "geosite:cn": {"2.2.2.2"}, + "_ORDER": {"cloudflare-dns.com", "geosite:cn"}, + } + + var buf bytes.Buffer + err := router.SerializeGeoSiteList(sites, deps, hostIPs, &buf) + common.Must(err) + + // Load matcher + m, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "HOSTS") + common.Must(err) + + // Load hostIPs + f := bytes.NewReader(buf.Bytes()) + hips, err := router.LoadGeoSiteHosts(f) + common.Must(err) + + hosts, err := NewStaticHostsFromCache(m, hips) + common.Must(err) + + { + ips, _ := hosts.Lookup("example.com", dns.IPOption{IPv4Enable: true}) + if len(ips) != 1 || ips[0].String() != "1.1.1.1" { + t.Error("failed to lookup example.com from cache") + } + } + + { + ips, _ := hosts.Lookup("baidu.cn", dns.IPOption{IPv4Enable: true}) + if len(ips) != 1 || ips[0].String() != "2.2.2.2" { + t.Error("failed to lookup baidu.cn from cache deps") + } + } +} diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index dbab5e8aba21..00d435b59218 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -10,6 +10,8 @@ import ( "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/platform" + "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/core" @@ -17,6 +19,18 @@ import ( "github.com/xtls/xray-core/features/routing" ) +type mphMatcherWrapper struct { + m strmatcher.IndexMatcher +} + +func (w *mphMatcherWrapper) Match(s string) bool { + return w.m.Match(s) != nil +} + +func (w *mphMatcherWrapper) String() string { + return "mph-matcher" +} + // Server is the interface for Name Server. type Server interface { // Name of the Client. @@ -132,29 +146,50 @@ func NewClient( var rules []string ruleCurr := 0 ruleIter := 0 - for i, domain := range ns.PrioritizedDomain { - ns.PrioritizedDomain[i] = nil - domainRule, err := toStrMatcher(domain.Type, domain.Domain) - if err != nil { - errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]") - domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule") + + // Check if domain matcher cache is provided via environment + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + var mphLoaded bool + + if domainMatcherPath != "" && ns.Tag != "" { + f, err := filesystem.NewFileReader(domainMatcherPath) + if err == nil { + defer f.Close() + g, err := router.LoadGeoSiteMatcher(f, ns.Tag) + if err == nil { + errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for ", ns.Tag, " dns tag)") + updateDomainRule(&mphMatcherWrapper{m: g}, 0, *matcherInfos) + rules = append(rules, "[MPH Cache]") + mphLoaded = true + } } - originalRuleIdx := ruleCurr - if ruleCurr < len(ns.OriginalRules) { - rule := ns.OriginalRules[ruleCurr] - if ruleCurr >= len(rules) { - rules = append(rules, rule.Rule) + } + + if !mphLoaded { + for i, domain := range ns.PrioritizedDomain { + ns.PrioritizedDomain[i] = nil + domainRule, err := toStrMatcher(domain.Type, domain.Domain) + if err != nil { + errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]") + domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule") } - ruleIter++ - if ruleIter >= int(rule.Size) { - ruleIter = 0 + originalRuleIdx := ruleCurr + if ruleCurr < len(ns.OriginalRules) { + rule := ns.OriginalRules[ruleCurr] + if ruleCurr >= len(rules) { + rules = append(rules, rule.Rule) + } + ruleIter++ + if ruleIter >= int(rule.Size) { + ruleIter = 0 + ruleCurr++ + } + } else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests) + rules = append(rules, domainRule.String()) ruleCurr++ } - } else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests) - rules = append(rules, domainRule.String()) - ruleCurr++ + updateDomainRule(domainRule, originalRuleIdx, *matcherInfos) } - updateDomainRule(domainRule, originalRuleIdx, *matcherInfos) } ns.PrioritizedDomain = nil runtime.GC() diff --git a/app/router/condition.go b/app/router/condition.go index 873f121d6461..54af816544c6 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -2,6 +2,7 @@ package router import ( "context" + "io" "os" "path/filepath" "regexp" @@ -52,7 +53,34 @@ var matcherTypeMap = map[Domain_Type]strmatcher.Type{ } type DomainMatcher struct { - matchers strmatcher.IndexMatcher + Matchers strmatcher.IndexMatcher +} + +func SerializeDomainMatcher(domains []*Domain, w io.Writer) error { + + g := strmatcher.NewMphMatcherGroup() + for _, d := range domains { + matcherType, f := matcherTypeMap[d.Type] + if !f { + continue + } + + _, err := g.AddPattern(d.Value, matcherType) + if err != nil { + return err + } + } + g.Build() + // serialize + return g.Serialize(w) +} + +func NewDomainMatcherFromBuffer(data []byte) (*strmatcher.MphMatcherGroup, error) { + matcher, err := strmatcher.NewMphMatcherGroupFromBuffer(data) + if err != nil { + return nil, err + } + return matcher, nil } func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { @@ -72,12 +100,12 @@ func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { } g.Build() return &DomainMatcher{ - matchers: g, + Matchers: g, }, nil } func (m *DomainMatcher) ApplyDomain(domain string) bool { - return len(m.matchers.Match(strings.ToLower(domain))) > 0 + return len(m.Matchers.Match(strings.ToLower(domain))) > 0 } // Apply implements Condition. diff --git a/app/router/condition_serialize_test.go b/app/router/condition_serialize_test.go new file mode 100644 index 000000000000..4c6ff46467f6 --- /dev/null +++ b/app/router/condition_serialize_test.go @@ -0,0 +1,167 @@ +package router_test + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/xtls/xray-core/app/router" + "github.com/xtls/xray-core/common/platform/filesystem" +) + +func TestDomainMatcherSerialization(t *testing.T) { + + domains := []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.com"}, + {Type: router.Domain_Domain, Value: "v2ray.com"}, + {Type: router.Domain_Full, Value: "full.example.com"}, + } + + var buf bytes.Buffer + if err := router.SerializeDomainMatcher(domains, &buf); err != nil { + t.Fatalf("Serialize failed: %v", err) + } + + matcher, err := router.NewDomainMatcherFromBuffer(buf.Bytes()) + if err != nil { + t.Fatalf("Deserialize failed: %v", err) + } + + dMatcher := &router.DomainMatcher{ + Matchers: matcher, + } + testCases := []struct { + Input string + Match bool + }{ + {"google.com", true}, + {"maps.google.com", true}, + {"v2ray.com", true}, + {"full.example.com", true}, + + {"example.com", false}, + } + + for _, tc := range testCases { + if res := dMatcher.ApplyDomain(tc.Input); res != tc.Match { + t.Errorf("Match(%s) = %v, want %v", tc.Input, res, tc.Match) + } + } +} + +func TestGeoSiteSerialization(t *testing.T) { + sites := []*router.GeoSite{ + { + CountryCode: "CN", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "baidu.cn"}, + {Type: router.Domain_Domain, Value: "qq.com"}, + }, + }, + { + CountryCode: "US", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.com"}, + {Type: router.Domain_Domain, Value: "facebook.com"}, + }, + }, + } + + var buf bytes.Buffer + if err := router.SerializeGeoSiteList(sites, nil, nil, &buf); err != nil { + t.Fatalf("SerializeGeoSiteList failed: %v", err) + } + + tmp := t.TempDir() + path := filepath.Join(tmp, "matcher.cache") + + f, err := os.Create(path) + require.NoError(t, err) + _, err = f.Write(buf.Bytes()) + require.NoError(t, err) + f.Close() + + f, err = os.Open(path) + require.NoError(t, err) + defer f.Close() + + require.NoError(t, err) + data, _ := filesystem.ReadFile(path) + + // cn + gp, err := router.LoadGeoSiteMatcher(bytes.NewReader(data), "CN") + if err != nil { + t.Fatalf("LoadGeoSiteMatcher(CN) failed: %v", err) + } + + cnMatcher := &router.DomainMatcher{ + Matchers: gp, + } + + if !cnMatcher.ApplyDomain("baidu.cn") { + t.Error("CN matcher should match baidu.cn") + } + if cnMatcher.ApplyDomain("google.com") { + t.Error("CN matcher should NOT match google.com") + } + + // us + gp, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "US") + if err != nil { + t.Fatalf("LoadGeoSiteMatcher(US) failed: %v", err) + } + + usMatcher := &router.DomainMatcher{ + Matchers: gp, + } + if !usMatcher.ApplyDomain("google.com") { + t.Error("US matcher should match google.com") + } + if usMatcher.ApplyDomain("baidu.cn") { + t.Error("US matcher should NOT match baidu.cn") + } + + // unknown + _, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "unknown") + if err == nil { + t.Error("LoadGeoSiteMatcher(unknown) should fail") + } +} +func TestGeoSiteSerializationWithDeps(t *testing.T) { + sites := []*router.GeoSite{ + { + CountryCode: "geosite:cn", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "baidu.cn"}, + }, + }, + { + CountryCode: "geosite:google@cn", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.cn"}, + }, + }, + { + CountryCode: "rule-1", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.com"}, + }, + }, + } + deps := map[string][]string{ + "rule-1": {"geosite:cn", "geosite:google@cn"}, + } + + var buf bytes.Buffer + err := router.SerializeGeoSiteList(sites, deps, nil, &buf) + require.NoError(t, err) + + matcher, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "rule-1") + require.NoError(t, err) + + require.True(t, matcher.Match("google.com") != nil) + require.True(t, matcher.Match("baidu.cn") != nil) + require.True(t, matcher.Match("google.cn") != nil) +} diff --git a/app/router/config.go b/app/router/config.go index c41e6cfc9625..4288f2af302e 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/platform" + "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/features/routing" ) @@ -105,11 +107,25 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if len(rr.Domain) > 0 { - matcher, err := NewMphMatcherGroup(rr.Domain) - if err != nil { - return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) + var matcher *DomainMatcher + var err error + // Check if domain matcher cache is provided via environment + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + + if domainMatcherPath != "" { + matcher, err = GetDomainMatcherWithRuleTag(domainMatcherPath, rr.RuleTag) + if err != nil { + return nil, errors.New("failed to build domain condition from cached MphDomainMatcher").Base(err) + } + errors.LogDebug(context.Background(), "MphDomainMatcher loaded from cache for ", rr.RuleTag, " rule tag)") + + } else { + matcher, err = NewMphMatcherGroup(rr.Domain) + if err != nil { + return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) + } + errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") } - errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") conds.Add(matcher) rr.Domain = nil runtime.GC() @@ -172,3 +188,20 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch return nil, errors.New("unrecognized balancer type") } } + +func GetDomainMatcherWithRuleTag(domainMatcherPath string, ruleTag string) (*DomainMatcher, error) { + f, err := filesystem.NewFileReader(domainMatcherPath) + if err != nil { + return nil, errors.New("failed to load file: ", domainMatcherPath).Base(err) + } + defer f.Close() + + g, err := LoadGeoSiteMatcher(f, ruleTag) + if err != nil { + return nil, errors.New("failed to load file:", domainMatcherPath).Base(err) + } + return &DomainMatcher{ + Matchers: g, + }, nil + +} diff --git a/app/router/geosite_compact.go b/app/router/geosite_compact.go new file mode 100644 index 000000000000..50fee83fce06 --- /dev/null +++ b/app/router/geosite_compact.go @@ -0,0 +1,100 @@ +package router + +import ( + "encoding/gob" + "errors" + "io" + "runtime" + + "github.com/xtls/xray-core/common/strmatcher" +) + +type geoSiteListGob struct { + Sites map[string][]byte + Deps map[string][]string + Hosts map[string][]string +} + +func SerializeGeoSiteList(sites []*GeoSite, deps map[string][]string, hosts map[string][]string, w io.Writer) error { + data := geoSiteListGob{ + Sites: make(map[string][]byte), + Deps: deps, + Hosts: hosts, + } + + for _, site := range sites { + if site == nil { + continue + } + var buf bytesWriter + if err := SerializeDomainMatcher(site.Domain, &buf); err != nil { + return err + } + data.Sites[site.CountryCode] = buf.Bytes() + } + + return gob.NewEncoder(w).Encode(data) +} + +type bytesWriter struct { + data []byte +} + +func (w *bytesWriter) Write(p []byte) (n int, err error) { + w.data = append(w.data, p...) + return len(p), nil +} + +func (w *bytesWriter) Bytes() []byte { + return w.data +} + +func LoadGeoSiteMatcher(r io.Reader, countryCode string) (strmatcher.IndexMatcher, error) { + var data geoSiteListGob + if err := gob.NewDecoder(r).Decode(&data); err != nil { + return nil, err + } + + return loadWithDeps(&data, countryCode, make(map[string]bool)) +} + +func loadWithDeps(data *geoSiteListGob, code string, visited map[string]bool) (strmatcher.IndexMatcher, error) { + if visited[code] { + return nil, errors.New("cyclic dependency") + } + visited[code] = true + + var matchers []strmatcher.IndexMatcher + + if siteData, ok := data.Sites[code]; ok { + m, err := NewDomainMatcherFromBuffer(siteData) + if err == nil { + matchers = append(matchers, m) + } + } + + if deps, ok := data.Deps[code]; ok { + for _, dep := range deps { + m, err := loadWithDeps(data, dep, visited) + if err == nil { + matchers = append(matchers, m) + } + } + } + + if len(matchers) == 0 { + return nil, errors.New("matcher not found for: " + code) + } + if len(matchers) == 1 { + return matchers[0], nil + } + runtime.GC() + return &strmatcher.IndexMatcherGroup{Matchers: matchers}, nil +} +func LoadGeoSiteHosts(r io.Reader) (map[string][]string, error) { + var data geoSiteListGob + if err := gob.NewDecoder(r).Decode(&data); err != nil { + return nil, err + } + return data.Hosts, nil +} diff --git a/common/platform/platform.go b/common/platform/platform.go index 80e62874d6e4..6446873be7a9 100644 --- a/common/platform/platform.go +++ b/common/platform/platform.go @@ -24,6 +24,8 @@ const ( XUDPBaseKey = "xray.xudp.basekey" TunFdKey = "xray.tun.fd" + + MphCachePath = "xray.mph.cache" ) type EnvFlag struct { diff --git a/common/strmatcher/ac_automaton_matcher.go b/common/strmatcher/ac_automaton_matcher.go index 24be9dac9193..7844333d1b87 100644 --- a/common/strmatcher/ac_automaton_matcher.go +++ b/common/strmatcher/ac_automaton_matcher.go @@ -7,8 +7,8 @@ import ( const validCharCount = 53 type MatchType struct { - matchType Type - exist bool + Type Type + Exist bool } const ( @@ -17,23 +17,23 @@ const ( ) type Edge struct { - edgeType bool - nextNode int + Type bool + NextNode int } type ACAutomaton struct { - trie [][validCharCount]Edge - fail []int - exists []MatchType - count int + Trie [][validCharCount]Edge + Fail []int + Exists []MatchType + Count int } func newNode() [validCharCount]Edge { var s [validCharCount]Edge for i := range s { s[i] = Edge{ - edgeType: FailEdge, - nextNode: 0, + Type: FailEdge, + NextNode: 0, } } return s @@ -123,11 +123,11 @@ var char2Index = []int{ func NewACAutomaton() *ACAutomaton { ac := new(ACAutomaton) - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, + ac.Trie = append(ac.Trie, newNode()) + ac.Fail = append(ac.Fail, 0) + ac.Exists = append(ac.Exists, MatchType{ + Type: Full, + Exist: false, }) return ac } @@ -136,53 +136,53 @@ func (ac *ACAutomaton) Add(domain string, t Type) { node := 0 for i := len(domain) - 1; i >= 0; i-- { idx := char2Index[domain[i]] - if ac.trie[node][idx].nextNode == 0 { - ac.count++ - if len(ac.trie) < ac.count+1 { - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, + if ac.Trie[node][idx].NextNode == 0 { + ac.Count++ + if len(ac.Trie) < ac.Count+1 { + ac.Trie = append(ac.Trie, newNode()) + ac.Fail = append(ac.Fail, 0) + ac.Exists = append(ac.Exists, MatchType{ + Type: Full, + Exist: false, }) } - ac.trie[node][idx] = Edge{ - edgeType: TrieEdge, - nextNode: ac.count, + ac.Trie[node][idx] = Edge{ + Type: TrieEdge, + NextNode: ac.Count, } } - node = ac.trie[node][idx].nextNode + node = ac.Trie[node][idx].NextNode } - ac.exists[node] = MatchType{ - matchType: t, - exist: true, + ac.Exists[node] = MatchType{ + Type: t, + Exist: true, } switch t { case Domain: - ac.exists[node] = MatchType{ - matchType: Full, - exist: true, + ac.Exists[node] = MatchType{ + Type: Full, + Exist: true, } idx := char2Index['.'] - if ac.trie[node][idx].nextNode == 0 { - ac.count++ - if len(ac.trie) < ac.count+1 { - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, + if ac.Trie[node][idx].NextNode == 0 { + ac.Count++ + if len(ac.Trie) < ac.Count+1 { + ac.Trie = append(ac.Trie, newNode()) + ac.Fail = append(ac.Fail, 0) + ac.Exists = append(ac.Exists, MatchType{ + Type: Full, + Exist: false, }) } - ac.trie[node][idx] = Edge{ - edgeType: TrieEdge, - nextNode: ac.count, + ac.Trie[node][idx] = Edge{ + Type: TrieEdge, + NextNode: ac.Count, } } - node = ac.trie[node][idx].nextNode - ac.exists[node] = MatchType{ - matchType: t, - exist: true, + node = ac.Trie[node][idx].NextNode + ac.Exists[node] = MatchType{ + Type: t, + Exist: true, } default: break @@ -192,8 +192,8 @@ func (ac *ACAutomaton) Add(domain string, t Type) { func (ac *ACAutomaton) Build() { queue := list.New() for i := 0; i < validCharCount; i++ { - if ac.trie[0][i].nextNode != 0 { - queue.PushBack(ac.trie[0][i]) + if ac.Trie[0][i].NextNode != 0 { + queue.PushBack(ac.Trie[0][i]) } } for { @@ -201,16 +201,16 @@ func (ac *ACAutomaton) Build() { if front == nil { break } else { - node := front.Value.(Edge).nextNode + node := front.Value.(Edge).NextNode queue.Remove(front) for i := 0; i < validCharCount; i++ { - if ac.trie[node][i].nextNode != 0 { - ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode - queue.PushBack(ac.trie[node][i]) + if ac.Trie[node][i].NextNode != 0 { + ac.Fail[ac.Trie[node][i].NextNode] = ac.Trie[ac.Fail[node]][i].NextNode + queue.PushBack(ac.Trie[node][i]) } else { - ac.trie[node][i] = Edge{ - edgeType: FailEdge, - nextNode: ac.trie[ac.fail[node]][i].nextNode, + ac.Trie[node][i] = Edge{ + Type: FailEdge, + NextNode: ac.Trie[ac.Fail[node]][i].NextNode, } } } @@ -230,9 +230,9 @@ func (ac *ACAutomaton) Match(s string) bool { return false } idx := char2Index[chr] - fullMatch = fullMatch && ac.trie[node][idx].edgeType - node = ac.trie[node][idx].nextNode - switch ac.exists[node].matchType { + fullMatch = fullMatch && ac.Trie[node][idx].Type + node = ac.Trie[node][idx].NextNode + switch ac.Exists[node].Type { case Substr: return true case Domain: @@ -243,5 +243,5 @@ func (ac *ACAutomaton) Match(s string) bool { break } } - return fullMatch && ac.exists[node].exist + return fullMatch && ac.Exists[node].Exist } diff --git a/common/strmatcher/matchers.go b/common/strmatcher/matchers.go index b5ab09c4cb9f..915927db8991 100644 --- a/common/strmatcher/matchers.go +++ b/common/strmatcher/matchers.go @@ -39,14 +39,18 @@ func (m domainMatcher) String() string { return "domain:" + string(m) } -type regexMatcher struct { - pattern *regexp.Regexp +type RegexMatcher struct { + Pattern string + reg *regexp.Regexp } -func (m *regexMatcher) Match(s string) bool { - return m.pattern.MatchString(s) +func (m *RegexMatcher) Match(s string) bool { + if m.reg == nil { + m.reg = regexp.MustCompile(m.Pattern) + } + return m.reg.MatchString(s) } -func (m *regexMatcher) String() string { - return "regexp:" + m.pattern.String() +func (m *RegexMatcher) String() string { + return "regexp:" + m.Pattern } diff --git a/common/strmatcher/mph_matcher.go b/common/strmatcher/mph_matcher.go index 3c10cb4920bd..ff3dea65c5a7 100644 --- a/common/strmatcher/mph_matcher.go +++ b/common/strmatcher/mph_matcher.go @@ -25,40 +25,40 @@ func RollingHash(s string) uint32 { // 2. `substr` patterns are matched by ac automaton; // 3. `regex` patterns are matched with the regex library. type MphMatcherGroup struct { - ac *ACAutomaton - otherMatchers []matcherEntry - rules []string - level0 []uint32 - level0Mask int - level1 []uint32 - level1Mask int - count uint32 - ruleMap *map[string]uint32 + Ac *ACAutomaton + OtherMatchers []MatcherEntry + Rules []string + Level0 []uint32 + Level0Mask int + Level1 []uint32 + Level1Mask int + Count uint32 + RuleMap *map[string]uint32 } func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { h := RollingHash(pattern) switch t { case Domain: - (*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.') + (*g.RuleMap)["."+pattern] = h*PrimeRK + uint32('.') fallthrough case Full: - (*g.ruleMap)[pattern] = h + (*g.RuleMap)[pattern] = h default: } } func NewMphMatcherGroup() *MphMatcherGroup { return &MphMatcherGroup{ - ac: nil, - otherMatchers: nil, - rules: nil, - level0: nil, - level0Mask: 0, - level1: nil, - level1Mask: 0, - count: 1, - ruleMap: &map[string]uint32{}, + Ac: nil, + OtherMatchers: nil, + Rules: nil, + Level0: nil, + Level0Mask: 0, + Level1: nil, + Level1Mask: 0, + Count: 1, + RuleMap: &map[string]uint32{}, } } @@ -66,10 +66,10 @@ func NewMphMatcherGroup() *MphMatcherGroup { func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { switch t { case Substr: - if g.ac == nil { - g.ac = NewACAutomaton() + if g.Ac == nil { + g.Ac = NewACAutomaton() } - g.ac.Add(pattern, t) + g.Ac.Add(pattern, t) case Full, Domain: pattern = strings.ToLower(pattern) g.AddFullOrDomainPattern(pattern, t) @@ -78,39 +78,39 @@ func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { if err != nil { return 0, err } - g.otherMatchers = append(g.otherMatchers, matcherEntry{ - m: ®exMatcher{pattern: r}, - id: g.count, + g.OtherMatchers = append(g.OtherMatchers, MatcherEntry{ + M: &RegexMatcher{Pattern: pattern, reg: r}, + Id: g.Count, }) default: panic("Unknown type") } - return g.count, nil + return g.Count, nil } // Build builds a minimal perfect hash table and ac automaton from insert rules func (g *MphMatcherGroup) Build() { - if g.ac != nil { - g.ac.Build() + if g.Ac != nil { + g.Ac.Build() } - keyLen := len(*g.ruleMap) + keyLen := len(*g.RuleMap) if keyLen == 0 { keyLen = 1 - (*g.ruleMap)["empty___"] = RollingHash("empty___") + (*g.RuleMap)["empty___"] = RollingHash("empty___") } - g.level0 = make([]uint32, nextPow2(keyLen/4)) - g.level0Mask = len(g.level0) - 1 - g.level1 = make([]uint32, nextPow2(keyLen)) - g.level1Mask = len(g.level1) - 1 - sparseBuckets := make([][]int, len(g.level0)) + g.Level0 = make([]uint32, nextPow2(keyLen/4)) + g.Level0Mask = len(g.Level0) - 1 + g.Level1 = make([]uint32, nextPow2(keyLen)) + g.Level1Mask = len(g.Level1) - 1 + sparseBuckets := make([][]int, len(g.Level0)) var ruleIdx int - for rule, hash := range *g.ruleMap { - n := int(hash) & g.level0Mask - g.rules = append(g.rules, rule) + for rule, hash := range *g.RuleMap { + n := int(hash) & g.Level0Mask + g.Rules = append(g.Rules, rule) sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) ruleIdx++ } - g.ruleMap = nil + g.RuleMap = nil var buckets []indexBucket for n, vals := range sparseBuckets { if len(vals) > 0 { @@ -119,7 +119,7 @@ func (g *MphMatcherGroup) Build() { } sort.Sort(bySize(buckets)) - occ := make([]bool, len(g.level1)) + occ := make([]bool, len(g.Level1)) var tmpOcc []int for _, bucket := range buckets { seed := uint32(0) @@ -127,7 +127,7 @@ func (g *MphMatcherGroup) Build() { findSeed := true tmpOcc = tmpOcc[:0] for _, i := range bucket.vals { - n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask + n := int(strhashFallback(unsafe.Pointer(&g.Rules[i]), uintptr(seed))) & g.Level1Mask if occ[n] { for _, n := range tmpOcc { occ[n] = false @@ -138,10 +138,10 @@ func (g *MphMatcherGroup) Build() { } occ[n] = true tmpOcc = append(tmpOcc, n) - g.level1[n] = uint32(i) + g.Level1[n] = uint32(i) } if findSeed { - g.level0[bucket.n] = seed + g.Level0[bucket.n] = seed break } } @@ -159,11 +159,11 @@ func nextPow2(v int) int { // Lookup searches for s in t and returns its index and whether it was found. func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { - i0 := int(h) & g.level0Mask - seed := g.level0[i0] - i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask - n := g.level1[i1] - return s == g.rules[int(n)] + i0 := int(h) & g.Level0Mask + seed := g.Level0[i0] + i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.Level1Mask + n := g.Level1[i1] + return s == g.Rules[int(n)] } // Match implements IndexMatcher.Match. @@ -183,13 +183,13 @@ func (g *MphMatcherGroup) Match(pattern string) []uint32 { result = append(result, 1) return result } - if g.ac != nil && g.ac.Match(pattern) { + if g.Ac != nil && g.Ac.Match(pattern) { result = append(result, 1) return result } - for _, e := range g.otherMatchers { - if e.m.Match(pattern) { - result = append(result, e.id) + for _, e := range g.OtherMatchers { + if e.M.Match(pattern) { + result = append(result, e.Id) return result } } @@ -302,3 +302,7 @@ func readUnaligned64(p unsafe.Pointer) uint64 { q := (*[8]byte)(p) return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 } + +func (g *MphMatcherGroup) Size() uint32 { + return g.Count +} diff --git a/common/strmatcher/mph_matcher_compact.go b/common/strmatcher/mph_matcher_compact.go new file mode 100644 index 000000000000..a40b9f568a14 --- /dev/null +++ b/common/strmatcher/mph_matcher_compact.go @@ -0,0 +1,47 @@ +package strmatcher + +import ( + "bytes" + "encoding/gob" + "io" +) + +func init() { + gob.Register(&RegexMatcher{}) + gob.Register(fullMatcher("")) + gob.Register(substrMatcher("")) + gob.Register(domainMatcher("")) +} + +func (g *MphMatcherGroup) Serialize(w io.Writer) error { + data := MphMatcherGroup{ + Ac: g.Ac, + OtherMatchers: g.OtherMatchers, + Rules: g.Rules, + Level0: g.Level0, + Level0Mask: g.Level0Mask, + Level1: g.Level1, + Level1Mask: g.Level1Mask, + Count: g.Count, + } + return gob.NewEncoder(w).Encode(data) +} + +func NewMphMatcherGroupFromBuffer(data []byte) (*MphMatcherGroup, error) { + var gData MphMatcherGroup + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&gData); err != nil { + return nil, err + } + + g := NewMphMatcherGroup() + g.Ac = gData.Ac + g.OtherMatchers = gData.OtherMatchers + g.Rules = gData.Rules + g.Level0 = gData.Level0 + g.Level0Mask = gData.Level0Mask + g.Level1 = gData.Level1 + g.Level1Mask = gData.Level1Mask + g.Count = gData.Count + + return g, nil +} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 4035acc3b2f9..89e7dae68053 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -41,8 +41,9 @@ func (t Type) New(pattern string) (Matcher, error) { if err != nil { return nil, err } - return ®exMatcher{ - pattern: r, + return &RegexMatcher{ + Pattern: pattern, + reg: r, }, nil default: return nil, errors.New("unk type") @@ -53,11 +54,13 @@ func (t Type) New(pattern string) (Matcher, error) { type IndexMatcher interface { // Match returns the index of a matcher that matches the input. It returns empty array if no such matcher exists. Match(input string) []uint32 + // Size returns the number of matchers in the group. + Size() uint32 } -type matcherEntry struct { - m Matcher - id uint32 +type MatcherEntry struct { + M Matcher + Id uint32 } // MatcherGroup is an implementation of IndexMatcher. @@ -66,7 +69,7 @@ type MatcherGroup struct { count uint32 fullMatcher FullMatcherGroup domainMatcher DomainMatcherGroup - otherMatchers []matcherEntry + otherMatchers []MatcherEntry } // Add adds a new Matcher into the MatcherGroup, and returns its index. The index will never be 0. @@ -80,9 +83,9 @@ func (g *MatcherGroup) Add(m Matcher) uint32 { case domainMatcher: g.domainMatcher.addMatcher(tm, c) default: - g.otherMatchers = append(g.otherMatchers, matcherEntry{ - m: m, - id: c, + g.otherMatchers = append(g.otherMatchers, MatcherEntry{ + M: m, + Id: c, }) } @@ -95,8 +98,8 @@ func (g *MatcherGroup) Match(pattern string) []uint32 { result = append(result, g.fullMatcher.Match(pattern)...) result = append(result, g.domainMatcher.Match(pattern)...) for _, e := range g.otherMatchers { - if e.m.Match(pattern) { - result = append(result, e.id) + if e.M.Match(pattern) { + result = append(result, e.Id) } } return result @@ -106,3 +109,33 @@ func (g *MatcherGroup) Match(pattern string) []uint32 { func (g *MatcherGroup) Size() uint32 { return g.count } + +type IndexMatcherGroup struct { + Matchers []IndexMatcher +} + +func (g *IndexMatcherGroup) Match(input string) []uint32 { + var offset uint32 + for _, m := range g.Matchers { + if res := m.Match(input); len(res) > 0 { + if offset == 0 { + return res + } + shifted := make([]uint32, len(res)) + for i, id := range res { + shifted[i] = id + offset + } + return shifted + } + offset += m.Size() + } + return nil +} + +func (g *IndexMatcherGroup) Size() uint32 { + var count uint32 + for _, m := range g.Matchers { + count += m.Size() + } + return count +} diff --git a/infra/conf/router.go b/infra/conf/router.go index 1e1f6b803758..bc3246108855 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/serial" "google.golang.org/protobuf/proto" @@ -204,6 +205,13 @@ func loadIP(file, code string) ([]*router.CIDR, error) { } func loadSite(file, code string) ([]*router.Domain, error) { + + // Check if domain matcher cache is provided via environment + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + if domainMatcherPath != "" { + return []*router.Domain{{}}, nil + } + bs, err := loadFile(file, code) if err != nil { return nil, err diff --git a/infra/conf/xray.go b/infra/conf/xray.go index 9e5c1394ab0d..39a1f76365b6 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -1,16 +1,21 @@ package conf import ( + "bytes" "context" "encoding/json" + "os" "path/filepath" + "sort" "strings" "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/proxyman" + "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/app/stats" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/platform" "github.com/xtls/xray-core/common/serial" core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/transport/internet" @@ -607,6 +612,187 @@ func (c *Config) Build() (*core.Config, error) { return config, nil } +func (c *Config) BuildMPHCache(customMatcherFilePath *string) error { + var geosite []*router.GeoSite + deps := make(map[string][]string) + uniqueGeosites := make(map[string]bool) + uniqueTags := make(map[string]bool) + matcherFilePath := platform.GetAssetLocation("matcher.cache") + + if customMatcherFilePath != nil { + matcherFilePath = *customMatcherFilePath + } + + processGeosite := func(dStr string) bool { + prefix := "" + if strings.HasPrefix(dStr, "geosite:") { + prefix = "geosite:" + } else if strings.HasPrefix(dStr, "ext-domain:") { + prefix = "ext-domain:" + } + if prefix == "" { + return false + } + key := strings.ToLower(dStr) + country := strings.ToUpper(dStr[len(prefix):]) + if !uniqueGeosites[country] { + ds, err := loadGeositeWithAttr("geosite.dat", country) + if err == nil { + uniqueGeosites[country] = true + geosite = append(geosite, &router.GeoSite{CountryCode: key, Domain: ds}) + } + } + return true + } + + processDomains := func(tag string, rawDomains []string) { + var manualDomains []*router.Domain + var dDeps []string + for _, dStr := range rawDomains { + if processGeosite(dStr) { + dDeps = append(dDeps, strings.ToLower(dStr)) + } else { + ds, err := parseDomainRule(dStr) + if err == nil { + manualDomains = append(manualDomains, ds...) + } + } + } + if len(manualDomains) > 0 { + if !uniqueTags[tag] { + uniqueTags[tag] = true + geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualDomains}) + } + } + if len(dDeps) > 0 { + deps[tag] = append(deps[tag], dDeps...) + } + } + + // proccess rules + if c.RouterConfig != nil { + for _, rawRule := range c.RouterConfig.RuleList { + type SimpleRule struct { + RuleTag string `json:"ruleTag"` + Domain *StringList `json:"domain"` + Domains *StringList `json:"domains"` + } + var sr SimpleRule + json.Unmarshal(rawRule, &sr) + if sr.RuleTag == "" { + continue + } + var allDomains []string + if sr.Domain != nil { + allDomains = append(allDomains, *sr.Domain...) + } + if sr.Domains != nil { + allDomains = append(allDomains, *sr.Domains...) + } + processDomains(sr.RuleTag, allDomains) + } + } + + // proccess dns servers + if c.DNSConfig != nil { + for _, ns := range c.DNSConfig.Servers { + if ns.Tag == "" { + continue + } + processDomains(ns.Tag, ns.Domains) + } + } + + var hostIPs map[string][]string + if c.DNSConfig != nil && c.DNSConfig.Hosts != nil { + hostIPs = make(map[string][]string) + var hostDeps []string + var hostPatterns []string + + // use raw map to avoid expanding geosites + var domains []string + for domain := range c.DNSConfig.Hosts.Hosts { + domains = append(domains, domain) + } + sort.Strings(domains) + + manualHostGroups := make(map[string][]*router.Domain) + manualHostIPs := make(map[string][]string) + manualHostNames := make(map[string]string) + + for _, domain := range domains { + ha := c.DNSConfig.Hosts.Hosts[domain] + m := getHostMapping(ha) + + var ips []string + if m.ProxiedDomain != "" { + ips = append(ips, m.ProxiedDomain) + } else { + for _, ip := range m.Ip { + ips = append(ips, net.IPAddress(ip).String()) + } + } + + if processGeosite(domain) { + tag := strings.ToLower(domain) + hostDeps = append(hostDeps, tag) + hostIPs[tag] = ips + hostPatterns = append(hostPatterns, domain) + } else { + // build manual domains by their destination IPs + sort.Strings(ips) + ipKey := strings.Join(ips, ",") + ds, err := parseDomainRule(domain) + if err == nil { + manualHostGroups[ipKey] = append(manualHostGroups[ipKey], ds...) + manualHostIPs[ipKey] = ips + if _, ok := manualHostNames[ipKey]; !ok { + manualHostNames[ipKey] = domain + } + } + } + } + + // create manual host groups + var ipKeys []string + for k := range manualHostGroups { + ipKeys = append(ipKeys, k) + } + sort.Strings(ipKeys) + + for _, k := range ipKeys { + tag := manualHostNames[k] + geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualHostGroups[k]}) + hostDeps = append(hostDeps, tag) + hostIPs[tag] = manualHostIPs[k] + + // record tag _ORDER links the matcher to IP addresses + hostPatterns = append(hostPatterns, tag) + } + + deps["HOSTS"] = hostDeps + hostIPs["_ORDER"] = hostPatterns + } + + f, err := os.Create(matcherFilePath) + if err != nil { + return err + } + defer f.Close() + + var buf bytes.Buffer + + if err := router.SerializeGeoSiteList(geosite, deps, hostIPs, &buf); err != nil { + return err + } + + if _, err := f.Write(buf.Bytes()); err != nil { + return err + } + + return nil +} + // Convert string to Address. func ParseSendThough(Addr *string) *Address { var addr Address diff --git a/main/commands/all/buildmphcache.go b/main/commands/all/buildmphcache.go new file mode 100644 index 000000000000..6c45205ec663 --- /dev/null +++ b/main/commands/all/buildmphcache.go @@ -0,0 +1,52 @@ +package all + +import ( + "os" + + "github.com/xtls/xray-core/common/platform" + "github.com/xtls/xray-core/infra/conf/serial" + "github.com/xtls/xray-core/main/commands/base" +) + +var cmdBuildMphCache = &base.Command{ + UsageLine: `{{.Exec}} buildMphCache [-c config.json] [-o domain.cache]`, + Short: `Build domain matcher cache`, + Long: ` +Build domain matcher cache from a configuration file. + +Example: {{.Exec}} buildMphCache -c config.json -o domain.cache +`, +} + +func init() { + cmdBuildMphCache.Run = executeBuildMphCache +} + +var ( + configPath = cmdBuildMphCache.Flag.String("c", "config.json", "Config file path") + outputPath = cmdBuildMphCache.Flag.String("o", "domain.cache", "Output cache file path") +) + +func executeBuildMphCache(cmd *base.Command, args []string) { + cf, err := os.Open(*configPath) + if err != nil { + base.Fatalf("failed to open config file: %v", err) + } + defer cf.Close() + + // prevent using existing cache + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + if domainMatcherPath != "" { + os.Setenv("XRAY_MPH_CACHE", "") + defer os.Setenv("XRAY_MPH_CACHE", domainMatcherPath) + } + + config, err := serial.DecodeJSONConfig(cf) + if err != nil { + base.Fatalf("failed to decode config file: %v", err) + } + + if err := config.BuildMPHCache(outputPath); err != nil { + base.Fatalf("failed to build MPH cache: %v", err) + } +} diff --git a/main/commands/all/commands.go b/main/commands/all/commands.go index fba3a4b8bb43..20b92bb01b97 100644 --- a/main/commands/all/commands.go +++ b/main/commands/all/commands.go @@ -19,5 +19,6 @@ func init() { cmdMLDSA65, cmdMLKEM768, cmdVLESSEnc, + cmdBuildMphCache, ) }