Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement unthrottled concurrency using task queue #106

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 38 additions & 45 deletions cmd/gau/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,90 +2,83 @@ package main

import (
"bufio"
"io"
"os"
"sync"

"github.com/lc/gau/v2/pkg/output"
"github.com/lc/gau/v2/runner"
"github.com/lc/gau/v2/runner/flags"
log "github.com/sirupsen/logrus"
"io"
"os"
"sync"
)

func main() {
flag := flags.New()
cfg, err := flag.ReadInConfig()
cfg, err := flags.New().ReadInConfig()
if err != nil {
if cfg.Verbose {
log.Warnf("error reading config: %v", err)
}
}

pMap := make(runner.ProvidersMap)
for _, provider := range cfg.Providers {
pMap[provider] = cfg.Filters
log.Warnf("error reading config: %v", err)
}

config, err := cfg.ProviderConfig()
if err != nil {
log.Fatal(err)
}

gau := &runner.Runner{}
gau := new(runner.Runner)

if err = gau.Init(config, pMap); err != nil {
if err = gau.Init(config, cfg.Providers, cfg.Filters); err != nil {
log.Warn(err)
}

results := make(chan string)

var out io.Writer
// Handle results in background
if config.Output == "" {
out = os.Stdout
} else {
ofp, err := os.OpenFile(config.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if config.Output != "" {
out, err := os.OpenFile(config.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatalf("Could not open output file: %v\n", err)
}
defer ofp.Close()
out = ofp
defer out.Close()
} else {
out = os.Stdout
}

writeWg := &sync.WaitGroup{}
writeWg := new(sync.WaitGroup)
writeWg.Add(1)
if config.JSON {
go func() {
defer writeWg.Done()
go func(JSON bool) {
defer writeWg.Done()
if JSON {
output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters)
}()
} else {
go func() {
defer writeWg.Done()
if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil {
log.Fatalf("error writing results: %v\n", err)
}
}()
}
} else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil {
log.Fatalf("error writing results: %v\n", err)
}
}(config.JSON)

domains := make(chan string)
gau.Start(domains, results)
workChan := make(chan runner.Work)
gau.Start(workChan, results)

if len(flags.Args()) > 0 {
for _, domain := range flags.Args() {
domains <- domain
domains := flags.Args()
if len(domains) > 0 {
for _, provider := range gau.Providers {
for _, domain := range domains {
workChan <- runner.NewWork(domain, provider)
}
}
} else {
sc := bufio.NewScanner(os.Stdin)
for sc.Scan() {
domains <- sc.Text()
}
for _, provider := range gau.Providers {
for sc.Scan() {
workChan <- runner.NewWork(sc.Text(), provider)

if err := sc.Err(); err != nil {
log.Fatal(err)
if err := sc.Err(); err != nil {
log.Fatal(err)
}
}
}

}

close(domains)
close(workChan)

// wait for providers to fetch URLS
gau.Wait()
Expand Down
3 changes: 1 addition & 2 deletions pkg/providers/commoncrawl/commoncrawl.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string)
return nil
}

paginate:
for page := uint(0); page < p.Pages; page++ {
select {
case <-ctx.Done():
break paginate
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, page)
Expand Down
6 changes: 2 additions & 4 deletions pkg/providers/otx/otx.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ func (c *Client) Name() string {
}

func (c *Client) Fetch(ctx context.Context, domain string, results chan string) error {
paginate:
for page := uint(1); ; page++ {
select {
case <-ctx.Done():
break paginate
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page - 1}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, page)
Expand All @@ -68,11 +67,10 @@ paginate:
}

if !result.HasNext {
break paginate
return nil
}
}
}
return nil
}

func (c *Client) formatURL(domain string, page uint) string {
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/valyala/fasthttp"
)

const Version = `2.1.2`
const Version = `2.2.0`

// Provider is a generic interface for all archive fetchers
type Provider interface {
Expand Down
17 changes: 6 additions & 11 deletions pkg/providers/urlscan/urlscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ const (
Name = "urlscan"
)

var _ providers.Provider = (*Client)(nil)

type Client struct {
config *providers.Config
}
Expand All @@ -41,11 +39,10 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string)
header.Value = c.config.URLScan.APIKey
}

paginate:
for page := uint(0); ; page++ {
select {
case <-ctx.Done():
break paginate
return nil
default:
logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain)
apiURL := c.formatURL(domain, searchAfter)
Expand All @@ -62,7 +59,7 @@ paginate:
// rate limited
if result.Status == 429 {
logrus.WithField("provider", "urlscan").Warnf("urlscan responded with 429, probably being rate limited")
break paginate
return nil
}

total := len(result.Results)
Expand All @@ -73,20 +70,18 @@ paginate:

if i == total-1 {
sortParam := parseSort(res.Sort)
if sortParam != "" {
searchAfter = sortParam
} else {
break paginate
if sortParam == "" {
return nil
}
searchAfter = sortParam
}
}

if !result.HasMore {
break paginate
return nil
}
}
}
return nil
}

