Skip to content

Commit

Permalink
Refine DNS
Browse files Browse the repository at this point in the history
  • Loading branch information
Loyalsoldier committed Aug 9, 2021
1 parent 3872172 commit ce2e116
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 65 deletions.
23 changes: 10 additions & 13 deletions app/dns/nameserver_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
// thus most of the DOH implementation is copied from udpns.go
type DoHNameServer struct {
sync.RWMutex
ips map[string]record
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
Expand Down Expand Up @@ -112,7 +112,7 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer {

func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer {
s := &DoHNameServer{
ips: make(map[string]record),
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: prefix + "//" + url.Host,
dohURL: url.String(),
Expand Down Expand Up @@ -140,17 +140,10 @@ func (s *DoHNameServer) Cleanup() error {
}

for domain, record := range s.ips {
if record.A != nil && len(record.A.IP) == 0 {
if record.A != nil && (len(record.A.IP) == 0 || record.A.Expire.Before(now)) {
record.A = nil
}
if record.AAAA != nil && len(record.AAAA.IP) == 0 {
record.AAAA = nil
}

if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
}
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
if record.AAAA != nil && (len(record.AAAA.IP) == 0 || record.AAAA.Expire.Before(now)) {
record.AAAA = nil
}

Expand All @@ -162,6 +155,10 @@ func (s *DoHNameServer) Cleanup() error {
}
}

if len(s.ips) == 0 {
s.ips = make(map[string]*record)
}

return nil
}

Expand All @@ -171,7 +168,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
s.Lock()
rec, found := s.ips[req.domain]
if !found {
rec = record{}
rec = &record{}
}
updated := false

Expand All @@ -182,7 +179,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
updated = true
}
case dnsmessage.TypeAAAA:
addr := make([]net.Address, 0)
addr := make([]net.Address, 0, len(ipRec.IP))
for _, ip := range ipRec.IP {
if len(ip.IP()) == net.IPv6len {
addr = append(addr, ip)
Expand Down
25 changes: 11 additions & 14 deletions app/dns/nameserver_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ const handshakeIdleTimeout = time.Second * 8
// QUICNameServer implemented DNS over QUIC
type QUICNameServer struct {
sync.RWMutex
ips map[string]record
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
name string
destination net.Destination
destination *net.Destination
session quic.Session
}

Expand All @@ -57,10 +57,10 @@ func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
dest := net.UDPDestination(net.DomainAddress(url.Hostname()), port)

s := &QUICNameServer{
ips: make(map[string]record),
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: url.String(),
destination: dest,
destination: &dest,
}
s.cleanup = &task.Periodic{
Interval: time.Minute,
Expand All @@ -86,17 +86,10 @@ func (s *QUICNameServer) Cleanup() error {
}

for domain, record := range s.ips {
if record.A != nil && len(record.A.IP) == 0 {
if record.A != nil && (len(record.A.IP) == 0 || record.A.Expire.Before(now)) {
record.A = nil
}
if record.AAAA != nil && len(record.AAAA.IP) == 0 {
record.AAAA = nil
}

if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
}
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
if record.AAAA != nil && (len(record.AAAA.IP) == 0 || record.AAAA.Expire.Before(now)) {
record.AAAA = nil
}

Expand All @@ -108,6 +101,10 @@ func (s *QUICNameServer) Cleanup() error {
}
}

if len(s.ips) == 0 {
s.ips = make(map[string]*record)
}

return nil
}

Expand All @@ -117,7 +114,7 @@ func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
s.Lock()
rec, found := s.ips[req.domain]
if !found {
rec = record{}
rec = &record{}
}
updated := false

Expand Down
29 changes: 13 additions & 16 deletions app/dns/nameserver_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
type TCPNameServer struct {
sync.RWMutex
name string
destination net.Destination
ips map[string]record
destination *net.Destination
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
Expand All @@ -45,7 +45,7 @@ func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServ
}

s.dial = func(ctx context.Context) (net.Conn, error) {
link, err := dispatcher.Dispatch(ctx, s.destination)
link, err := dispatcher.Dispatch(ctx, *s.destination)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +67,7 @@ func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) {
}

s.dial = func(ctx context.Context) (net.Conn, error) {
return internet.DialSystem(ctx, s.destination, nil)
return internet.DialSystem(ctx, *s.destination, nil)
}

return s, nil
Expand All @@ -85,8 +85,8 @@ func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) {
dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)

s := &TCPNameServer{
destination: dest,
ips: make(map[string]record),
destination: &dest,
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: prefix + "//" + dest.NetAddr(),
}
Expand Down Expand Up @@ -114,17 +114,10 @@ func (s *TCPNameServer) Cleanup() error {
}

for domain, record := range s.ips {
if record.A != nil && len(record.A.IP) == 0 {
if record.A != nil && (len(record.A.IP) == 0 || record.A.Expire.Before(now)) {
record.A = nil
}
if record.AAAA != nil && len(record.AAAA.IP) == 0 {
record.AAAA = nil
}

if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
}
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
if record.AAAA != nil && (len(record.AAAA.IP) == 0 || record.AAAA.Expire.Before(now)) {
record.AAAA = nil
}

Expand All @@ -136,6 +129,10 @@ func (s *TCPNameServer) Cleanup() error {
}
}

if len(s.ips) == 0 {
s.ips = make(map[string]*record)
}

return nil
}

Expand All @@ -145,7 +142,7 @@ func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
s.Lock()
rec, found := s.ips[req.domain]
if !found {
rec = record{}
rec = &record{}
}
updated := false

Expand Down
42 changes: 20 additions & 22 deletions app/dns/nameserver_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import (
type ClassicNameServer struct {
sync.RWMutex
name string
address net.Destination
ips map[string]record
requests map[uint16]dnsRequest
address *net.Destination
ips map[string]*record
requests map[uint16]*dnsRequest
pub *pubsub.Service
udpServer *udp.Dispatcher
cleanup *task.Periodic
Expand All @@ -45,9 +45,9 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
}

s := &ClassicNameServer{
address: address,
ips: make(map[string]record),
requests: make(map[uint16]dnsRequest),
address: &address,
ips: make(map[string]*record),
requests: make(map[uint16]*dnsRequest),
pub: pubsub.NewService(),
name: strings.ToUpper(address.String()),
}
Expand Down Expand Up @@ -76,35 +76,33 @@ func (s *ClassicNameServer) Cleanup() error {
}

for domain, record := range s.ips {
if record.A != nil && len(record.A.IP) == 0 {
if record.A != nil && (len(record.A.IP) == 0 || record.A.Expire.Before(now)) {
record.A = nil
}
if record.AAAA != nil && len(record.AAAA.IP) == 0 {
record.AAAA = nil
}

if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
}
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
if record.AAAA != nil && (len(record.AAAA.IP) == 0 || record.AAAA.Expire.Before(now)) {
record.AAAA = nil
}

if record.A == nil && record.AAAA == nil {
newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
delete(s.ips, domain)
} else {
s.ips[domain] = record
}
}

if len(s.ips) == 0 {
s.ips = make(map[string]*record)
}

for id, req := range s.requests {
if req.expire.Before(now) {
delete(s.requests, id)
}
}

if len(s.requests) == 0 {
s.requests = make(map[uint16]dnsRequest)
s.requests = make(map[uint16]*dnsRequest)
}

return nil
Expand Down Expand Up @@ -142,16 +140,16 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
elapsed := time.Since(req.start)
newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
s.updateIP(req.domain, rec)
s.updateIP(req.domain, &rec)
}
}

func (s *ClassicNameServer) updateIP(domain string, newRec record) {
func (s *ClassicNameServer) updateIP(domain string, newRec *record) {
s.Lock()

rec, found := s.ips[domain]
if !found {
rec = record{}
rec = &record{}
}

updated := false
Expand All @@ -164,7 +162,7 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
updated = true
}

if updated && ((newRec.A != nil && len(newRec.A.IP) > 0) || (newRec.AAAA != nil && len(newRec.AAAA.IP) > 0)) {
if updated && (len(newRec.A.IP) > 0 || len(newRec.AAAA.IP) > 0) {
newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
s.ips[domain] = rec
}
Expand All @@ -188,7 +186,7 @@ func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {

id := req.msg.ID
req.expire = time.Now().Add(time.Second * 8)
s.requests[id] = *req
s.requests[id] = req
}

func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
Expand All @@ -206,7 +204,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
udpCtx = session.ContextWithContent(udpCtx, &session.Content{
Protocol: "dns",
})
s.udpServer.Dispatch(udpCtx, s.address, b)
s.udpServer.Dispatch(udpCtx, *s.address, b)
}
}

Expand Down

0 comments on commit ce2e116

Please sign in to comment.