From 5f8a46ac56fff35c15982c594e4ff9c801ca7f0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 28 Aug 2021 23:43:11 +0800 Subject: [PATCH] Revert "DNS: fix typo & refine code (#1183)" This reverts commit 73470e8dd8d020f3f19d77545cce00a9dbb96c2e. --- app/dns/dnscommon.go | 2 +- app/dns/nameserver_doh.go | 47 ++++++++++++++------------- app/dns/nameserver_quic.go | 49 ++++++++++++++-------------- app/dns/nameserver_tcp.go | 53 +++++++++++++++--------------- app/dns/nameserver_udp.go | 66 ++++++++++++++++++-------------------- 5 files changed, 106 insertions(+), 111 deletions(-) diff --git a/app/dns/dnscommon.go b/app/dns/dnscommon.go index 0624a4db0e6..056ceda439c 100644 --- a/app/dns/dnscommon.go +++ b/app/dns/dnscommon.go @@ -214,7 +214,7 @@ L: case dnsmessage.TypeAAAA: ans, err := parser.AAAAResource() if err != nil { - newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog() + newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() break L } ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index 88350f2ee11..5d6e8c8b6e7 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -33,7 +33,7 @@ import ( // thus most of the DOH implementation is copied from udpns.go type DoHNameServer struct { sync.RWMutex - ips map[string]*record + ips map[string]record pub *pubsub.Service cleanup *task.Periodic reqID uint32 @@ -113,7 +113,7 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer { func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer { s := &DoHNameServer{ - ips: make(map[string]*record), + ips: make(map[string]record), pub: pubsub.NewService(), name: prefix + "//" + url.Host, dohURL: url.String(), @@ -157,7 +157,7 @@ func (s *DoHNameServer) Cleanup() error { } if len(s.ips) == 0 { - s.ips = make(map[string]*record) + s.ips = make(map[string]record) } return nil @@ -167,10 +167,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { elapsed := time.Since(req.start) s.Lock() - rec, found := s.ips[req.domain] - if !found { - rec = &record{} - } + rec := s.ips[req.domain] updated := false switch req.reqType { @@ -180,7 +177,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { updated = true } case dnsmessage.TypeAAAA: - addr := make([]net.Address, 0, len(ipRec.IP)) + addr := make([]net.Address, 0) for _, ip := range ipRec.IP { if len(ip.IP()) == net.IPv6len { addr = append(addr, ip) @@ -299,30 +296,34 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt return nil, errRecordNotFound } - var err4 error - var err6 error var ips []net.Address - var ip6 []net.Address - - switch { - case option.IPv4Enable: - ips, err4 = record.A.getIPs() - fallthrough - case option.IPv6Enable: - ip6, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) + var lastErr error + if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess { + aaaa, err := record.AAAA.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, aaaa...) + } + + if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess { + a, err := record.A.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, a...) } if len(ips) > 0 { return toNetIP(ips) } - if err4 != nil { - return nil, err4 + if lastErr != nil { + return nil, lastErr } - if err6 != nil { - return nil, err6 + if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { + return nil, dns_feature.ErrEmptyResponse } return nil, errRecordNotFound diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index f2e2573edfa..2fe6a4ee2e8 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -34,12 +34,12 @@ const handshakeIdleTimeout = time.Second * 8 // QUICNameServer implemented DNS over QUIC type QUICNameServer struct { sync.RWMutex - ips map[string]*record + ips map[string]record pub *pubsub.Service cleanup *task.Periodic reqID uint32 name string - destination *net.Destination + destination net.Destination session quic.Session } @@ -58,10 +58,10 @@ func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) { dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port) s := &QUICNameServer{ - ips: make(map[string]*record), + ips: make(map[string]record), pub: pubsub.NewService(), name: url.String(), - destination: &dest, + destination: dest, } s.cleanup = &task.Periodic{ Interval: time.Minute, @@ -103,7 +103,7 @@ func (s *QUICNameServer) Cleanup() error { } if len(s.ips) == 0 { - s.ips = make(map[string]*record) + s.ips = make(map[string]record) } return nil @@ -113,10 +113,7 @@ func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { elapsed := time.Since(req.start) s.Lock() - rec, found := s.ips[req.domain] - if !found { - rec = &record{} - } + rec := s.ips[req.domain] updated := false switch req.reqType { @@ -236,30 +233,34 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp return nil, errRecordNotFound } - var err4 error - var err6 error var ips []net.Address - var ip6 []net.Address - - switch { - case option.IPv4Enable: - ips, err4 = record.A.getIPs() - fallthrough - case option.IPv6Enable: - ip6, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) + var lastErr error + if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess { + aaaa, err := record.AAAA.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, aaaa...) + } + + if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess { + a, err := record.A.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, a...) } if len(ips) > 0 { return toNetIP(ips) } - if err4 != nil { - return nil, err4 + if lastErr != nil { + return nil, lastErr } - if err6 != nil { - return nil, err6 + if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { + return nil, dns_feature.ErrEmptyResponse } return nil, errRecordNotFound diff --git a/app/dns/nameserver_tcp.go b/app/dns/nameserver_tcp.go index be51ffb57c8..ff34509c153 100644 --- a/app/dns/nameserver_tcp.go +++ b/app/dns/nameserver_tcp.go @@ -30,8 +30,8 @@ import ( type TCPNameServer struct { sync.RWMutex name string - destination *net.Destination - ips map[string]*record + destination net.Destination + ips map[string]record pub *pubsub.Service cleanup *task.Periodic reqID uint32 @@ -46,7 +46,7 @@ func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServ } s.dial = func(ctx context.Context) (net.Conn, error) { - link, err := dispatcher.Dispatch(ctx, *s.destination) + link, err := dispatcher.Dispatch(ctx, s.destination) if err != nil { return nil, err } @@ -68,7 +68,7 @@ func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) { } s.dial = func(ctx context.Context) (net.Conn, error) { - return internet.DialSystem(ctx, *s.destination, nil) + return internet.DialSystem(ctx, s.destination, nil) } return s, nil @@ -86,8 +86,8 @@ func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) { dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) s := &TCPNameServer{ - destination: &dest, - ips: make(map[string]*record), + destination: dest, + ips: make(map[string]record), pub: pubsub.NewService(), name: prefix + "//" + dest.NetAddr(), } @@ -131,7 +131,7 @@ func (s *TCPNameServer) Cleanup() error { } if len(s.ips) == 0 { - s.ips = make(map[string]*record) + s.ips = make(map[string]record) } return nil @@ -141,10 +141,7 @@ func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { elapsed := time.Since(req.start) s.Lock() - rec, found := s.ips[req.domain] - if !found { - rec = &record{} - } + rec := s.ips[req.domain] updated := false switch req.reqType { @@ -278,30 +275,30 @@ func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt return nil, errRecordNotFound } - var err4 error - var err6 error var ips []net.Address - var ip6 []net.Address - - switch { - case option.IPv4Enable: - ips, err4 = record.A.getIPs() - fallthrough - case option.IPv6Enable: - ip6, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) + var lastErr error + if option.IPv4Enable { + a, err := record.A.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, a...) } - if len(ips) > 0 { - return toNetIP(ips) + if option.IPv6Enable { + aaaa, err := record.AAAA.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, aaaa...) } - if err4 != nil { - return nil, err4 + if len(ips) > 0 { + return toNetIP(ips) } - if err6 != nil { - return nil, err6 + if lastErr != nil { + return nil, lastErr } return nil, dns_feature.ErrEmptyResponse diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 5d88da148e3..6610e806844 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -29,9 +29,9 @@ import ( type ClassicNameServer struct { sync.RWMutex name string - address *net.Destination - ips map[string]*record - requests map[uint16]*dnsRequest + address net.Destination + ips map[string]record + requests map[uint16]dnsRequest pub *pubsub.Service udpServer *udp.Dispatcher cleanup *task.Periodic @@ -46,9 +46,9 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher } s := &ClassicNameServer{ - address: &address, - ips: make(map[string]*record), - requests: make(map[uint16]*dnsRequest), + address: address, + ips: make(map[string]record), + requests: make(map[uint16]dnsRequest), pub: pubsub.NewService(), name: strings.ToUpper(address.String()), } @@ -85,7 +85,6 @@ func (s *ClassicNameServer) Cleanup() error { } if record.A == nil && record.AAAA == nil { - newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() delete(s.ips, domain) } else { s.ips[domain] = record @@ -93,7 +92,7 @@ func (s *ClassicNameServer) Cleanup() error { } if len(s.ips) == 0 { - s.ips = make(map[string]*record) + s.ips = make(map[string]record) } for id, req := range s.requests { @@ -103,7 +102,7 @@ func (s *ClassicNameServer) Cleanup() error { } if len(s.requests) == 0 { - s.requests = make(map[uint16]*dnsRequest) + s.requests = make(map[uint16]dnsRequest) } return nil @@ -141,17 +140,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot elapsed := time.Since(req.start) newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { - s.updateIP(req.domain, &rec) + s.updateIP(req.domain, rec) } } -func (s *ClassicNameServer) updateIP(domain string, newRec *record) { +func (s *ClassicNameServer) updateIP(domain string, newRec record) { s.Lock() - rec, found := s.ips[domain] - if !found { - rec = &record{} - } + newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() + rec := s.ips[domain] updated := false if isNewer(rec.A, newRec.A) { @@ -164,7 +161,6 @@ func (s *ClassicNameServer) updateIP(domain string, newRec *record) { } if updated { - newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() s.ips[domain] = rec } if newRec.A != nil { @@ -187,7 +183,7 @@ func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { id := req.msg.ID req.expire = time.Now().Add(time.Second * 8) - s.requests[id] = req + s.requests[id] = *req } func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { @@ -205,7 +201,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client udpCtx = session.ContextWithContent(udpCtx, &session.Content{ Protocol: "dns", }) - s.udpServer.Dispatch(udpCtx, *s.address, b) + s.udpServer.Dispatch(udpCtx, s.address, b) } } @@ -218,30 +214,30 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.I return nil, errRecordNotFound } - var err4 error - var err6 error var ips []net.Address - var ip6 []net.Address - - switch { - case option.IPv4Enable: - ips, err4 = record.A.getIPs() - fallthrough - case option.IPv6Enable: - ip6, err6 = record.AAAA.getIPs() - ips = append(ips, ip6...) + var lastErr error + if option.IPv4Enable { + a, err := record.A.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, a...) } - if len(ips) > 0 { - return toNetIP(ips) + if option.IPv6Enable { + aaaa, err := record.AAAA.getIPs() + if err != nil { + lastErr = err + } + ips = append(ips, aaaa...) } - if err4 != nil { - return nil, err4 + if len(ips) > 0 { + return toNetIP(ips) } - if err6 != nil { - return nil, err6 + if lastErr != nil { + return nil, lastErr } return nil, dns_feature.ErrEmptyResponse