func (c *Client) formatURL(domain string, after string) string {
Expand Down
9 changes: 2 additions & 7 deletions runner/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,8 @@ func (o *Options) getFlagValues(c *Config) {
c.RemoveParameters = fp
}

if json {
c.JSON = true
}

if verbose {
c.Verbose = verbose
}
c.JSON = json
c.Verbose = verbose

// get filter flags
mc := o.viper.GetStringSlice("mc")
Expand Down
72 changes: 33 additions & 39 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,81 +13,75 @@ import (
)

type Runner struct {
providers []providers.Provider
wg sync.WaitGroup
sync.WaitGroup

config *providers.Config
Providers []providers.Provider
threads uint
ctx context.Context
cancelFunc context.CancelFunc
}

type ProvidersMap map[string]providers.Filters

// Init initializes the runner
func (r *Runner) Init(c *providers.Config, providerMap ProvidersMap) error {
r.config = c
func (r *Runner) Init(c *providers.Config, providers []string, filters providers.Filters) error {
r.threads = c.Threads
r.ctx, r.cancelFunc = context.WithCancel(context.Background())

for name, filters := range providerMap {
for _, name := range providers {
switch name {
case "urlscan":
r.providers = append(r.providers, urlscan.New(c))
r.Providers = append(r.Providers, urlscan.New(c))
case "otx":
o := otx.New(c)
r.providers = append(r.providers, o)
r.Providers = append(r.Providers, otx.New(c))
case "wayback":
r.providers = append(r.providers, wayback.New(c, filters))
r.Providers = append(r.Providers, wayback.New(c, filters))
case "commoncrawl":
cc, err := commoncrawl.New(c, filters)
if err != nil {
return fmt.Errorf("error instantiating commoncrawl: %v\n", err)
}
r.providers = append(r.providers, cc)
r.Providers = append(r.Providers, cc)
}
}

return nil
}

// Starts starts the worker
func (r *Runner) Start(domains chan string, results chan string) {
for i := uint(0); i < r.config.Threads; i++ {
r.wg.Add(1)
func (r *Runner) Start(workChan chan Work, results chan string) {
for i := uint(0); i < r.threads; i++ {
r.Add(1)
go func() {
defer r.wg.Done()
r.worker(r.ctx, domains, results)
defer r.Done()
r.worker(r.ctx, workChan, results)
}()
}
}

// Wait waits for the providers to finish fetching
func (r *Runner) Wait() {
r.wg.Wait()
type Work struct {
domain string
provider providers.Provider
}

func NewWork(domain string, provider providers.Provider) Work {
return Work{domain, provider}
}

func (w *Work) Do(ctx context.Context, results chan string) error {
return w.provider.Fetch(ctx, w.domain, results)
}

// worker checks to see if the context is finished and executes the fetching process for each provider
func (r *Runner) worker(ctx context.Context, domains chan string, results chan string) {
work:
func (r *Runner) worker(ctx context.Context, workChan chan Work, results chan string) {
for {
select {
case <-ctx.Done():
break work
case domain, ok := <-domains:
if ok {
var wg sync.WaitGroup
for _, p := range r.providers {
wg.Add(1)
go func(p providers.Provider) {
defer wg.Done()
if err := p.Fetch(ctx, domain, results); err != nil {
logrus.WithField("provider", p.Name()).Warnf("%s - %v", domain, err)
}
}(p)
}
wg.Wait()
}
return
case work, ok := <-workChan:
if !ok {
break work
return
}
if err := work.Do(ctx, results); err != nil {
logrus.WithField("provider", work.provider.Name()).Warnf("%s - %v", work.domain, err)
}
}
}
Expand Down