Skip to content

Commit

Permalink
DNS: fix typo & refine code (#1183)
Browse files Browse the repository at this point in the history
Co-authored-by: loyalsoldier <[email protected]>
  • Loading branch information
rurirei and Loyalsoldier committed Aug 10, 2021
1 parent ecaf597 commit 73470e8
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 106 deletions.
2 changes: 1 addition & 1 deletion app/dns/dnscommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ L:
case dnsmessage.TypeAAAA:
ans, err := parser.AAAAResource()
if err != nil {
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog()
break L
}
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
Expand Down
47 changes: 23 additions & 24 deletions app/dns/nameserver_doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
// thus most of the DOH implementation is copied from udpns.go
type DoHNameServer struct {
sync.RWMutex
ips map[string]record
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
Expand Down Expand Up @@ -112,7 +112,7 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer {

func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer {
s := &DoHNameServer{
ips: make(map[string]record),
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: prefix + "//" + url.Host,
dohURL: url.String(),
Expand Down Expand Up @@ -156,7 +156,7 @@ func (s *DoHNameServer) Cleanup() error {
}

if len(s.ips) == 0 {
s.ips = make(map[string]record)
s.ips = make(map[string]*record)
}

return nil
Expand All @@ -166,7 +166,10 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start)

s.Lock()
rec := s.ips[req.domain]
rec, found := s.ips[req.domain]
if !found {
rec = &record{}
}
updated := false

switch req.reqType {
Expand All @@ -176,7 +179,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
updated = true
}
case dnsmessage.TypeAAAA:
addr := make([]net.Address, 0)
addr := make([]net.Address, 0, len(ipRec.IP))
for _, ip := range ipRec.IP {
if len(ip.IP()) == net.IPv6len {
addr = append(addr, ip)
Expand Down Expand Up @@ -295,34 +298,30 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
return nil, errRecordNotFound
}

var err4 error
var err6 error
var ips []net.Address
var lastErr error
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
}

if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
var ip6 []net.Address

switch {
case option.IPv4Enable:
ips, err4 = record.A.getIPs()
fallthrough
case option.IPv6Enable:
ip6, err6 = record.AAAA.getIPs()
ips = append(ips, ip6...)
}

if len(ips) > 0 {
return toNetIP(ips)
}

if lastErr != nil {
return nil, lastErr
if err4 != nil {
return nil, err4
}

if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
return nil, dns_feature.ErrEmptyResponse
if err6 != nil {
return nil, err6
}

return nil, errRecordNotFound
Expand Down
49 changes: 24 additions & 25 deletions app/dns/nameserver_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ const handshakeIdleTimeout = time.Second * 8
// QUICNameServer implemented DNS over QUIC
type QUICNameServer struct {
sync.RWMutex
ips map[string]record
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
name string
destination net.Destination
destination *net.Destination
session quic.Session
}

Expand All @@ -57,10 +57,10 @@ func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
dest := net.UDPDestination(net.DomainAddress(url.Hostname()), port)

s := &QUICNameServer{
ips: make(map[string]record),
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: url.String(),
destination: dest,
destination: &dest,
}
s.cleanup = &task.Periodic{
Interval: time.Minute,
Expand Down Expand Up @@ -102,7 +102,7 @@ func (s *QUICNameServer) Cleanup() error {
}

if len(s.ips) == 0 {
s.ips = make(map[string]record)
s.ips = make(map[string]*record)
}

return nil
Expand All @@ -112,7 +112,10 @@ func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start)

s.Lock()
rec := s.ips[req.domain]
rec, found := s.ips[req.domain]
if !found {
rec = &record{}
}
updated := false

switch req.reqType {
Expand Down Expand Up @@ -232,34 +235,30 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp
return nil, errRecordNotFound
}

var err4 error
var err6 error
var ips []net.Address
var lastErr error
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
}

if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
var ip6 []net.Address

switch {
case option.IPv4Enable:
ips, err4 = record.A.getIPs()
fallthrough
case option.IPv6Enable:
ip6, err6 = record.AAAA.getIPs()
ips = append(ips, ip6...)
}

if len(ips) > 0 {
return toNetIP(ips)
}

if lastErr != nil {
return nil, lastErr
if err4 != nil {
return nil, err4
}

if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
return nil, dns_feature.ErrEmptyResponse
if err6 != nil {
return nil, err6
}

return nil, errRecordNotFound
Expand Down
53 changes: 28 additions & 25 deletions app/dns/nameserver_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
type TCPNameServer struct {
sync.RWMutex
name string
destination net.Destination
ips map[string]record
destination *net.Destination
ips map[string]*record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
Expand All @@ -45,7 +45,7 @@ func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServ
}

s.dial = func(ctx context.Context) (net.Conn, error) {
link, err := dispatcher.Dispatch(ctx, s.destination)
link, err := dispatcher.Dispatch(ctx, *s.destination)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +67,7 @@ func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) {
}

s.dial = func(ctx context.Context) (net.Conn, error) {
return internet.DialSystem(ctx, s.destination, nil)
return internet.DialSystem(ctx, *s.destination, nil)
}

return s, nil
Expand All @@ -85,8 +85,8 @@ func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) {
dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)

s := &TCPNameServer{
destination: dest,
ips: make(map[string]record),
destination: &dest,
ips: make(map[string]*record),
pub: pubsub.NewService(),
name: prefix + "//" + dest.NetAddr(),
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func (s *TCPNameServer) Cleanup() error {
}

if len(s.ips) == 0 {
s.ips = make(map[string]record)
s.ips = make(map[string]*record)
}

return nil
Expand All @@ -140,7 +140,10 @@ func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start)

s.Lock()
rec := s.ips[req.domain]
rec, found := s.ips[req.domain]
if !found {
rec = &record{}
}
updated := false

switch req.reqType {
Expand Down Expand Up @@ -274,30 +277,30 @@ func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
return nil, errRecordNotFound
}

var err4 error
var err6 error
var ips []net.Address
var lastErr error
if option.IPv4Enable {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
}

if option.IPv6Enable {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
var ip6 []net.Address

switch {
case option.IPv4Enable:
ips, err4 = record.A.getIPs()
fallthrough
case option.IPv6Enable:
ip6, err6 = record.AAAA.getIPs()
ips = append(ips, ip6...)
}

if len(ips) > 0 {
return toNetIP(ips)
}

if lastErr != nil {
return nil, lastErr
if err4 != nil {
return nil, err4
}

if err6 != nil {
return nil, err6
}

return nil, dns_feature.ErrEmptyResponse
Expand Down
Loading

0 comments on commit 73470e8

Please sign in to comment.