From ad2b06fd0004f9fae388ee4ba6a992b5cdb45a2e Mon Sep 17 00:00:00 2001 From: rook1e Date: Thu, 13 Jul 2023 16:57:37 +0800 Subject: [PATCH] fix: refactor wildcard record checker Use the record (the first element of answer), not the result of recursive querying. --- .gitignore | 1 + cmd/sf/sf.go | 7 +++-- internal/engine/checker.go | 53 ++++++++++--------------------------- internal/engine/engine.go | 12 ++++----- internal/engine/recorder.go | 20 +++++--------- internal/engine/resolver.go | 2 +- internal/module/module.go | 4 +-- 7 files changed, 31 insertions(+), 68 deletions(-) diff --git a/.gitignore b/.gitignore index 2823ac4..cf639e5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pprof .DS_Store ._* +*.txt dist/ diff --git a/cmd/sf/sf.go b/cmd/sf/sf.go index c3e019d..c523f42 100644 --- a/cmd/sf/sf.go +++ b/cmd/sf/sf.go @@ -85,12 +85,11 @@ It is recommended to determine if the rate is appropriate by the send/recv stati c.Target, c.Wordlist, c.Resolver, c.Concurrent, c.Rate, c.Retry, c.ValidCheck) startAt := time.Now() - app := engine.New(c) - valid, invalid := app.Run() + res := engine.New().Run() - logrus.Infof("found %d valid, %d invalid. %.2f seconds in total.\n", len(valid), len(invalid), time.Since(startAt).Seconds()) + logrus.Infof("found %d subdomains. time: %.2f seconds.\n", len(res), time.Since(startAt).Seconds()) - saveResult(output, valid) + saveResult(output, res) } func saveResult(path string, data []string) { diff --git a/internal/engine/checker.go b/internal/engine/checker.go index bca1d52..2a827e5 100644 --- a/internal/engine/checker.go +++ b/internal/engine/checker.go @@ -10,60 +10,35 @@ import ( // existWildcard checks if there is a wildcard record func (e *Engine) existWildcard() bool { - m := new(dns.Msg) - m.SetQuestion(conf.C.Target, dns.TypeNS) - r, err := dns.Exchange(m, conf.C.Resolver) - if err != nil || r.Rcode != dns.RcodeSuccess || len(r.Answer) == 0 { + m := &dns.Msg{} + m.SetQuestion("*."+conf.C.Target, dns.TypeA) + resp, err := dns.Exchange(m, conf.C.Resolver) + if err != nil || resp.Rcode != dns.RcodeSuccess || len(resp.Answer) == 0 { return false } - for _, v := range r.Answer { - n, ok := v.(*dns.NS) - if !ok { - continue - } - m := &dns.Msg{} - m.SetQuestion("*."+conf.C.Target, dns.TypeA) - resp, err := dns.Exchange(m, n.Ns+":53") - if err != nil || resp.Rcode != dns.RcodeSuccess || len(resp.Answer) == 0 { - continue - } - e.wildcardRecord = resp.Answer - break - } - if len(e.wildcardRecord) == 0 { - return false - } - e.wildcardRecord[0].Header().Name = "" // for easier comparison + e.wildcardRecord = resp.Answer[0] + logrus.Debug("found wildcard record: " + e.wildcardRecord.String()) + e.wildcardRecord.Header().Name = "" // for easier comparison return true } -// checker checks if domain is valid +// checker checks if domain is valid: // -// more: https://github.com/0x2E/sf/issues/12 +// 1. not wildcard record func (e *Engine) checker(wg *sync.WaitGroup) { defer func() { close(e.toRecorder) wg.Done() }() - logger := logrus.WithField("step", "checker") - for t := range e.toChecker { - if len(t.Answer) == len(e.wildcardRecord) { - matchAll := true - t.Answer[0].Header().Name = "" // for easier comparison - for i, v := range e.wildcardRecord { - if !dns.IsDuplicate(v, t.Answer[i]) { - matchAll = false - break - } - } - if matchAll { - t.Valid = false - logger.Debug("invalid: " + t.DomainName) - } + t.Record.Header().Name = "" // for easier comparison + + if dns.IsDuplicate(t.Record, e.wildcardRecord) { + continue } + e.toRecorder <- t } } diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 7e578c8..ee25a87 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -22,15 +22,14 @@ var ( type Engine struct { needCheck bool // wildcardRecord is the wildcard record (`*.example.com`) - wildcardRecord []dns.RR + wildcardRecord dns.RR toResolver chan *module.Task toChecker chan *module.Task toRecorder chan *module.Task - validResults []string - invalidResults []string + results []string } -func New(config *conf.Config) *Engine { +func New() *Engine { return &Engine{ toResolver: make(chan *module.Task, QueueMaxLen), toChecker: make(chan *module.Task, QueueMaxLen), @@ -38,13 +37,12 @@ func New(config *conf.Config) *Engine { } } -func (e *Engine) Run() ([]string, []string) { +func (e *Engine) Run() []string { wg := sync.WaitGroup{} e.needCheck = conf.C.ValidCheck && e.existWildcard() if e.needCheck { wg.Add(1) go e.checker(&wg) - logrus.Debugf("wirldcard record: %#v", e.wildcardRecord) } else { logrus.Debug("turn off checker") close(e.toChecker) @@ -88,5 +86,5 @@ func (e *Engine) Run() ([]string, []string) { close(e.toResolver) wg.Wait() - return e.validResults, e.invalidResults + return e.results } diff --git a/internal/engine/recorder.go b/internal/engine/recorder.go index 515c772..70c87e9 100644 --- a/internal/engine/recorder.go +++ b/internal/engine/recorder.go @@ -8,23 +8,15 @@ import ( func (e *Engine) recorder(wg *sync.WaitGroup) { defer wg.Done() - validSet, invalidSet := make(map[string]struct{}), make(map[string]struct{}) + res := make(map[string]struct{}) for t := range e.toRecorder { subdomain := t.DomainName[:len(t.DomainName)-1] - if t.Valid { - fmt.Println(subdomain) - validSet[subdomain] = struct{}{} - } else { - invalidSet[subdomain] = struct{}{} - } + fmt.Println(subdomain) + res[subdomain] = struct{}{} } - e.validResults = make([]string, 0, len(validSet)) - for d := range validSet { - e.validResults = append(e.validResults, d) - } - e.invalidResults = make([]string, 0, len(invalidSet)) - for d := range invalidSet { - e.invalidResults = append(e.invalidResults, d) + e.results = make([]string, 0, len(res)) + for d := range res { + e.results = append(e.results, d) } } diff --git a/internal/engine/resolver.go b/internal/engine/resolver.go index 854f0f4..ba46721 100644 --- a/internal/engine/resolver.go +++ b/internal/engine/resolver.go @@ -160,7 +160,7 @@ func (w *resolverWorker) receiver(wg *sync.WaitGroup, toNext chan<- *module.Task if msg.Rcode != dns.RcodeSuccess || len(msg.Answer) == 0 { continue } - task.Answer = msg.Answer + task.Record = msg.Answer[0] toNext <- task } } diff --git a/internal/module/module.go b/internal/module/module.go index 2102097..a462e17 100644 --- a/internal/module/module.go +++ b/internal/module/module.go @@ -6,15 +6,13 @@ import ( type Task struct { DomainName string - Answer []dns.RR + Record dns.RR LastQueryAt int64 Received bool - Valid bool } func putTask(toNext chan<- *Task, dn string) { toNext <- &Task{ DomainName: dns.Fqdn(dn), - Valid: true, } }