Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions app/dns/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/features/dns"
"sort"
)

// StaticHosts represents static domain-ip mapping in DNS server.
Expand Down Expand Up @@ -59,16 +60,14 @@ func filterIP(ips []net.Address, option dns.IPOption) []net.Address {
}

func (h *StaticHosts) lookupInternal(domain string) []net.Address {
ips := make([]net.Address, 0)
found := false
for _, id := range h.matchers.Match(domain) {
ips = append(ips, h.ips[id]...)
found = true
}
if !found {
MatchSlice := h.matchers.Match(domain)
sort.Slice(MatchSlice, func(i, j int) bool {
return MatchSlice[i] < MatchSlice[j]
})
if len(MatchSlice) == 0 {
return nil
}
return ips
return h.ips[MatchSlice[0]]
}

func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) []net.Address {
Expand Down
72 changes: 56 additions & 16 deletions infra/conf/dns.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package conf

import (
"bytes"
"encoding/json"
"sort"
"strings"

"github.com/xtls/xray-core/app/dns"
Expand Down Expand Up @@ -192,7 +192,8 @@ func (h *HostAddress) UnmarshalJSON(data []byte) error {
}

type HostsWrapper struct {
Hosts map[string]*HostAddress
Domains []string
Hosts map[string]*HostAddress
}

func getHostMapping(ha *HostAddress) *dns.Config_HostMapping {
Expand Down Expand Up @@ -223,31 +224,70 @@ func getHostMapping(ha *HostAddress) *dns.Config_HostMapping {

// MarshalJSON implements encoding/json.Marshaler.MarshalJSON
func (m *HostsWrapper) MarshalJSON() ([]byte, error) {
return json.Marshal(m.Hosts)
var buf bytes.Buffer
buf.WriteString("{")
for i, domain := range m.Domains {
if i > 0 {
buf.WriteString(",")
}
keyBytes, err := json.Marshal(domain)
if err != nil {
return nil, err
}
buf.Write(keyBytes)
buf.WriteString(":")
valueBytes, err := json.Marshal(m.Hosts[domain])
if err != nil {
return nil, err
}
buf.Write(valueBytes)
}
buf.WriteString("}")
return buf.Bytes(), nil
}

// UnmarshalJSON implements encoding/json.Unmarshaler.UnmarshalJSON
func (m *HostsWrapper) UnmarshalJSON(data []byte) error {
hosts := make(map[string]*HostAddress)
err := json.Unmarshal(data, &hosts)
if err == nil {
m.Hosts = hosts
return nil
m.Hosts = make(map[string]*HostAddress)
m.Domains = []string{}

var tempMap map[string]*HostAddress
if err := json.Unmarshal(data, &tempMap); err != nil {
return err
}

dec := json.NewDecoder(bytes.NewReader(data))
t, err := dec.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("unexpected token")
}
for dec.More() {
key, err := dec.Token()
if err != nil {
return err
}
domain, ok := key.(string)
if !ok {
return errors.New("invalid key")
}
m.Domains = append(m.Domains, domain)
var ha *HostAddress
if err := dec.Decode(&ha); err != nil {
return err
}
m.Hosts[domain] = ha
}
return errors.New("invalid DNS hosts").Base(err)
return nil
}

// Build implements Buildable
func (m *HostsWrapper) Build() ([]*dns.Config_HostMapping, error) {
mappings := make([]*dns.Config_HostMapping, 0, 20)

domains := make([]string, 0, len(m.Hosts))
for domain := range m.Hosts {
domains = append(domains, domain)
}
sort.Strings(domains)

for _, domain := range domains {
for _, domain := range m.Domains {
switch {
case strings.HasPrefix(domain, "domain:"):
domainName := domain[7:]
Expand Down