Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions app/dns/cache_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,29 @@ const (
)

type CacheController struct {
name string
disableCache bool
serveStale bool
serveExpiredTTL int32

ips map[string]*record
dirtyips map[string]*record

sync.RWMutex
ips map[string]*record
dirtyips map[string]*record
pub *pubsub.Service
cacheCleanup *task.Periodic
name string
disableCache bool
highWatermark int
requestGroup singleflight.Group
}

func NewCacheController(name string, disableCache bool) *CacheController {
func NewCacheController(name string, disableCache bool, serveStale bool, serveExpiredTTL uint32) *CacheController {
c := &CacheController{
name: name,
disableCache: disableCache,
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: name,
disableCache: disableCache,
serveStale: serveStale,
serveExpiredTTL: -int32(serveExpiredTTL),
ips: make(map[string]*record),
pub: pubsub.NewService(),
}

c.cacheCleanup = &task.Periodic{
Expand Down Expand Up @@ -78,6 +84,10 @@ func (c *CacheController) collectExpiredKeys() ([]string, error) {
}

now := time.Now()
if c.serveStale && c.serveExpiredTTL != 0 {
now = now.Add(time.Duration(c.serveExpiredTTL) * time.Second)
}

expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate

for domain, rec := range c.ips {
Expand Down Expand Up @@ -105,6 +115,10 @@ func (c *CacheController) writeAndShrink(expiredKeys []string) {
}

now := time.Now()
if c.serveStale && c.serveExpiredTTL != 0 {
now = now.Add(time.Duration(c.serveExpiredTTL) * time.Second)
}

for _, domain := range expiredKeys {
rec := c.ips[domain]
if rec == nil {
Expand Down Expand Up @@ -280,15 +294,17 @@ func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
c.Unlock()

if pubRecord != nil {
_, _ /*ttl*/, err := pubRecord.getIPs()
if /*ttl >= 0 &&*/ !go_errors.Is(err, errRecordNotFound) {
_, ttl, err := pubRecord.getIPs()
if ttl > 0 && !go_errors.Is(err, errRecordNotFound) {
c.pub.Publish(req.domain+pubSuffix, pubRecord)
}
}

errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)

common.Must(c.cacheCleanup.Start())
if !c.serveStale || c.serveExpiredTTL != 0 {
common.Must(c.cacheCleanup.Start())
}
}

func (c *CacheController) findRecords(domain string) *record {
Expand Down
180 changes: 112 additions & 68 deletions app/dns/config.pb.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions app/dns/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ message NameServer {
string tag = 9;
uint64 timeoutMs = 10;
bool disableCache = 11;
bool serveStale = 15;
optional uint32 serveExpiredTTL = 16;
bool finalQuery = 12;
repeated xray.app.router.GeoIP unexpected_geoip = 13;
bool actUnprior = 14;
Expand Down Expand Up @@ -80,6 +82,8 @@ message Config {

// DisableCache disables DNS cache
bool disableCache = 8;
bool serveStale = 12;
uint32 serveExpiredTTL = 13;

QueryStrategy query_strategy = 9;

Expand Down
7 changes: 6 additions & 1 deletion app/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
}

disableCache := config.DisableCache || ns.DisableCache
serveStale := config.ServeStale || ns.ServeStale
serveExpiredTTL := config.ServeExpiredTTL
if ns.ServeExpiredTTL != nil {
serveExpiredTTL = *ns.ServeExpiredTTL
}

var tag = defaultTag
if len(ns.Tag) > 0 {
Expand All @@ -128,7 +133,7 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
return nil, errors.New("no QueryStrategy available for ", ns.Address)
}

client, err := NewClient(ctx, ns, myClientIP, disableCache, tag, clientIPOption, &matcherInfos, updateDomain)
client, err := NewClient(ctx, ns, myClientIP, disableCache, serveStale, serveExpiredTTL, tag, clientIPOption, &matcherInfos, updateDomain)
if err != nil {
return nil, errors.New("failed to create client").Base(err)
}
Expand Down
13 changes: 5 additions & 8 deletions app/dns/dnscommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dns
import (
"context"
"encoding/binary"
"math"
"strings"
"time"

Expand All @@ -13,6 +14,7 @@ import (
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/core"
dns_feature "github.com/xtls/xray-core/features/dns"

"golang.org/x/net/dns/dnsmessage"
)

Expand All @@ -39,19 +41,14 @@ type IPRecord struct {
RawHeader *dnsmessage.Header
}

func (r *IPRecord) getIPs() ([]net.IP, uint32, error) {
func (r *IPRecord) getIPs() ([]net.IP, int32, error) {
if r == nil {
return nil, 0, errRecordNotFound
}

untilExpire := time.Until(r.Expire).Seconds()
if untilExpire <= 0 {
return nil, 0, errRecordNotFound
}
ttl := int32(math.Ceil(untilExpire))

ttl := uint32(untilExpire) + 1
if ttl == 1 {
r.Expire = time.Now().Add(time.Second) // To ensure that two consecutive requests get the same result
}
if r.RCode != dnsmessage.RCodeSuccess {
return nil, ttl, dns_feature.RCodeError(r.RCode)
}
Expand Down
22 changes: 11 additions & 11 deletions app/dns/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Client struct {
}

// NewServer creates a name server object according to the network destination url.
func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, disableCache bool, clientIP net.IP) (Server, error) {
func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dispatcher, disableCache bool, serveStale bool, serveExpiredTTL uint32, clientIP net.IP) (Server, error) {
if address := dest.Address; address.Family().IsDomain() {
u, err := url.Parse(address.Domain())
if err != nil {
Expand All @@ -51,19 +51,19 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis
case strings.EqualFold(u.String(), "localhost"):
return NewLocalNameServer(), nil
case strings.EqualFold(u.Scheme, "https"): // DNS-over-HTTPS Remote mode
return NewDoHNameServer(u, dispatcher, false, disableCache, clientIP), nil
return NewDoHNameServer(u, dispatcher, false, disableCache, serveStale, serveExpiredTTL, clientIP), nil
case strings.EqualFold(u.Scheme, "h2c"): // DNS-over-HTTPS h2c Remote mode
return NewDoHNameServer(u, dispatcher, true, disableCache, clientIP), nil
return NewDoHNameServer(u, dispatcher, true, disableCache, serveStale, serveExpiredTTL, clientIP), nil
case strings.EqualFold(u.Scheme, "https+local"): // DNS-over-HTTPS Local mode
return NewDoHNameServer(u, nil, false, disableCache, clientIP), nil
return NewDoHNameServer(u, nil, false, disableCache, serveStale, serveExpiredTTL, clientIP), nil
case strings.EqualFold(u.Scheme, "h2c+local"): // DNS-over-HTTPS h2c Local mode
return NewDoHNameServer(u, nil, true, disableCache, clientIP), nil
return NewDoHNameServer(u, nil, true, disableCache, serveStale, serveExpiredTTL, clientIP), nil
case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode
return NewQUICNameServer(u, disableCache, clientIP)
return NewQUICNameServer(u, disableCache, serveStale, serveExpiredTTL, clientIP)
case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode
return NewTCPNameServer(u, dispatcher, disableCache, clientIP)
return NewTCPNameServer(u, dispatcher, disableCache, serveStale, serveExpiredTTL, clientIP)
case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
return NewTCPLocalNameServer(u, disableCache, clientIP)
return NewTCPLocalNameServer(u, disableCache, serveStale, serveExpiredTTL, clientIP)
case strings.EqualFold(u.String(), "fakedns"):
var fd dns.FakeDNSEngine
err = core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {
Expand All @@ -79,7 +79,7 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis
dest.Network = net.Network_UDP
}
if dest.Network == net.Network_UDP { // UDP classic DNS mode
return NewClassicNameServer(dest, dispatcher, disableCache, clientIP), nil
return NewClassicNameServer(dest, dispatcher, disableCache, serveStale, serveExpiredTTL, clientIP), nil
}
return nil, errors.New("No available name server could be created from ", dest).AtWarning()
}
Expand All @@ -89,7 +89,7 @@ func NewClient(
ctx context.Context,
ns *NameServer,
clientIP net.IP,
disableCache bool,
disableCache bool, serveStale bool, serveExpiredTTL uint32,
tag string,
ipOption dns.IPOption,
matcherInfos *[]*DomainMatcherInfo,
Expand All @@ -99,7 +99,7 @@ func NewClient(

err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
// Create a new server for each client for now
server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, disableCache, clientIP)
server, err := NewServer(ctx, ns.Address.AsDestination(), dispatcher, disableCache, serveStale, serveExpiredTTL, clientIP)
if err != nil {
return errors.New("failed to create nameserver").Base(err).AtWarning()
}
Expand Down
38 changes: 31 additions & 7 deletions app/dns/nameserver_cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,17 @@ func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.
if rec := cache.findRecords(fqdn); rec != nil {
ips, ttl, err := merge(option, rec.A, rec.AAAA)
if !go_errors.Is(err, errRecordNotFound) {
// errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
return ips, ttl, err
if ttl > 0 {
// errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
return ips, uint32(ttl), err
}
if cache.serveStale && (cache.serveExpiredTTL == 0 || cache.serveExpiredTTL < ttl) {
// errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips)
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheOptimiste, Elapsed: 0, Error: err})
go pull(ctx, s, fqdn, option)
return ips, 1, err
}
}
}
} else {
Expand All @@ -39,8 +47,15 @@ func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.
return fetch(ctx, s, fqdn, option)
}

func pull(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) {
nctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 8*time.Second)
defer cancel()

fetch(nctx, s, fqdn, option)
}

func fetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) ([]net.IP, uint32, error) {
key := fqdn + "f"
key := fqdn
switch {
case option.IPv4Enable && option.IPv6Enable:
key = key + "46"
Expand Down Expand Up @@ -99,13 +114,22 @@ func doFetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IP
}

ips, ttl, err := merge(option, rec4, rec6, errs...)
var rTTL uint32
if ttl > 0 {
rTTL = uint32(ttl)
} else if ttl == 0 && go_errors.Is(err, errRecordNotFound) {
rTTL = 0
} else { // edge case: where a fast rep's ttl expires during the rtt of a slower, parallel query
rTTL = 1
}

log.Record(&log.DNSLog{Server: s.getCacheController().name, Domain: fqdn, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
return result{ips, ttl, err}
return result{ips, rTTL, err}
}

func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, uint32, error) {
func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, int32, error) {
var allIPs []net.IP
var rTTL uint32 = dns.DefaultTTL
var rTTL int32 = dns.DefaultTTL

mergeReq := option.IPv4Enable && option.IPv6Enable

Expand Down
20 changes: 14 additions & 6 deletions app/dns/nameserver_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ type DoHNameServer struct {
}

// NewDoHNameServer creates DOH/DOHL client object for remote/local resolving.
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, h2c bool, disableCache bool, clientIP net.IP) *DoHNameServer {
func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, h2c bool, disableCache bool, serveStale bool, serveExpiredTTL uint32, clientIP net.IP) *DoHNameServer {
url.Scheme = "https"
mode := "DOH"
if dispatcher == nil {
mode = "DOHL"
}
errors.LogInfo(context.Background(), "DNS: created ", mode, " client for ", url.String(), ", with h2c ", h2c)
s := &DoHNameServer{
cacheController: NewCacheController(mode+"//"+url.Host, disableCache),
cacheController: NewCacheController(mode+"//"+url.Host, disableCache, serveStale, serveExpiredTTL),
dohURL: url.String(),
clientIP: clientIP,
}
Expand Down Expand Up @@ -131,7 +131,9 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er

if s.Name()+"." == "DOH//"+fqdn {
errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.")
noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
if noResponseErrCh != nil {
noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
}
return
}

Expand Down Expand Up @@ -172,19 +174,25 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
b, err := dns.PackMessage(r.msg)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to pack dns query for ", fqdn)
noResponseErrCh <- err
if noResponseErrCh != nil {
noResponseErrCh <- err
}
return
}
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
if err != nil {
errors.LogErrorInner(ctx, err, "failed to retrieve response for ", fqdn)
noResponseErrCh <- err
if noResponseErrCh != nil {
noResponseErrCh <- err
}
return
}
rec, err := parseResponse(resp)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", fqdn)
noResponseErrCh <- err
if noResponseErrCh != nil {
noResponseErrCh <- err
}
return
}
s.cacheController.updateRecord(r, rec)
Expand Down
8 changes: 4 additions & 4 deletions app/dns/nameserver_doh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestDOHNameServer(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
s := NewDoHNameServer(url, nil, false, false, false, 0, net.IP(nil))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
IPv4Enable: true,
Expand All @@ -34,7 +34,7 @@ func TestDOHNameServerWithCache(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
s := NewDoHNameServer(url, nil, false, false, false, 0, net.IP(nil))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
IPv4Enable: true,
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestDOHNameServerWithIPv4Override(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
s := NewDoHNameServer(url, nil, false, false, false, 0, net.IP(nil))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
IPv4Enable: true,
Expand All @@ -85,7 +85,7 @@ func TestDOHNameServerWithIPv6Override(t *testing.T) {
url, err := url.Parse("https+local://1.1.1.1/dns-query")
common.Must(err)

s := NewDoHNameServer(url, nil, false, false, net.IP(nil))
s := NewDoHNameServer(url, nil, false, false, false, 0, net.IP(nil))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ips, _, err := s.QueryIP(ctx, "google.com", dns_feature.IPOption{
IPv4Enable: false,
Expand Down
Loading