diff --git a/pkg/query_utils.go b/pkg/query_utils.go index 6cacbea..4a7e35e 100644 --- a/pkg/query_utils.go +++ b/pkg/query_utils.go @@ -26,16 +26,15 @@ type SpiderResolver struct { r *net.Resolver filter []*regexp.Regexp contains []string - - lock sync.Mutex + timeout int + lock sync.Mutex } func DefaultResolver() *SpiderResolver { - ctx, _ := context.WithTimeout(context.Background(), time.Duration(DnsTimeout)*time.Second) // I don't think if a inside cluster dns query has more than 2s latency. return &SpiderResolver{ dns: "default-dns", + timeout: DnsTimeout, r: net.DefaultResolver, - ctx: ctx, filter: []*regexp.Regexp{}, contains: []string{}, lock: sync.Mutex{}, @@ -43,9 +42,9 @@ func DefaultResolver() *SpiderResolver { } func WarpDnsServer(dnsServer string) *SpiderResolver { - ctx, _ := context.WithTimeout(context.Background(), time.Duration(DnsTimeout)*time.Second) return &SpiderResolver{ - dns: dnsServer, + dns: dnsServer, + timeout: DnsTimeout, r: &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -53,7 +52,6 @@ func WarpDnsServer(dnsServer string) *SpiderResolver { return d.DialContext(ctx, network, dnsServer) }, }, - ctx: ctx, filter: []*regexp.Regexp{}, contains: []string{}, lock: sync.Mutex{}, @@ -110,7 +108,8 @@ func (s *SpiderResolver) CurrentDNS() string { func (s *SpiderResolver) PTRRecord(ip net.IP) []string { s.lock.Lock() defer s.lock.Unlock() - names, err := s.r.LookupAddr(s.ctx, ip.String()) + ctx, _ := context.WithTimeout(context.Background(), time.Duration(s.timeout)*time.Second) + names, err := s.r.LookupAddr(ctx, ip.String()) if err != nil { log.Debugf("LookupAddr failed: %v", err) return nil @@ -122,7 +121,8 @@ func (s *SpiderResolver) PTRRecord(ip net.IP) []string { func (s *SpiderResolver) SRVRecord(svcDomain string) (string, []*net.SRV, error) { s.lock.Lock() defer s.lock.Unlock() - cname, srvs, err := s.r.LookupSRV(s.ctx, "", "", svcDomain) + ctx, _ := context.WithTimeout(context.Background(), time.Duration(s.timeout)*time.Second) + cname, srvs, err := s.r.LookupSRV(ctx, "", "", svcDomain) var finalsrv []*net.SRV for _, srv := range srvs { if s.filterString(srv.Target) { @@ -137,7 +137,8 @@ func (s *SpiderResolver) SRVRecord(svcDomain string) (string, []*net.SRV, error) func (s *SpiderResolver) CustomSRVRecord(svcDomain string, service, proto string) (string, []*net.SRV, error) { s.lock.Lock() defer s.lock.Unlock() - cname, srvs, err := s.r.LookupSRV(s.ctx, service, proto, svcDomain) + ctx, _ := context.WithTimeout(context.Background(), time.Duration(s.timeout)*time.Second) + cname, srvs, err := s.r.LookupSRV(ctx, service, proto, svcDomain) time.Sleep(time.Duration(Latency) * time.Millisecond) return cname, srvs, err } @@ -146,14 +147,16 @@ func (s *SpiderResolver) ARecord(domain string) ([]net.IP, error) { s.lock.Lock() defer s.lock.Unlock() time.Sleep(time.Duration(Latency) * time.Millisecond) - return s.r.LookupIP(s.ctx, "ip", domain) + ctx, _ := context.WithTimeout(context.Background(), time.Duration(s.timeout)*time.Second) + return s.r.LookupIP(ctx, "ip", domain) } func (s *SpiderResolver) TXTRecord(domain string) ([]string, error) { s.lock.Lock() defer s.lock.Unlock() time.Sleep(time.Duration(Latency) * time.Millisecond) - return s.r.LookupTXT(s.ctx, domain) + ctx, _ := context.WithTimeout(context.Background(), time.Duration(s.timeout)*time.Second) + return s.r.LookupTXT(ctx, domain) } func PTRRecord(ip net.IP) []string {