diff --git a/cmd/nuclei/main.go b/cmd/nuclei/main.go index 1a51c9f010..4a1ee5b2fc 100644 --- a/cmd/nuclei/main.go +++ b/cmd/nuclei/main.go @@ -406,6 +406,7 @@ on extensive configurability, massive extensibility and ease of use.`) flagSet.CreateGroup("rate-limit", "Rate-Limit", flagSet.IntVarP(&options.RateLimit, "rate-limit", "rl", 150, "maximum number of requests to send per second"), flagSet.DurationVarP(&options.RateLimitDuration, "rate-limit-duration", "rld", time.Second, "maximum number of requests to send per second"), + flagSet.BoolVar(&options.PerHostRateLimit, "per-host-rate-limit", false, "enable per-host rate limiting (global rate limit becomes unlimited when enabled)"), flagSet.IntVarP(&options.RateLimitMinute, "rate-limit-minute", "rlm", 0, "maximum number of requests to send per minute (DEPRECATED)"), flagSet.IntVarP(&options.BulkSize, "bulk-size", "bs", 25, "maximum number of hosts to be analyzed in parallel per template"), flagSet.IntVarP(&options.TemplateThreads, "concurrency", "c", 25, "maximum number of templates to be executed in parallel"), @@ -434,6 +435,9 @@ on extensive configurability, massive extensibility and ease of use.`) }), flagSet.DurationVarP(&options.InputReadTimeout, "input-read-timeout", "irt", time.Duration(3*time.Minute), "timeout on input read"), flagSet.BoolVarP(&options.DisableHTTPProbe, "no-httpx", "nh", false, "disable httpx probing for non-url input"), + flagSet.BoolVar(&options.PreflightPortScan, "preflight-portscan", false, "run preflight resolve + TCP portscan and filter targets before scanning (disabled by default)"), + flagSet.BoolVar(&options.PerHostClientPool, "per-host-client-pool", false, "enable per-host HTTP client pooling for better connection reuse"), + flagSet.BoolVar(&options.HTTPClientShards, "http-client-shards", false, "enable HTTP client sharding for connection pooling (auto-calculates optimal shard count, max 256)"), flagSet.BoolVar(&options.DisableStdin, "no-stdin", false, "disable stdin processing"), ) diff --git a/go.mod b/go.mod index 3b6928f212..f6e17b3fc0 100644 --- a/go.mod +++ b/go.mod @@ -81,6 +81,7 @@ require ( github.com/goccy/go-json v0.10.5 github.com/google/uuid v1.6.0 github.com/h2non/filetype v1.1.3 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/invopop/yaml v0.3.1 github.com/jcmturner/gokrb5/v8 v8.4.4 github.com/kitabisa/go-ci v1.0.3 @@ -256,7 +257,6 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.8 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/hashicorp/go-version v1.7.0 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hbakhtiyor/strsim v0.0.0-20190107154042-4d2bbb273edf // indirect github.com/hdm/jarm-go v0.0.7 // indirect github.com/iangcarroll/cookiemonster v1.6.0 // indirect diff --git a/internal/runner/preflight_portscan.go b/internal/runner/preflight_portscan.go new file mode 100644 index 0000000000..9e274be8ec --- /dev/null +++ b/internal/runner/preflight_portscan.go @@ -0,0 +1,646 @@ +package runner + +import ( + "context" + "fmt" + "net" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader" + "github.com/projectdiscovery/nuclei/v3/pkg/input/provider" + inputtypes "github.com/projectdiscovery/nuclei/v3/pkg/input/types" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" + "github.com/projectdiscovery/nuclei/v3/pkg/templates" + "github.com/projectdiscovery/utils/errkit" + iputil "github.com/projectdiscovery/utils/ip" + mapsutil "github.com/projectdiscovery/utils/maps" + sliceutil "github.com/projectdiscovery/utils/slice" + stringsutil "github.com/projectdiscovery/utils/strings" + syncutil "github.com/projectdiscovery/utils/sync" + urlutil "github.com/projectdiscovery/utils/url" +) + +const preflightWorkers = 100 + +// preflightDialTimeout is intentionally short: this is a coarse filter to skip +// obviously-dead targets before the real scan (which will use full timeouts). +const preflightDialTimeout = 750 * time.Millisecond + +type filteringInputProvider struct { + base provider.InputProvider + allowed *mapsutil.SyncLockMap[string, struct{}] + allowCnt int64 + execID string +} + +func (f *filteringInputProvider) Count() int64 { return f.allowCnt } +func (f *filteringInputProvider) InputType() string { return f.base.InputType() } +func (f *filteringInputProvider) Close() { f.base.Close() } +func (f *filteringInputProvider) Set(executionId string, value string) { + f.base.Set(executionId, value) +} +func (f *filteringInputProvider) SetWithProbe(executionId string, value string, probe inputtypes.InputLivenessProbe) error { + return f.base.SetWithProbe(executionId, value, probe) +} +func (f *filteringInputProvider) SetWithExclusions(executionId string, value string) error { + return f.base.SetWithExclusions(executionId, value) +} +func (f *filteringInputProvider) Iterate(callback func(value *contextargs.MetaInput) bool) { + f.base.Iterate(func(mi *contextargs.MetaInput) bool { + key, err := mi.MarshalString() + if err != nil { + return callback(mi) + } + if _, ok := f.allowed.Get(key); !ok { + return true + } + return callback(mi) + }) +} + +// preflightResolveAndPortScan resolves hostname targets and performs a TCP connect scan for ports +// required by loaded templates. Targets that are non-resolvable hostnames or have no relevant open +// ports are filtered out from the input provider. +func (r *Runner) preflightResolveAndPortScan(store *loader.Store) error { + if r.inputProvider == nil { + return nil + } + // MultiFormat inputs may represent complete requests; skip preflight for now. + if r.inputProvider.InputType() == provider.MultiFormatInputProvider { + return nil + } + + finalTemplates := []*templates.Template{} + finalTemplates = append(finalTemplates, store.Templates()...) + finalTemplates = append(finalTemplates, store.Workflows()...) + if len(finalTemplates) == 0 { + return nil + } + + dialers := protocolstate.GetDialersWithId(r.options.ExecutionId) + if dialers == nil { + return fmt.Errorf("dialers not initialized for %s", r.options.ExecutionId) + } + + portsPopularity := portsPopularityFromTemplates(finalTemplates) + // Also include ports explicitly present in input list (ip:port or URL with port), + // so that a user-provided port isn't dropped even if templates didn't specify it. + var inputs []preflightTarget + portsFromInputs := map[string]struct{}{} + var totalTargets atomic.Int64 + r.inputProvider.Iterate(func(mi *contextargs.MetaInput) bool { + totalTargets.Add(1) + key, err := mi.MarshalString() + if err != nil { + return true + } + inputs = append(inputs, preflightTarget{key: key, raw: mi.Input}) + extractPortsFromInput(portsFromInputs, mi.Input) + return true + }) + + portsToScan := sliceutil.Dedupe(append(keysOfPopularity(portsPopularity), keysOf(portsFromInputs)...)) + portsToScan = filterValidPorts(portsToScan) + // Sort by "likely-open" order (nmap-ish/common ports first), then numeric for determinism. + likelyRank := likelyOpenPortRank() + sort.Slice(portsToScan, func(i, j int) bool { + pi, pj := portsToScan[i], portsToScan[j] + ri, okRi := likelyRank[pi] + rj, okRj := likelyRank[pj] + // ranked ports first + if okRi != okRj { + return okRi + } + // among ranked ports, lower rank wins + if okRi && okRj && ri != rj { + return ri < rj + } + // numeric asc + ni, _ := strconv.Atoi(pi) + nj, _ := strconv.Atoi(pj) + return ni < nj + }) + + // If no ports were found, nothing to scan -> keep all. + if len(portsToScan) == 0 { + return nil + } + + r.Logger.Info().Msgf("Running preflight portscan (workers=%d, ports=%d, targets=%d)", preflightWorkers, len(portsToScan), totalTargets.Load()) + + swg, err := syncutil.New(syncutil.WithSize(preflightWorkers)) + if err != nil { + return err + } + + // Resolve all targets (once) up-front so we can optionally run a single batched TCP dial scan. + // Map original input key -> resolved IPs (deduped). + // SyncLockMap requires comparable values; store resolved IPs as a comma-separated string. + resolvedIPsByKey := mapsutil.NewSyncLockMap[string, string]() + allIPsSet := mapsutil.NewSyncLockMap[string, struct{}]() + var resolveProcessed atomic.Int64 + var resolveDNSFail atomic.Int64 + + for _, t := range inputs { + swg.Add() + go func(t preflightTarget) { + defer swg.Done() + host, _, _, _ := hostForResolveAndScan(t.raw) + if host == "" { + resolveDNSFail.Add(1) + resolveProcessed.Add(1) + return + } + ips := []string{} + if iputil.IsIP(host) { + ips = append(ips, host) + } else { + dns, err := dialers.Fastdialer.GetDNSData(host) + if err != nil || (len(dns.A) == 0 && len(dns.AAAA) == 0) { + resolveDNSFail.Add(1) + resolveProcessed.Add(1) + return + } + ips = append(ips, dns.A...) + ips = append(ips, dns.AAAA...) + } + ips = sliceutil.Dedupe(ips) + if len(ips) == 0 { + resolveDNSFail.Add(1) + resolveProcessed.Add(1) + return + } + + // store + // (small contention; acceptable at preflight scale) + _ = resolvedIPsByKey.Set(t.key, strings.Join(ips, ",")) + for _, ip := range ips { + _ = allIPsSet.Set(ip, struct{}{}) + } + resolveProcessed.Add(1) + }(t) + } + swg.Wait() + + // Prepare list of all IPs for scan. + allIPsMap := allIPsSet.GetAll() + allIPs := make([]string, 0, len(allIPsMap)) + for ip := range allIPsMap { + allIPs = append(allIPs, ip) + } + sort.Strings(allIPs) + + // we do fast TCP dial scanning against resolved IPs. + if !r.options.Silent { + r.Logger.Info().Msgf("Preflight resolution: total=%d resolvable=%d unresolvable=%d", totalTargets.Load(), int64(len(resolvedIPsByKey.GetAll())), resolveDNSFail.Load()) + } + + allowed := mapsutil.NewSyncLockMap[string, struct{}]() + + var dnsFail atomic.Int64 + var portFail atomic.Int64 + var kept atomic.Int64 + var processed atomic.Int64 + + perPortOpen := mapsutil.NewSyncLockMap[string, *atomic.Int64]() + + // Periodic progress logging + // Always enabled unless running in silent mode. + debugProgress := true + stopProgress := make(chan struct{}) + if debugProgress && !r.options.Silent { + start := time.Now() + go func() { + t := time.NewTicker(1 * time.Second) + defer t.Stop() + var lastProcessed int64 + for { + select { + case <-t.C: + p := processed.Load() + if p == lastProcessed { + continue + } + lastProcessed = p + total := totalTargets.Load() + k := kept.Load() + df := dnsFail.Load() + pf := portFail.Load() + dropped := p - k + r.Logger.Info().Msgf("Preflight progress: %d/%d processed (kept=%d dropped=%d dns_fail=%d port_fail=%d elapsed=%s)", + p, total, k, dropped, df, pf, time.Since(start).Truncate(time.Second)) + case <-stopProgress: + return + } + } + }() + } + + for _, t := range inputs { + swg.Add() + go func(t preflightTarget) { + defer swg.Done() + ok, openPort, reason := r.preflightOneResolved(t.key, t.raw, portsToScan, resolvedIPsByKey, dialers) + processed.Add(1) + if ok { + _ = allowed.Set(t.key, struct{}{}) + kept.Add(1) + if openPort != "" { + counter, _ := perPortOpen.Get(openPort) + if counter == nil { + counter = &atomic.Int64{} + _ = perPortOpen.Set(openPort, counter) + } + counter.Add(1) + } + return + } + switch reason { + case preflightReasonDNS: + dnsFail.Add(1) + case preflightReasonPorts: + portFail.Add(1) + } + }(t) + } + swg.Wait() + close(stopProgress) + + // Apply filtering wrapper + allowedAll := allowed.GetAll() + r.inputProvider = &filteringInputProvider{ + base: r.inputProvider, + allowed: allowed, + allowCnt: int64(len(allowedAll)), + execID: r.options.ExecutionId, + } + + // Summary + if !r.options.Silent { + dropped := totalTargets.Load() - kept.Load() + r.Logger.Info().Msgf("Preflight summary: total=%d kept=%d filtered_dns=%d filtered_ports=%d", + totalTargets.Load(), kept.Load(), dnsFail.Load(), portFail.Load()) + r.Logger.Info().Msgf("Preflight targets: dropped=%d left=%d", dropped, kept.Load()) + perPortOpenAll := perPortOpen.GetAll() + if len(perPortOpenAll) > 0 { + type kv struct { + port string + count int64 + } + kvs := make([]kv, 0, len(perPortOpenAll)) + for p, c := range perPortOpenAll { + if c == nil { + continue + } + kvs = append(kvs, kv{port: p, count: c.Load()}) + } + sort.Slice(kvs, func(i, j int) bool { + if kvs[i].count == kvs[j].count { + return kvs[i].port < kvs[j].port + } + return kvs[i].count > kvs[j].count + }) + parts := make([]string, 0, len(kvs)) + for _, item := range kvs { + parts = append(parts, fmt.Sprintf("%s=%d", item.port, item.count)) + } + r.Logger.Info().Msgf("Preflight open-port distribution: %s", strings.Join(parts, " ")) + } + } + + _ = gologger.DefaultLogger // ensure logger imported even when silent builds vary + return nil +} + +type preflightTarget struct { + key string + raw string +} + +type preflightReason int + +const ( + preflightReasonNone preflightReason = iota + preflightReasonDNS + preflightReasonPorts +) + +func (r *Runner) preflightOneResolved(key string, raw string, ports []string, resolved *mapsutil.SyncLockMap[string, string], dialers *protocolstate.Dialers) (ok bool, openPort string, reason preflightReason) { + resolvedIPsCSV, _ := resolved.Get(key) + if resolvedIPsCSV == "" { + return false, "", preflightReasonDNS + } + ips := strings.Split(resolvedIPsCSV, ",") + + // TCP dial scan using resolved IPs + host, schemePort, hasSchemePort, _ := hostForResolveAndScan(raw) + ordered := ports + if hasSchemePort && schemePort != "" { + ordered = append([]string{schemePort}, ports...) + ordered = sliceutil.Dedupe(ordered) + } + + timeout := preflightDialTimeout + if r.options.Timeout > 0 { + t := time.Duration(r.options.Timeout) * time.Second + if t > 0 && t < timeout { + timeout = t + } + } + // Use net.Dialer directly for strict timeout enforcement. + // We already resolved IPs, so we don't need fastdialer DNS behavior here. + // This avoids rare cases where proxy dialers / custom dial stacks may not respect ctx cancellation promptly. + d := &net.Dialer{Timeout: timeout} + + // Per-host parallelism: probe up to 3 ports concurrently, stop on first success. + ctx, cancelAll := context.WithCancel(context.Background()) + defer cancelAll() + + type hit struct { + port string + } + resultCh := make(chan hit, 1) + portsCh := make(chan string) + + worker := func() { + for p := range portsCh { + // Stop quickly if someone already found an open port. + select { + case <-ctx.Done(): + return + default: + } + + for _, ip := range ips { + _ = host // keep for debugging parity + dctx, cancel := context.WithTimeout(ctx, timeout) + conn, err := d.DialContext(dctx, "tcp", net.JoinHostPort(ip, p)) + cancel() + if err == nil { + _ = conn.Close() + // Best-effort: first hit wins. + select { + case resultCh <- hit{port: p}: + cancelAll() + default: + } + return + } + // If ctx cancelled (other worker won), stop early. + select { + case <-ctx.Done(): + return + default: + } + } + } + } + + var wg sync.WaitGroup + workers := 3 + if len(ordered) < workers { + workers = len(ordered) + } + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + worker() + }() + } + + go func() { + defer close(portsCh) + for _, p := range ordered { + select { + case <-ctx.Done(): + return + case portsCh <- p: + } + } + }() + + // Wait for either a hit or all workers to finish. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case h := <-resultCh: + <-done + return true, h.port, preflightReasonNone + case <-done: + return false, "", preflightReasonPorts + } +} + +func portsPopularityFromTemplates(tpls []*templates.Template) map[string]int { + out := map[string]int{} + for _, tpl := range tpls { + // HTTP templates imply 80/443 for preflight. + if len(tpl.RequestsHTTP) > 0 || len(tpl.RequestsWithHTTP) > 0 || len(tpl.RequestsHeadless) > 0 { + out["80"]++ + out["443"]++ + } + // Network templates declare ports directly. + for _, req := range tpl.RequestsNetwork { + for _, p := range splitPorts(req.Port) { + out[p]++ + } + } + for _, req := range tpl.RequestsWithTCP { + for _, p := range splitPorts(req.Port) { + out[p]++ + } + } + // Javascript templates may include args.Port (comma-separated). + for _, req := range tpl.RequestsJavascript { + for _, p := range extractPortsFromJSArgs(req.Args) { + out[p]++ + } + } + } + return out +} + +func keysOfPopularity(m map[string]int) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} + +// likelyOpenPortRank returns a heuristic "most likely to be open" ranking for common TCP ports. +// This is intentionally static (fast + deterministic) and loosely aligns with what scanners like nmap +// tend to prioritize (common services first). +func likelyOpenPortRank() map[string]int { + // Lower index = higher priority. + // Keep this list small-ish but useful; anything not in here falls back to template popularity + numeric. + common := []string{ + "80", "443", + "22", "21", "23", + "25", "110", "143", "465", "587", "993", "995", + "53", + "3389", + "445", "139", + "135", + "3306", "5432", "1433", "1521", + "6379", "27017", + "9200", "9300", + "8080", "8443", "8000", "8008", "8081", "8888", + "9201", + "161", "162", + "389", "636", + "5900", + "11211", + "69", "123", + "1194", + "500", "4500", + } + rank := make(map[string]int, len(common)) + for i, p := range common { + // do not overwrite if duplicates + if _, ok := rank[p]; !ok { + rank[p] = i + } + } + return rank +} + +func extractPortsFromJSArgs(args map[string]interface{}) []string { + if args == nil { + return nil + } + for k, v := range args { + if strings.EqualFold(k, "Port") { + s := fmt.Sprint(v) + return splitPorts(s) + } + } + return nil +} + +func splitPorts(s string) []string { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out +} + +func extractPortsFromInput(dst map[string]struct{}, input string) { + if dst == nil { + return + } + in := strings.TrimSpace(input) + if in == "" { + return + } + low := strings.ToLower(in) + if strings.HasPrefix(low, "http://") { + dst["80"] = struct{}{} + } + if strings.HasPrefix(low, "https://") { + dst["443"] = struct{}{} + } + // URL parsing (best effort) + if u, err := urlutil.Parse(in); err == nil && u != nil { + if p := u.Port(); p != "" { + dst[p] = struct{}{} + } else if u.Scheme == "http" { + dst["80"] = struct{}{} + } else if u.Scheme == "https" { + dst["443"] = struct{}{} + } + return + } + // host:port + _, p, err := net.SplitHostPort(in) + if err == nil && p != "" { + dst[p] = struct{}{} + } +} + +func hostForResolveAndScan(raw string) (host string, schemeDefaultPort string, hasSchemePort bool, err error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "", false, errkit.New("empty input") + } + // If it looks like URL, parse and extract hostname. + if stringsutil.ContainsAny(raw, "://") { + u, perr := urlutil.ParseAbsoluteURL(raw, false) + if perr == nil && u != nil { + host = u.Hostname() + if u.Port() != "" { + return host, "", false, nil + } + switch strings.ToLower(u.Scheme) { + case "http": + return host, "80", true, nil + case "https": + return host, "443", true, nil + } + return host, "", false, nil + } + } + // Try host:port form + h, _, serr := net.SplitHostPort(raw) + if serr == nil && h != "" { + return h, "", false, nil + } + // Bare host/ip + return raw, "", false, nil +} + +func keysOf(m map[string]struct{}) []string { + out := make([]string, 0, len(m)) + for k := range m { + out = append(out, k) + } + return out +} + +func filterValidPorts(ports []string) []string { + out := make([]string, 0, len(ports)) + for _, p := range ports { + if p == "" { + continue + } + // allow numeric only + if !isNumeric(p) { + continue + } + i, err := strconv.Atoi(p) + if err != nil || i < 1 || i > 65535 { + continue + } + out = append(out, p) + } + return sliceutil.Dedupe(out) +} + +func isNumeric(s string) bool { + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 648f052e51..a1965a40f3 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -54,6 +54,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/hosterrorscache" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/interactsh" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolinit" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/uncover" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/utils/excludematchers" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/headless/engine" @@ -391,7 +392,12 @@ func New(options *types.Options) (*Runner, error) { if options.RateLimit > 0 && options.RateLimitDuration == 0 { options.RateLimitDuration = time.Second } - runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration) + // If per-host rate limiting is enabled, make global rate limiter unlimited + if options.PerHostRateLimit { + runner.rateLimiter = utils.GetRateLimiter(context.Background(), 0, 0) + } else { + runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration) + } // Initialization successful, disable cleanup on error cleanupOnError = false @@ -678,6 +684,14 @@ func (r *Runner) RunEnumeration() error { _ = r.inputProvider.SetWithExclusions(r.options.ExecutionId, host) } } + + // Preflight: resolve hosts + portscan for ports required by loaded templates, then filter inputs. + // This reduces time spent on non-resolvable targets or targets with no relevant open ports. + if r.options.PreflightPortScan { + if err := r.preflightResolveAndPortScan(store); err != nil { + return errors.Wrap(err, "preflight resolve/portscan failed") + } + } // display execution info like version , templates used etc r.displayExecutionInfo(store) @@ -700,13 +714,17 @@ func (r *Runner) RunEnumeration() error { executorOpts.InputHelper.InputsHTTP = inputHelpers } + // Set input count in dialers for sharding calculation + inputCount := int(r.inputProvider.Count()) + protocolstate.SetInputCount(r.options.ExecutionId, inputCount) + // initialize stats worker ( this is no-op unless nuclei is built with stats build tag) // during execution a directory with 2 files will be created in the current directory // config.json - containing below info // events.jsonl - containing all start and end times of all templates events.InitWithConfig(&events.ScanConfig{ Name: "nuclei-stats", // make this configurable - TargetCount: int(r.inputProvider.Count()), + TargetCount: inputCount, TemplatesCount: len(store.Templates()) + len(store.Workflows()), TemplateConcurrency: r.options.TemplateThreads, PayloadConcurrency: r.options.PayloadConcurrency, @@ -748,6 +766,52 @@ func (r *Runner) RunEnumeration() error { r.progress.Stop() timeTaken := time.Since(now) + + // Print per-host pool stats if available + if dialers := protocolstate.GetDialersWithId(r.options.ExecutionId); dialers != nil && dialers.PerHostHTTPPool != nil { + if pool, ok := dialers.PerHostHTTPPool.(interface{ PrintStats() }); ok { + pool.PrintStats() + } + } + // Print per-host rate limit pool stats if available + if dialers := protocolstate.GetDialersWithId(r.options.ExecutionId); dialers != nil && dialers.PerHostRateLimitPool != nil { + if pool, ok := dialers.PerHostRateLimitPool.(interface{ PrintStats() }); ok { + pool.PrintStats() + } + if pool, ok := dialers.PerHostRateLimitPool.(interface{ PrintPerHostPPSStats() }); ok { + pool.PrintPerHostPPSStats() + } + } + // Always print connection reuse stats (tracker is initialized early for all HTTP requests) + if dialers := protocolstate.GetDialersWithId(r.options.ExecutionId); dialers != nil { + // Ensure tracker exists (it should already be initialized, but create if needed) + if dialers.ConnectionReuseTracker == nil { + _ = httpclientpool.GetConnectionReuseTracker(r.options) + } + if dialers.ConnectionReuseTracker != nil { + if tracker, ok := dialers.ConnectionReuseTracker.(interface{ PrintStats() }); ok { + tracker.PrintStats() + } + if tracker, ok := dialers.ConnectionReuseTracker.(interface{ PrintPerHostStats() }); ok { + tracker.PrintPerHostStats() + } + } + + // Print sharded pool stats if sharding is enabled + if r.options.HTTPClientShards && dialers.ShardedHTTPPool != nil { + if pool, ok := dialers.ShardedHTTPPool.(interface{ PrintStats() }); ok { + pool.PrintStats() + } + } + + // Print HTTP-to-HTTPS port tracker stats + if dialers.HTTPToHTTPSPortTracker != nil { + if tracker, ok := dialers.HTTPToHTTPSPortTracker.(interface{ PrintStats() }); ok { + tracker.PrintStats() + } + } + } + // todo: error propagation without canonical straight error check is required by cloud? // use safe dereferencing to avoid potential panics in case of previous unchecked errors if v := ptrutil.Safe(results); !v.Load() { diff --git a/lib/tests/sdk_test.go b/lib/tests/sdk_test.go index 75309bbf8e..e1ce6e2260 100644 --- a/lib/tests/sdk_test.go +++ b/lib/tests/sdk_test.go @@ -19,6 +19,9 @@ var knownLeaks = []goleak.Option{ // net/http transport maintains idle connections which are closed with cooldown // hence they don't count as leaks goleak.IgnoreAnyFunction("net/http.(*http2ClientConn).readLoop"), + // expirable LRU cache creates a background goroutine for TTL expiration that persists + // see: https://github.com/hashicorp/golang-lru/blob/770151e9c8cdfae1797826b7b74c33d6f103fbd8/expirable/expirable_lru.go#L79 + goleak.IgnoreAnyContainingPkg("github.com/hashicorp/golang-lru/v2/expirable"), } func TestSimpleNuclei(t *testing.T) { diff --git a/pkg/protocols/common/protocolstate/dialers.go b/pkg/protocols/common/protocolstate/dialers.go index 91bdbae514..6269d80134 100644 --- a/pkg/protocols/common/protocolstate/dialers.go +++ b/pkg/protocols/common/protocolstate/dialers.go @@ -15,6 +15,12 @@ type Dialers struct { RawHTTPClient *rawhttp.Client DefaultHTTPClient *retryablehttp.Client HTTPClientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client] + PerHostHTTPPool any + PerHostRateLimitPool any + ConnectionReuseTracker any + HTTPToHTTPSPortTracker any // *httpclientpool.HTTPToHTTPSPortTracker + ShardedHTTPPool any // *httpclientpool.ShardedClientPool + InputCount int // Total number of input targets for sharding calculation NetworkPolicy *networkpolicy.NetworkPolicy LocalFileAccessAllowed bool RestrictLocalNetworkAccess bool diff --git a/pkg/protocols/common/protocolstate/memguardian_test.go b/pkg/protocols/common/protocolstate/memguardian_test.go index 7306b81e23..72a41cea4b 100644 --- a/pkg/protocols/common/protocolstate/memguardian_test.go +++ b/pkg/protocols/common/protocolstate/memguardian_test.go @@ -18,6 +18,9 @@ func TestMemGuardianGoroutineLeak(t *testing.T) { goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"), goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"), goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"), + // expirable LRU cache creates a background goroutine for TTL expiration that persists + // see: https://github.com/hashicorp/golang-lru/blob/770151e9c8cdfae1797826b7b74c33d6f103fbd8/expirable/expirable_lru.go#L79 + goleak.IgnoreAnyContainingPkg("github.com/hashicorp/golang-lru/v2/expirable"), ) // Initialize memguardian if not already initialized diff --git a/pkg/protocols/common/protocolstate/state.go b/pkg/protocols/common/protocolstate/state.go index 61232df1a5..21174df6ef 100644 --- a/pkg/protocols/common/protocolstate/state.go +++ b/pkg/protocols/common/protocolstate/state.go @@ -215,9 +215,25 @@ func initDialers(options *types.Options) error { SetLfaAllowed(options) + // Set input count for sharding calculation (will be updated later when input provider is ready) + dialersInstance.InputCount = 0 + return nil } +// SetInputCount sets the input count for sharding calculation +func SetInputCount(executionId string, count int) { + dialers := GetDialersWithId(executionId) + if dialers == nil { + return + } + + dialers.Lock() + defer dialers.Unlock() + + dialers.InputCount = count +} + // isIpAssociatedWithInterface checks if the given IP is associated with the given interface. func isIpAssociatedWithInterface(sourceIP, interfaceName string) (bool, error) { addrs, err := interfaceAddresses(interfaceName) diff --git a/pkg/protocols/http/build_request.go b/pkg/protocols/http/build_request.go index c8f4d447fe..ac91487b8d 100644 --- a/pkg/protocols/http/build_request.go +++ b/pkg/protocols/http/build_request.go @@ -450,9 +450,27 @@ func (r *requestGenerator) fillRequest(req *retryablehttp.Request, values map[st } } - // In case of multiple threads the underlying connection should remain open to allow reuse - if r.request.Threads <= 0 && req.Header.Get("Connection") == "" && r.options.Options.ScanStrategy != scanstrategy.HostSpray.String() { + // Respect connection reuse policy from smart analyzer + // If policy is ReuseUnsafe, ensure connection is closed + switch r.request.connectionReusePolicy { + case ReuseUnsafe: + // Explicitly set Connection: close header if not already present + if req.Header.Get("Connection") == "" { + req.Header.Set("Connection", "close") + } req.Close = true + case ReuseSafe: + // For safe requests, ensure connection can be reused + // Don't set req.Close = true, allow connection pooling + // Remove any existing "Connection: close" header if present + if strings.EqualFold(req.Header.Get("Connection"), "close") { + req.Header.Del("Connection") + } + default: + // Legacy behavior: In case of multiple threads the underlying connection should remain open to allow reuse + if r.request.Threads <= 0 && req.Header.Get("Connection") == "" && r.options.Options.ScanStrategy != scanstrategy.HostSpray.String() { + req.Close = true + } } // Check if the user requested a request body diff --git a/pkg/protocols/http/http.go b/pkg/protocols/http/http.go index d461db5920..e6a1c960c8 100644 --- a/pkg/protocols/http/http.go +++ b/pkg/protocols/http/http.go @@ -158,6 +158,10 @@ type Request struct { // - "AWS" Signature SignatureTypeHolder `yaml:"signature,omitempty" json:"signature,omitempty" jsonschema:"title=signature is the http request signature method,description=Signature is the HTTP Request signature Method,enum=AWS"` + // connectionReusePolicy stores the analyzed connection reuse policy + // This is set during Compile() based on template analysis + connectionReusePolicy ConnectionReusePolicy `yaml:"-" json:"-"` + // description: | // SkipSecretFile skips the authentication or authorization configured in the secret file. SkipSecretFile bool `yaml:"skip-secret-file,omitempty" json:"skip-secret-file,omitempty" jsonschema:"title=bypass secret file,description=Skips the authentication or authorization configured in the secret file"` @@ -304,13 +308,33 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { return errors.Wrap(err, "validation error") } + // Analyze connection reuse policy to determine if we can safely reuse connections + reusePolicy := request.AnalyzeConnectionReuse() + request.connectionReusePolicy = reusePolicy + + // Determine if keep-alive should be disabled + // If policy is ReuseUnsafe, we must disable keep-alive to preserve existing behavior + // Otherwise, use the standard logic (which may enable keep-alive) + var disableKeepAlive bool + switch reusePolicy { + case ReuseUnsafe: + // Preserve existing behavior: disable keep-alive for unsafe requests + disableKeepAlive = true + case ReuseSafe: + // Enable keep-alive for safe requests to allow connection pooling/sharding + disableKeepAlive = false + default: + // If ReuseUnknown, use the standard logic + disableKeepAlive = httputil.ShouldDisableKeepAlive(options.Options) + } + connectionConfiguration := &httpclientpool.Configuration{ Threads: request.Threads, MaxRedirects: request.MaxRedirects, NoTimeout: false, DisableCookie: request.DisableCookie, Connection: &httpclientpool.ConnectionConfiguration{ - DisableKeepAlive: httputil.ShouldDisableKeepAlive(options.Options), + DisableKeepAlive: disableKeepAlive, }, RedirectFlow: httpclientpool.DontFollowRedirect, } @@ -345,6 +369,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { } request.connConfiguration = connectionConfiguration + // At compile time, no hostname is available yet, so pass empty string client, err := httpclientpool.Get(options.Options, connectionConfiguration) if err != nil { return errors.Wrap(err, "could not get dns client") @@ -536,6 +561,18 @@ const ( SetThreadToCountZero = "set-thread-count-to-zero" ) +// ConnectionReusePolicy determines whether a request can safely reuse connections +type ConnectionReusePolicy int + +const ( + // ReuseUnknown indicates the policy hasn't been analyzed yet + ReuseUnknown ConnectionReusePolicy = iota + // ReuseSafe indicates the request can safely reuse connections (enable pooling/sharding) + ReuseSafe + // ReuseUnsafe indicates the request must close connections (preserve existing behavior) + ReuseUnsafe +) + func init() { stats.NewEntry(SetThreadToCountZero, "Setting thread count to 0 for %d templates, dynamic extractors are not supported with payloads yet") } @@ -549,3 +586,88 @@ func (r *Request) UpdateOptions(opts *protocols.ExecutorOptions) { func (request *Request) HasFuzzing() bool { return len(request.Fuzzing) > 0 } + +// AnalyzeConnectionReuse determines if a request can safely reuse connections. +// Returns ReuseUnsafe if connection closure is required, ReuseSafe otherwise. +// This analysis ensures backward compatibility by preserving connection-close behavior +// when necessary while enabling pooling/sharding for other requests. +func (r *Request) AnalyzeConnectionReuse() ConnectionReusePolicy { + // Priority 1: Check for explicit "Connection: close" header in raw requests + for _, raw := range r.Raw { + if hasConnectionCloseHeader(raw) { + return ReuseUnsafe + } + } + + // Priority 2: Check for "Connection: close" in regular headers + for key, value := range r.Headers { + if strings.EqualFold(key, "Connection") && strings.Contains(strings.ToLower(value), "close") { + return ReuseUnsafe + } + } + + // Priority 3: Check for time-based analyzers that require connection closure + // Time-based attacks need fresh connections to measure timing accurately + if r.Analyzer != nil && r.Analyzer.Name == "time_delay" { + return ReuseUnsafe + } + + // Priority 4: Check for raw HTTP (unsafe) with explicit connection control + if r.Unsafe { + // Analyze raw request for connection directives + for _, raw := range r.Raw { + if hasConnectionCloseHeader(raw) { + return ReuseUnsafe + } + } + } + + // Priority 5: Check for specific request patterns that require closure + // This can be extended based on template analysis + if requiresConnectionClosure(r) { + return ReuseUnsafe + } + + // Default: Safe to reuse - enable connection pooling/sharding + return ReuseSafe +} + +// hasConnectionCloseHeader checks if a raw HTTP request contains "Connection: close" +// Case-insensitive check for both "Connection:" and "close" +func hasConnectionCloseHeader(raw string) bool { + rawLower := strings.ToLower(raw) + // Check for "connection:" header + if !strings.Contains(rawLower, "connection:") { + return false + } + // Check for "close" value after "connection:" + // Handle various formats: "Connection: close", "Connection:Close", "Connection: close\r\n", etc. + connIdx := strings.Index(rawLower, "connection:") + if connIdx == -1 { + return false + } + // Extract the value after "connection:" + valueStart := connIdx + len("connection:") + // Skip whitespace + for valueStart < len(rawLower) && (rawLower[valueStart] == ' ' || rawLower[valueStart] == '\t') { + valueStart++ + } + // Check if the value contains "close" + value := rawLower[valueStart:] + // Find end of line or end of string + if newlineIdx := strings.IndexAny(value, "\r\n"); newlineIdx != -1 { + value = value[:newlineIdx] + } + return strings.Contains(value, "close") +} + +// requiresConnectionClosure checks for specific patterns that require connection closure +// This can be extended based on template analysis +func requiresConnectionClosure(r *Request) bool { + // Add specific patterns that require connection closure + // Example: If request has specific headers that indicate stateful protocol + // Example: If request uses specific authentication that requires fresh connections + + // For now, no special requirements beyond what's already checked + return false +} diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 0520985bef..80667166f0 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -22,7 +22,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" "github.com/projectdiscovery/nuclei/v3/pkg/types" - "github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy" + "github.com/projectdiscovery/ratelimit" "github.com/projectdiscovery/rawhttp" "github.com/projectdiscovery/retryablehttp-go" urlutil "github.com/projectdiscovery/utils/url" @@ -38,6 +38,12 @@ func Init(options *types.Options) error { forceMaxRedirects = options.MaxRedirects } + // Initialize connection reuse tracker early to ensure it's always available for tracking + _ = GetConnectionReuseTracker(options) + + // Initialize HTTP-to-HTTPS port tracker early to ensure it's always available + _ = GetHTTPToHTTPSPortTracker(options) + return nil } @@ -180,6 +186,18 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C return wrappedGet(options, configuration) } +func isMultiThreadWithJar(configuration *Configuration) bool { + return configuration.Threads > 0 && configuration.Connection != nil && configuration.Connection.HasCookieJar() +} + +func hashWithCookieJar(hash string, configuration *Configuration) string { + if isMultiThreadWithJar(configuration) { + jar := configuration.Connection.GetCookieJar() + return hash + fmt.Sprintf("cookieptr%p", jar) + } + return hash +} + // wrappedGet wraps a get operation without normal client check func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { var err error @@ -189,7 +207,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId) } - hash := configuration.Hash() + hash := hashWithCookieJar(configuration.Hash(), configuration) if client, ok := dialers.HTTPClientPool.Get(hash); ok { return client, nil } @@ -204,12 +222,13 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl // because this won't work on slow hosts retryableHttpOptions.NoAdjustTimeout = true - if configuration.Threads > 0 || options.ScanStrategy == scanstrategy.HostSpray.String() { - // Single host + // with threading always allow connection reuse + if configuration.Threads > 0 { retryableHttpOptions = retryablehttp.DefaultOptionsSingle disableKeepAlives = false maxIdleConnsPerHost = 500 maxConnsPerHost = 500 + maxIdleConns = 500 } retryableHttpOptions.RetryWaitMax = 10 * time.Second @@ -358,15 +377,213 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl } client.CheckRetry = retryablehttp.HostSprayRetryPolicy() - // Only add to client pool if we don't have a cookie jar in place. - if jar == nil { + if jar == nil || isMultiThreadWithJar(configuration) { if err := dialers.HTTPClientPool.Set(hash, client); err != nil { return nil, err } } + return client, nil } +// GetForTarget creates or gets a client for a specific target +// Supports three modes: +// 1. Per-host pooling (--per-host-client-pool flag) +// 2. Sharded pooling (--http-client-shards flag) +// 3. Standard pooling (default) +// Respects connection reuse policy: if DisableKeepAlive is true, uses standard pool +func GetForTarget(options *types.Options, configuration *Configuration, targetURL string) (*retryablehttp.Client, error) { + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId) + } + + // Check connection reuse policy: if keep-alive is disabled, skip pooling/sharding + // This preserves existing behavior for templates requiring connection closure + if configuration.Connection != nil && configuration.Connection.DisableKeepAlive { + // Use standard client pool (no connection reuse) + return Get(options, configuration) + } + + // Priority 1: Per-host pooling (if flag is set) + if options.PerHostClientPool { + dialers.Lock() + if dialers.PerHostHTTPPool == nil { + dialers.PerHostHTTPPool = NewPerHostClientPool(1024, 5*time.Minute, 30*time.Minute) + } + dialers.Unlock() + + pool, ok := dialers.PerHostHTTPPool.(*PerHostClientPool) + if ok && pool != nil { + return pool.GetOrCreate(targetURL, func() (*retryablehttp.Client, error) { + cfg := configuration.Clone() + if cfg.Connection == nil { + cfg.Connection = &ConnectionConfiguration{} + } + cfg.Connection.DisableKeepAlive = false + + // Override Threads to force connection pool settings + originalThreads := cfg.Threads + cfg.Threads = 1 + client, err := wrappedGet(options, cfg) + cfg.Threads = originalThreads + + return client, err + }) + } + } + + // Priority 2: Sharded pooling (if flag is set) + if options.HTTPClientShards { + dialers.Lock() + if dialers.ShardedHTTPPool == nil { + // Calculate optimal shard count based on input size + numShards := 0 // 0 triggers automatic calculation + inputSize := dialers.InputCount + + pool, err := NewShardedClientPool(numShards, options, configuration, inputSize) + if err != nil { + dialers.Unlock() + return nil, fmt.Errorf("failed to create sharded client pool: %w", err) + } + dialers.ShardedHTTPPool = pool + } + dialers.Unlock() + + pool, ok := dialers.ShardedHTTPPool.(*ShardedClientPool) + if ok && pool != nil { + client, _ := pool.GetClientForHost(targetURL) + return client, nil + } + } + + // Priority 3: Standard client pool (default) + return Get(options, configuration) +} + +// GetPerHostRateLimiter gets or creates a rate limiter for a specific host +// Returns nil if per-host rate limiting is not enabled +func GetPerHostRateLimiter(options *types.Options, hostname string) (*ratelimit.Limiter, error) { + if !options.PerHostRateLimit { + return nil, nil + } + + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId) + } + + dialers.Lock() + if dialers.PerHostRateLimitPool == nil { + // Keep entries for the entire scan duration - no TTL-based eviction during scan + // maxIdleTime: 24 hours (entries only evicted after 24 hours of inactivity) + // maxLifetime: 24 hours (maximum lifetime of entries) + // This ensures all hosts are tracked throughout the entire scan, even for very long scans + dialers.PerHostRateLimitPool = NewPerHostRateLimitPool(1024, 24*time.Hour, 24*time.Hour, options) + } + dialers.Unlock() + + pool, ok := dialers.PerHostRateLimitPool.(*PerHostRateLimitPool) + if !ok || pool == nil { + return nil, nil + } + + return pool.GetOrCreate(hostname) +} + +// RecordPerHostRateLimitRequest records a request for pps stats calculation +func RecordPerHostRateLimitRequest(options *types.Options, hostname string) { + if !options.PerHostRateLimit || hostname == "" { + return + } + + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return + } + + pool, ok := dialers.PerHostRateLimitPool.(*PerHostRateLimitPool) + if !ok || pool == nil { + return + } + + pool.RecordRequest(hostname) +} + +// GetConnectionReuseTracker gets or creates the connection reuse tracker +func GetConnectionReuseTracker(options *types.Options) *ConnectionReuseTracker { + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return nil + } + + dialers.Lock() + if dialers.ConnectionReuseTracker == nil { + // Keep entries for the entire scan duration - no TTL-based eviction during scan + // maxIdleTime: 24 hours (entries only evicted after 24 hours of inactivity) + // maxLifetime: 24 hours (maximum lifetime of entries) + // This ensures all hosts are tracked throughout the entire scan, even for very long scans + dialers.ConnectionReuseTracker = NewConnectionReuseTracker(1024, 24*time.Hour, 24*time.Hour) + } + dialers.Unlock() + + tracker, ok := dialers.ConnectionReuseTracker.(*ConnectionReuseTracker) + if !ok || tracker == nil { + return nil + } + + return tracker +} + +// GetHTTPToHTTPSPortTracker gets or creates the HTTP-to-HTTPS port tracker +func GetHTTPToHTTPSPortTracker(options *types.Options) *HTTPToHTTPSPortTracker { + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return nil + } + + dialers.Lock() + if dialers.HTTPToHTTPSPortTracker == nil { + dialers.HTTPToHTTPSPortTracker = NewHTTPToHTTPSPortTracker() + } + dialers.Unlock() + + tracker, ok := dialers.HTTPToHTTPSPortTracker.(*HTTPToHTTPSPortTracker) + if !ok || tracker == nil { + return nil + } + + return tracker +} + +// RecordConnectionReuse records a connection reuse event +func RecordConnectionReuse(options *types.Options, hostname string, reused bool) { + if hostname == "" { + return + } + + tracker := GetConnectionReuseTracker(options) + if tracker == nil { + return + } + + tracker.RecordConnection(hostname, reused) +} + +// RecordHTTPToHTTPSPortMismatch records that a host:port requires HTTPS +func RecordHTTPToHTTPSPortMismatch(options *types.Options, hostname string) { + if hostname == "" { + return + } + + tracker := GetHTTPToHTTPSPortTracker(options) + if tracker == nil { + return + } + + tracker.RecordHTTPToHTTPSPort(hostname) +} + type RedirectFlow uint8 const ( diff --git a/pkg/protocols/http/httpclientpool/connection_reuse_tracker.go b/pkg/protocols/http/httpclientpool/connection_reuse_tracker.go new file mode 100644 index 0000000000..77e30f3809 --- /dev/null +++ b/pkg/protocols/http/httpclientpool/connection_reuse_tracker.go @@ -0,0 +1,477 @@ +package httpclientpool + +import ( + "fmt" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/projectdiscovery/gologger" + urlutil "github.com/projectdiscovery/utils/url" +) + +type ConnectionReuseTracker struct { + cache *expirable.LRU[string, *connectionReuseEntry] + capacity int + mu sync.Mutex + + totalConnections atomic.Uint64 + totalReused atomic.Uint64 + totalNewConnections atomic.Uint64 + + // Protocol-specific counters + totalHTTPConnections atomic.Uint64 + totalHTTPSConnections atomic.Uint64 + totalHTTPReused atomic.Uint64 + totalHTTPSReused atomic.Uint64 + totalHTTPNewConnections atomic.Uint64 + totalHTTPSNewConnections atomic.Uint64 +} + +type connectionReuseEntry struct { + host string + createdAt time.Time + totalConnections atomic.Uint64 + totalReused atomic.Uint64 + totalNewConnections atomic.Uint64 + accessCount atomic.Uint64 + + // Protocol-specific counters per host + totalHTTPConnections atomic.Uint64 + totalHTTPSConnections atomic.Uint64 + totalHTTPReused atomic.Uint64 + totalHTTPSReused atomic.Uint64 + totalHTTPNewConnections atomic.Uint64 + totalHTTPSNewConnections atomic.Uint64 +} + +func NewConnectionReuseTracker(size int, maxIdleTime, maxLifetime time.Duration) *ConnectionReuseTracker { + if size <= 0 { + size = 1024 + } + // For global scan tracking, use very long TTL to keep entries for entire scan duration + // Default to 24 hours if not specified, which should cover even very long scans + if maxIdleTime == 0 { + maxIdleTime = 24 * time.Hour + } + if maxLifetime == 0 { + maxLifetime = 24 * time.Hour + } + + ttl := maxIdleTime + if maxLifetime < maxIdleTime { + ttl = maxLifetime + } + + tracker := &ConnectionReuseTracker{ + cache: expirable.NewLRU[string, *connectionReuseEntry]( + size, + func(key string, value *connectionReuseEntry) { + gologger.Debug().Msgf("[connection-reuse-tracker] Evicted entry for %s (age: %v, connections: %d, reused: %d)", + key, time.Since(value.createdAt), value.totalConnections.Load(), value.totalReused.Load()) + }, + ttl, + ), + capacity: size, + } + + return tracker +} + +// RecordConnection records a connection event (new or reused) for a host +func (t *ConnectionReuseTracker) RecordConnection(hostname string, reused bool) { + if hostname == "" { + return + } + + normalizedHost := normalizeHostForConnectionReuse(hostname) + if normalizedHost == "" { + return + } + + // Detect protocol (HTTP vs HTTPS) from the original hostname/URL + isHTTPS := isHTTPSConnection(hostname) + + t.totalConnections.Add(1) + if reused { + t.totalReused.Add(1) + } else { + t.totalNewConnections.Add(1) + } + + // Update protocol-specific global counters + if isHTTPS { + t.totalHTTPSConnections.Add(1) + if reused { + t.totalHTTPSReused.Add(1) + } else { + t.totalHTTPSNewConnections.Add(1) + } + } else { + t.totalHTTPConnections.Add(1) + if reused { + t.totalHTTPReused.Add(1) + } else { + t.totalHTTPNewConnections.Add(1) + } + } + + entry := t.getOrCreateEntry(normalizedHost) + if entry == nil { + return + } + + entry.totalConnections.Add(1) + entry.accessCount.Add(1) + if reused { + entry.totalReused.Add(1) + } else { + entry.totalNewConnections.Add(1) + } + + // Update protocol-specific per-host counters + if isHTTPS { + entry.totalHTTPSConnections.Add(1) + if reused { + entry.totalHTTPSReused.Add(1) + } else { + entry.totalHTTPSNewConnections.Add(1) + } + } else { + entry.totalHTTPConnections.Add(1) + if reused { + entry.totalHTTPReused.Add(1) + } else { + entry.totalHTTPNewConnections.Add(1) + } + } +} + +// isHTTPSConnection detects if a connection is HTTPS based on the URL/hostname +func isHTTPSConnection(hostname string) bool { + if hostname == "" { + return false + } + + // Check for https:// scheme prefix + if strings.HasPrefix(strings.ToLower(hostname), "https://") { + return true + } + + // Check if port is 443 (HTTPS default port) + if strings.HasSuffix(hostname, ":443") { + return true + } + + // Try to parse as URL to get scheme + parsed, err := urlutil.Parse(hostname) + if err == nil && parsed.Scheme == "https" { + return true + } + + // Default to HTTP if we can't determine + return false +} + +func (t *ConnectionReuseTracker) getOrCreateEntry(normalizedHost string) *connectionReuseEntry { + if entry, ok := t.cache.Get(normalizedHost); ok { + return entry + } + + t.mu.Lock() + defer t.mu.Unlock() + + // Double-check after acquiring lock + if entry, ok := t.cache.Peek(normalizedHost); ok { + return entry + } + + entry := &connectionReuseEntry{ + host: normalizedHost, + createdAt: time.Now(), + } + entry.totalConnections.Store(0) + entry.totalReused.Store(0) + entry.totalNewConnections.Store(0) + entry.accessCount.Store(0) + entry.totalHTTPConnections.Store(0) + entry.totalHTTPSConnections.Store(0) + entry.totalHTTPReused.Store(0) + entry.totalHTTPSReused.Store(0) + entry.totalHTTPNewConnections.Store(0) + entry.totalHTTPSNewConnections.Store(0) + + evicted := t.cache.Add(normalizedHost, entry) + if evicted { + _ = evicted + // Entry was evicted, but we still return the new entry + } + + return entry +} + +func (t *ConnectionReuseTracker) Size() int { + return t.cache.Len() +} + +func (t *ConnectionReuseTracker) Stats() ConnectionReuseStats { + return ConnectionReuseStats{ + TotalConnections: t.totalConnections.Load(), + TotalReused: t.totalReused.Load(), + TotalNewConnections: t.totalNewConnections.Load(), + Hosts: t.Size(), + TotalHTTPConnections: t.totalHTTPConnections.Load(), + TotalHTTPSConnections: t.totalHTTPSConnections.Load(), + TotalHTTPReused: t.totalHTTPReused.Load(), + TotalHTTPSReused: t.totalHTTPSReused.Load(), + TotalHTTPNewConnections: t.totalHTTPNewConnections.Load(), + TotalHTTPSNewConnections: t.totalHTTPSNewConnections.Load(), + } +} + +type ConnectionReuseStats struct { + TotalConnections uint64 + TotalReused uint64 + TotalNewConnections uint64 + Hosts int + TotalHTTPConnections uint64 + TotalHTTPSConnections uint64 + TotalHTTPReused uint64 + TotalHTTPSReused uint64 + TotalHTTPNewConnections uint64 + TotalHTTPSNewConnections uint64 +} + +func (t *ConnectionReuseTracker) PrintStats() { + stats := t.Stats() + reuseRate := float64(0) + if stats.TotalConnections > 0 { + reuseRate = float64(stats.TotalReused) * 100 / float64(stats.TotalConnections) + } + + httpReuseRate := float64(0) + if stats.TotalHTTPConnections > 0 { + httpReuseRate = float64(stats.TotalHTTPReused) * 100 / float64(stats.TotalHTTPConnections) + } + + httpsReuseRate := float64(0) + if stats.TotalHTTPSConnections > 0 { + httpsReuseRate = float64(stats.TotalHTTPSReused) * 100 / float64(stats.TotalHTTPSConnections) + } + + gologger.Info().Msgf("[connection-reuse-tracker] Connection reuse stats: Total=%d Reused=%d New=%d ReuseRate=%.1f%% Hosts=%d", + stats.TotalConnections, stats.TotalReused, stats.TotalNewConnections, reuseRate, stats.Hosts) + gologger.Info().Msgf("[connection-reuse-tracker] Protocol breakdown: HTTP=%d (Reused=%d, ReuseRate=%.1f%%) HTTPS=%d (Reused=%d, ReuseRate=%.1f%%)", + stats.TotalHTTPConnections, stats.TotalHTTPReused, httpReuseRate, + stats.TotalHTTPSConnections, stats.TotalHTTPSReused, httpsReuseRate) +} + +func (t *ConnectionReuseTracker) PrintPerHostStats() { + if t.Size() == 0 { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + hostStats := []struct { + host string + totalConnections uint64 + totalReused uint64 + totalNewConnections uint64 + reuseRate float64 + age time.Duration + totalHTTPConnections uint64 + totalHTTPSConnections uint64 + totalHTTPReused uint64 + totalHTTPSReused uint64 + httpReuseRate float64 + httpsReuseRate float64 + }{} + + for _, key := range t.cache.Keys() { + entry, ok := t.cache.Peek(key) + if !ok || entry == nil { + continue + } + + totalConn := entry.totalConnections.Load() + totalReused := entry.totalReused.Load() + totalNew := entry.totalNewConnections.Load() + reuseRate := float64(0) + if totalConn > 0 { + reuseRate = float64(totalReused) * 100 / float64(totalConn) + } + age := time.Since(entry.createdAt) + + httpConn := entry.totalHTTPConnections.Load() + httpsConn := entry.totalHTTPSConnections.Load() + httpReused := entry.totalHTTPReused.Load() + httpsReused := entry.totalHTTPSReused.Load() + + httpReuseRate := float64(0) + if httpConn > 0 { + httpReuseRate = float64(httpReused) * 100 / float64(httpConn) + } + + httpsReuseRate := float64(0) + if httpsConn > 0 { + httpsReuseRate = float64(httpsReused) * 100 / float64(httpsConn) + } + + hostStats = append(hostStats, struct { + host string + totalConnections uint64 + totalReused uint64 + totalNewConnections uint64 + reuseRate float64 + age time.Duration + totalHTTPConnections uint64 + totalHTTPSConnections uint64 + totalHTTPReused uint64 + totalHTTPSReused uint64 + httpReuseRate float64 + httpsReuseRate float64 + }{ + host: key, + totalConnections: totalConn, + totalReused: totalReused, + totalNewConnections: totalNew, + reuseRate: reuseRate, + age: age, + totalHTTPConnections: httpConn, + totalHTTPSConnections: httpsConn, + totalHTTPReused: httpReused, + totalHTTPSReused: httpsReused, + httpReuseRate: httpReuseRate, + httpsReuseRate: httpsReuseRate, + }) + } + + if len(hostStats) == 0 { + return + } + + gologger.Info().Msgf("[connection-reuse-tracker] Per-host connection reuse:") + for _, stat := range hostStats { + gologger.Info().Msgf(" %s: %d reused / %d total (%.1f%% reuse rate, age: %v)", + stat.host, stat.totalReused, stat.totalConnections, stat.reuseRate, stat.age.Round(time.Second)) + if stat.totalHTTPConnections > 0 || stat.totalHTTPSConnections > 0 { + protocolDetails := []string{} + if stat.totalHTTPConnections > 0 { + protocolDetails = append(protocolDetails, fmt.Sprintf("HTTP: %d reused / %d total (%.1f%%)", + stat.totalHTTPReused, stat.totalHTTPConnections, stat.httpReuseRate)) + } + if stat.totalHTTPSConnections > 0 { + protocolDetails = append(protocolDetails, fmt.Sprintf("HTTPS: %d reused / %d total (%.1f%%)", + stat.totalHTTPSReused, stat.totalHTTPSConnections, stat.httpsReuseRate)) + } + if len(protocolDetails) > 0 { + gologger.Info().Msgf(" Protocol breakdown: %s", strings.Join(protocolDetails, ", ")) + } + } + } +} + +func (t *ConnectionReuseTracker) Close() { + t.cache.Purge() +} + +// normalizeHostForConnectionReuse extracts and normalizes host:port from URL (same as rate limit) +func normalizeHostForConnectionReuse(rawURL string) string { + if rawURL == "" { + return "" + } + + parsed, err := urlutil.Parse(rawURL) + if err != nil { + // If parsing fails, try to extract host:port manually + return extractHostPortFromStringForReuse(rawURL) + } + + scheme := parsed.Scheme + if scheme == "" { + scheme = "http" + } + + // Extract just the hostname (without port) and port separately + hostname := parsed.Hostname() + if hostname == "" { + // Fallback: try to extract from Host field + host := parsed.Host + if host != "" { + // Split host:port if port is present + if h, _, err := net.SplitHostPort(host); err == nil { + hostname = h + } else { + hostname = host + } + } + } + + if hostname == "" { + return extractHostPortFromStringForReuse(rawURL) + } + + port := parsed.Port() + if port == "" { + // Use default ports based on scheme + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + + // Return just hostname:port (no scheme prefix) + return fmt.Sprintf("%s:%s", hostname, port) +} + +// extractHostPortFromStringForReuse attempts to extract host:port from a string when URL parsing fails +func extractHostPortFromStringForReuse(s string) string { + original := s + scheme := "http" + + // Remove scheme prefix if present + if strings.HasPrefix(s, "http://") { + s = strings.TrimPrefix(s, "http://") + scheme = "http" + } else if strings.HasPrefix(s, "https://") { + s = strings.TrimPrefix(s, "https://") + scheme = "https" + } + + // Extract up to first /, ?, #, space, or newline (path/query/fragment separator) + if idx := strings.IndexAny(s, "/?# \n\r\t"); idx != -1 { + s = s[:idx] + } + + if s == "" { + return original // Return original if we can't extract anything + } + + // Validate and split host:port + host, port, err := net.SplitHostPort(s) + if err == nil { + // Valid host:port format + if port == "" { + // Port is empty, use default + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + // Return just host:port (no scheme prefix) + return fmt.Sprintf("%s:%s", host, port) + } + + // No port in string, add default port + if scheme == "https" { + return fmt.Sprintf("%s:443", s) + } + return fmt.Sprintf("%s:80", s) +} diff --git a/pkg/protocols/http/httpclientpool/http_to_https_tracker.go b/pkg/protocols/http/httpclientpool/http_to_https_tracker.go new file mode 100644 index 0000000000..7ec7241c92 --- /dev/null +++ b/pkg/protocols/http/httpclientpool/http_to_https_tracker.go @@ -0,0 +1,204 @@ +package httpclientpool + +import ( + "fmt" + "net" + "strings" + "sync/atomic" + + "github.com/projectdiscovery/gologger" + mapsutil "github.com/projectdiscovery/utils/maps" + urlutil "github.com/projectdiscovery/utils/url" +) + +// HTTPToHTTPSPortTracker tracks host:port combinations that require HTTPS +// This is used to automatically detect and correct cases where HTTP requests +// are sent to HTTPS ports (detected via 400 error with specific message) +type HTTPToHTTPSPortTracker struct { + ports *mapsutil.SyncLockMap[string, bool] + + // Statistics + totalDetections atomic.Uint64 + totalCorrections atomic.Uint64 +} + +// NewHTTPToHTTPSPortTracker creates a new HTTP-to-HTTPS port tracker +func NewHTTPToHTTPSPortTracker() *HTTPToHTTPSPortTracker { + return &HTTPToHTTPSPortTracker{ + ports: mapsutil.NewSyncLockMap[string, bool](), + } +} + +// RecordHTTPToHTTPSPort records that a host:port requires HTTPS +func (t *HTTPToHTTPSPortTracker) RecordHTTPToHTTPSPort(hostPort string) { + if hostPort == "" { + return + } + + normalizedHostPort := normalizeHostPortForTracker(hostPort) + if normalizedHostPort == "" { + return + } + + // Check if already recorded + if _, exists := t.ports.Get(normalizedHostPort); exists { + return // Already recorded, no need to log again + } + + // Record the host:port as requiring HTTPS + _ = t.ports.Set(normalizedHostPort, true) + t.totalDetections.Add(1) + + gologger.Debug().Msgf("[http-to-https-tracker] Detected HTTP-to-HTTPS port mismatch for %s", normalizedHostPort) +} + +// RequiresHTTPS checks if a host:port requires HTTPS +func (t *HTTPToHTTPSPortTracker) RequiresHTTPS(hostPort string) bool { + if hostPort == "" { + return false + } + + normalizedHostPort := normalizeHostPortForTracker(hostPort) + if normalizedHostPort == "" { + return false + } + + requiresHTTPS, ok := t.ports.Get(normalizedHostPort) + if !ok { + return false + } + + if requiresHTTPS { + t.totalCorrections.Add(1) + } + + return requiresHTTPS +} + +// Stats returns statistics about the tracker +func (t *HTTPToHTTPSPortTracker) Stats() HTTPToHTTPSPortStats { + // Note: SyncLockMap doesn't have a direct Len() method + // We track detections instead, which gives us the number of unique host:port combinations + // For exact count, we'd need to maintain a separate counter + return HTTPToHTTPSPortStats{ + TotalDetections: t.totalDetections.Load(), + TotalCorrections: t.totalCorrections.Load(), + TrackedPorts: int(t.totalDetections.Load()), // Approximate: each detection is a unique host:port + } +} + +// HTTPToHTTPSPortStats contains statistics about the HTTP-to-HTTPS port tracker +type HTTPToHTTPSPortStats struct { + TotalDetections uint64 + TotalCorrections uint64 + TrackedPorts int +} + +// PrintStats prints statistics about the tracker +func (t *HTTPToHTTPSPortTracker) PrintStats() { + stats := t.Stats() + if stats.TotalDetections == 0 { + return + } + + gologger.Info().Msgf("[http-to-https-tracker] HTTP-to-HTTPS port corrections: Detections=%d Corrections=%d TrackedPorts=%d", + stats.TotalDetections, stats.TotalCorrections, stats.TrackedPorts) +} + +// normalizeHostPortForTracker extracts and normalizes host:port from URL +// Returns format: "hostname:port" (e.g., "example.com:443", "example.com:2087") +func normalizeHostPortForTracker(rawURL string) string { + if rawURL == "" { + return "" + } + + parsed, err := urlutil.Parse(rawURL) + if err != nil { + // If parsing fails, try to extract host:port manually + return extractHostPortFromStringForHTTPS(rawURL) + } + + scheme := parsed.Scheme + if scheme == "" { + scheme = "http" + } + + // Extract hostname + hostname := parsed.Hostname() + if hostname == "" { + // Fallback: try to extract from Host field + host := parsed.Host + if host != "" { + // Split host:port if port is present + if h, _, err := net.SplitHostPort(host); err == nil { + hostname = h + } else { + hostname = host + } + } + } + + if hostname == "" { + return extractHostPortFromStringForHTTPS(rawURL) + } + + port := parsed.Port() + if port == "" { + // Use default ports based on scheme + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + + // Return just hostname:port (no scheme prefix) + return fmt.Sprintf("%s:%s", hostname, port) +} + +// extractHostPortFromStringForHTTPS attempts to extract host:port from a string when URL parsing fails +func extractHostPortFromStringForHTTPS(s string) string { + original := s + scheme := "http" + + // Remove scheme prefix if present + if strings.HasPrefix(s, "http://") { + s = strings.TrimPrefix(s, "http://") + scheme = "http" + } else if strings.HasPrefix(s, "https://") { + s = strings.TrimPrefix(s, "https://") + scheme = "https" + } + + // Extract up to first /, ?, #, space, or newline (path/query/fragment separator) + if idx := strings.IndexAny(s, "/?# \n\r\t"); idx != -1 { + s = s[:idx] + } + + if s == "" { + return original // Return original if we can't extract anything + } + + // Validate and split host:port + host, port, err := net.SplitHostPort(s) + if err == nil { + // Valid host:port format + if port == "" { + // Port is empty, use default + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + // Return just host:port (no scheme prefix) + return fmt.Sprintf("%s:%s", host, port) + } + + // No port in string, add default port + if scheme == "https" { + return fmt.Sprintf("%s:443", s) + } + return fmt.Sprintf("%s:80", s) +} + diff --git a/pkg/protocols/http/httpclientpool/perhost_pool.go b/pkg/protocols/http/httpclientpool/perhost_pool.go new file mode 100644 index 0000000000..b568fda283 --- /dev/null +++ b/pkg/protocols/http/httpclientpool/perhost_pool.go @@ -0,0 +1,249 @@ +package httpclientpool + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/retryablehttp-go" + urlutil "github.com/projectdiscovery/utils/url" +) + +type PerHostClientPool struct { + cache *expirable.LRU[string, *clientEntry] + capacity int + mu sync.Mutex + + hits atomic.Uint64 + misses atomic.Uint64 + evictions atomic.Uint64 +} + +type clientEntry struct { + client *retryablehttp.Client + createdAt time.Time + accessCount atomic.Uint64 +} + +func NewPerHostClientPool(size int, maxIdleTime, maxLifetime time.Duration) *PerHostClientPool { + if size <= 0 { + size = 1024 + } + if maxIdleTime == 0 { + maxIdleTime = 5 * time.Minute + } + if maxLifetime == 0 { + maxLifetime = 30 * time.Minute + } + + ttl := maxIdleTime + if maxLifetime < maxIdleTime { + ttl = maxLifetime + } + + pool := &PerHostClientPool{ + cache: expirable.NewLRU[string, *clientEntry]( + size, + func(key string, value *clientEntry) { + gologger.Debug().Msgf("[perhost-pool] Evicted client for %s (age: %v, accesses: %d)", + key, time.Since(value.createdAt), value.accessCount.Load()) + }, + ttl, + ), + capacity: size, + } + + return pool +} + +func (p *PerHostClientPool) GetOrCreate( + host string, + createFunc func() (*retryablehttp.Client, error), +) (*retryablehttp.Client, error) { + normalizedHost := normalizeHost(host) + + if entry, ok := p.cache.Get(normalizedHost); ok { + entry.accessCount.Add(1) + p.hits.Add(1) + return entry.client, nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + if entry, ok := p.cache.Peek(normalizedHost); ok { + entry.accessCount.Add(1) + p.hits.Add(1) + return entry.client, nil + } + + p.misses.Add(1) + + client, err := createFunc() + if err != nil { + return nil, err + } + + entry := &clientEntry{ + client: client, + createdAt: time.Now(), + } + entry.accessCount.Store(1) + + evicted := p.cache.Add(normalizedHost, entry) + if evicted { + p.evictions.Add(1) + } + + return client, nil +} + +func (p *PerHostClientPool) EvictHost(host string) bool { + normalizedHost := normalizeHost(host) + existed := p.cache.Remove(normalizedHost) + + if existed { + p.evictions.Add(1) + } + return existed +} + +func (p *PerHostClientPool) EvictAll() { + count := p.cache.Len() + p.cache.Purge() + p.evictions.Add(uint64(count)) +} + +func (p *PerHostClientPool) Size() int { + return p.cache.Len() +} + +func (p *PerHostClientPool) Stats() PoolStats { + return PoolStats{ + Hits: p.hits.Load(), + Misses: p.misses.Load(), + Evictions: p.evictions.Load(), + Size: p.Size(), + } +} + +func (p *PerHostClientPool) Close() { + p.EvictAll() +} + +func normalizeHost(rawURL string) string { + if rawURL == "" { + return "" + } + + parsed, err := urlutil.Parse(rawURL) + if err != nil { + return rawURL + } + + scheme := parsed.Scheme + if scheme == "" { + scheme = "http" + } + + host := parsed.Host + if host == "" { + host = parsed.Hostname() + } + + port := parsed.Port() + if port != "" { + return fmt.Sprintf("%s://%s:%s", scheme, parsed.Hostname(), port) + } + + if scheme == "https" && port == "" { + return fmt.Sprintf("%s://%s:443", scheme, parsed.Hostname()) + } + if scheme == "http" && port == "" { + return fmt.Sprintf("%s://%s:80", scheme, parsed.Hostname()) + } + + return fmt.Sprintf("%s://%s", scheme, host) +} + +type PoolStats struct { + Hits uint64 + Misses uint64 + Evictions uint64 + Size int +} + +func (p *PerHostClientPool) GetClientForHost(host string) (*retryablehttp.Client, bool) { + normalizedHost := normalizeHost(host) + + if entry, ok := p.cache.Peek(normalizedHost); ok { + return entry.client, true + } + return nil, false +} + +func (p *PerHostClientPool) ListAllClients() []string { + return p.cache.Keys() +} + +type ClientInfo struct { + Host string + CreatedAt time.Time + AccessCount uint64 + Age time.Duration +} + +func (p *PerHostClientPool) GetClientInfo(host string) *ClientInfo { + normalizedHost := normalizeHost(host) + + entry, ok := p.cache.Peek(normalizedHost) + if !ok { + return nil + } + + now := time.Now() + + return &ClientInfo{ + Host: normalizedHost, + CreatedAt: entry.createdAt, + AccessCount: entry.accessCount.Load(), + Age: now.Sub(entry.createdAt), + } +} + +func (p *PerHostClientPool) GetAllClientInfo() []*ClientInfo { + infos := []*ClientInfo{} + for _, key := range p.cache.Keys() { + if info := p.GetClientInfo(key); info != nil { + infos = append(infos, info) + } + } + return infos +} + +func (p *PerHostClientPool) Resize(size int) int { + evicted := p.cache.Resize(size) + p.capacity = size + return evicted +} + +func (p *PerHostClientPool) Cap() int { + return p.capacity +} + +func (p *PerHostClientPool) PrintStats() { + stats := p.Stats() + if stats.Size == 0 { + return + } + gologger.Verbose().Msgf("[perhost-pool] Connection reuse stats: Hits=%d Misses=%d HitRate=%.1f%% Hosts=%d", + stats.Hits, stats.Misses, + float64(stats.Hits)*100/float64(stats.Hits+stats.Misses+1), + stats.Size) +} + +func (p *PerHostClientPool) PrintTransportStats() { +} diff --git a/pkg/protocols/http/httpclientpool/perhost_ratelimit_pool.go b/pkg/protocols/http/httpclientpool/perhost_ratelimit_pool.go new file mode 100644 index 0000000000..07963d69c7 --- /dev/null +++ b/pkg/protocols/http/httpclientpool/perhost_ratelimit_pool.go @@ -0,0 +1,511 @@ +package httpclientpool + +import ( + "context" + "fmt" + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/nuclei/v3/pkg/types" + "github.com/projectdiscovery/nuclei/v3/pkg/utils" + "github.com/projectdiscovery/ratelimit" + urlutil "github.com/projectdiscovery/utils/url" +) + +type PerHostRateLimitPool struct { + cache *expirable.LRU[string, *rateLimitEntry] + capacity int + mu sync.Mutex + options *types.Options + maxLifetime time.Duration // Maximum lifetime for entries regardless of access + + hits atomic.Uint64 + misses atomic.Uint64 + evictions atomic.Uint64 +} + +type rateLimitEntry struct { + limiter *ratelimit.Limiter + createdAt time.Time + accessCount atomic.Uint64 + requestCount atomic.Uint64 + firstRequestAt atomic.Int64 // UnixNano timestamp + lastRequestAt atomic.Int64 // UnixNano timestamp + requestTimestamps []int64 // Ring buffer of recent request timestamps for pps calculation + requestMu sync.Mutex +} + +func NewPerHostRateLimitPool(size int, maxIdleTime, maxLifetime time.Duration, options *types.Options) *PerHostRateLimitPool { + if size <= 0 { + size = 1024 + } + // For global scan tracking, use very long TTL to keep entries for entire scan duration + // Default to 24 hours if not specified, which should cover even very long scans + if maxIdleTime == 0 { + maxIdleTime = 24 * time.Hour + } + if maxLifetime == 0 { + maxLifetime = 24 * time.Hour + } + + ttl := maxIdleTime + if maxLifetime < maxIdleTime { + ttl = maxLifetime + } + + pool := &PerHostRateLimitPool{ + cache: expirable.NewLRU[string, *rateLimitEntry]( + size, + func(key string, value *rateLimitEntry) { + if value.limiter != nil { + value.limiter.Stop() + } + gologger.Debug().Msgf("[perhost-ratelimit-pool] Evicted rate limiter for %s (age: %v, accesses: %d)", + key, time.Since(value.createdAt), value.accessCount.Load()) + }, + ttl, + ), + capacity: size, + options: options, + maxLifetime: maxLifetime, + } + + pool.cache.Purge() + + return pool +} + +func (p *PerHostRateLimitPool) GetOrCreate( + host string, +) (*ratelimit.Limiter, error) { + normalizedHost := normalizeHostForRateLimit(host) + + // Try to get entry (this refreshes TTL in expirable LRU) + if entry, ok := p.cache.Get(normalizedHost); ok { + // Check if entry has exceeded maxLifetime + if p.maxLifetime > 0 && time.Since(entry.createdAt) > p.maxLifetime { + // Entry is too old, need to evict and recreate + // Acquire lock to safely evict + p.mu.Lock() + // Double-check after acquiring lock (another goroutine might have evicted it) + if entry, ok := p.cache.Peek(normalizedHost); ok { + // Check maxLifetime again (entry might have been replaced) + if time.Since(entry.createdAt) > p.maxLifetime { + if entry.limiter != nil { + entry.limiter.Stop() + } + p.cache.Remove(normalizedHost) + p.evictions.Add(1) + // Fall through to create new entry + } else { + // Entry was replaced or is now valid + entry.accessCount.Add(1) + p.hits.Add(1) + p.mu.Unlock() + return entry.limiter, nil + } + } + // Entry was evicted or doesn't exist, continue to create new one + } else { + // Entry is valid (not expired by maxLifetime) + entry.accessCount.Add(1) + p.hits.Add(1) + return entry.limiter, nil + } + } else { + // Entry doesn't exist, acquire lock to create + p.mu.Lock() + } + + // At this point we have the lock and need to create a new entry + defer p.mu.Unlock() + + // Double-check after acquiring lock (another goroutine might have created it) + if entry, ok := p.cache.Peek(normalizedHost); ok { + // Check maxLifetime + if p.maxLifetime > 0 && time.Since(entry.createdAt) > p.maxLifetime { + // Entry is too old, evict it + if entry.limiter != nil { + entry.limiter.Stop() + } + p.cache.Remove(normalizedHost) + p.evictions.Add(1) + } else { + // Entry exists and is valid + entry.accessCount.Add(1) + p.hits.Add(1) + return entry.limiter, nil + } + } + + p.misses.Add(1) + + // Create new rate limiter for this host + limiter := utils.GetRateLimiter(context.Background(), p.options.RateLimit, p.options.RateLimitDuration) + + entry := &rateLimitEntry{ + limiter: limiter, + createdAt: time.Now(), + requestTimestamps: make([]int64, 0, 100), // Track last 100 requests for pps calculation + } + entry.accessCount.Store(1) + + evicted := p.cache.Add(normalizedHost, entry) + if evicted { + p.evictions.Add(1) + } + + return limiter, nil +} + +func (p *PerHostRateLimitPool) EvictHost(host string) bool { + normalizedHost := normalizeHostForRateLimit(host) + + // Get entry before removing to stop limiter + entry, ok := p.cache.Peek(normalizedHost) + if ok && entry != nil && entry.limiter != nil { + entry.limiter.Stop() + } + + existed := p.cache.Remove(normalizedHost) + if existed { + p.evictions.Add(1) + } + return existed +} + +func (p *PerHostRateLimitPool) EvictAll() { + keys := p.cache.Keys() + for _, key := range keys { + if entry, ok := p.cache.Peek(key); ok && entry != nil && entry.limiter != nil { + entry.limiter.Stop() + } + } + count := p.cache.Len() + p.cache.Purge() + p.evictions.Add(uint64(count)) +} + +func (p *PerHostRateLimitPool) Size() int { + return p.cache.Len() +} + +func (p *PerHostRateLimitPool) Stats() RateLimitPoolStats { + return RateLimitPoolStats{ + Hits: p.hits.Load(), + Misses: p.misses.Load(), + Evictions: p.evictions.Load(), + Size: p.Size(), + } +} + +func (p *PerHostRateLimitPool) Close() { + p.EvictAll() +} + +// normalizeHostForRateLimit extracts and normalizes host:port from URL for rate limit pool +// This ensures all requests to the same host:port use the same rate limiter, regardless of path +func normalizeHostForRateLimit(rawURL string) string { + if rawURL == "" { + return "" + } + + parsed, err := urlutil.Parse(rawURL) + if err != nil { + // If parsing fails, try to extract host:port manually + // This handles cases where the URL might be malformed + return extractHostPortFromString(rawURL) + } + + scheme := parsed.Scheme + if scheme == "" { + scheme = "http" + } + + // Extract just the hostname (without port) and port separately + hostname := parsed.Hostname() + if hostname == "" { + // Fallback: try to extract from Host field + host := parsed.Host + if host != "" { + // Split host:port if port is present + if h, _, err := net.SplitHostPort(host); err == nil { + hostname = h + } else { + hostname = host + } + } + } + + if hostname == "" { + return extractHostPortFromString(rawURL) + } + + port := parsed.Port() + if port == "" { + // Use default ports based on scheme + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + + // Return just hostname:port (no scheme prefix) + return fmt.Sprintf("%s:%s", hostname, port) +} + +// extractHostPortFromString attempts to extract host:port from a string when URL parsing fails +func extractHostPortFromString(s string) string { + original := s + scheme := "http" + + // Remove scheme prefix if present + if strings.HasPrefix(s, "http://") { + s = strings.TrimPrefix(s, "http://") + scheme = "http" + } else if strings.HasPrefix(s, "https://") { + s = strings.TrimPrefix(s, "https://") + scheme = "https" + } + + // Extract up to first /, ?, #, space, or newline (path/query/fragment separator) + if idx := strings.IndexAny(s, "/?# \n\r\t"); idx != -1 { + s = s[:idx] + } + + if s == "" { + return original // Return original if we can't extract anything + } + + // Validate and split host:port + host, port, err := net.SplitHostPort(s) + if err == nil { + // Valid host:port format + if port == "" { + // Port is empty, use default + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + // Return just host:port (no scheme prefix) + return fmt.Sprintf("%s:%s", host, port) + } + + // No port in string, add default port + if scheme == "https" { + return fmt.Sprintf("%s:443", s) + } + return fmt.Sprintf("%s:80", s) +} + +type RateLimitPoolStats struct { + Hits uint64 + Misses uint64 + Evictions uint64 + Size int +} + +func (p *PerHostRateLimitPool) GetLimiterForHost(host string) (*ratelimit.Limiter, bool) { + normalizedHost := normalizeHostForRateLimit(host) + + if entry, ok := p.cache.Peek(normalizedHost); ok { + return entry.limiter, true + } + return nil, false +} + +func (p *PerHostRateLimitPool) ListAllLimiters() []string { + return p.cache.Keys() +} + +type RateLimitInfo struct { + Host string + CreatedAt time.Time + AccessCount uint64 + Age time.Duration +} + +func (p *PerHostRateLimitPool) GetRateLimitInfo(host string) *RateLimitInfo { + normalizedHost := normalizeHostForRateLimit(host) + + entry, ok := p.cache.Peek(normalizedHost) + if !ok { + return nil + } + + now := time.Now() + + return &RateLimitInfo{ + Host: normalizedHost, + CreatedAt: entry.createdAt, + AccessCount: entry.accessCount.Load(), + Age: now.Sub(entry.createdAt), + } +} + +func (p *PerHostRateLimitPool) GetAllRateLimitInfo() []*RateLimitInfo { + infos := []*RateLimitInfo{} + for _, key := range p.cache.Keys() { + if info := p.GetRateLimitInfo(key); info != nil { + infos = append(infos, info) + } + } + return infos +} + +func (p *PerHostRateLimitPool) Resize(size int) int { + evicted := p.cache.Resize(size) + p.capacity = size + return evicted +} + +func (p *PerHostRateLimitPool) Cap() int { + return p.capacity +} + +// RecordRequest records a request timestamp for a host to calculate pps +func (p *PerHostRateLimitPool) RecordRequest(host string) { + normalizedHost := normalizeHostForRateLimit(host) + entry, ok := p.cache.Peek(normalizedHost) + if !ok || entry == nil { + return + } + + now := time.Now().UnixNano() + entry.requestCount.Add(1) + + // Set first request time if not set + if entry.firstRequestAt.Load() == 0 { + entry.firstRequestAt.Store(now) + } + entry.lastRequestAt.Store(now) + + // Track recent timestamps for pps calculation (keep last 100) + entry.requestMu.Lock() + entry.requestTimestamps = append(entry.requestTimestamps, now) + if len(entry.requestTimestamps) > 100 { + // Keep only last 100 timestamps + entry.requestTimestamps = entry.requestTimestamps[len(entry.requestTimestamps)-100:] + } + entry.requestMu.Unlock() +} + +// calculatePPS calculates requests per second for a host based on recent requests +func (p *PerHostRateLimitPool) calculatePPS(entry *rateLimitEntry) float64 { + if entry == nil { + return 0 + } + + entry.requestMu.Lock() + defer entry.requestMu.Unlock() + + if len(entry.requestTimestamps) < 2 { + // Need at least 2 requests to calculate pps + return 0 + } + + now := time.Now().UnixNano() + // Calculate pps based on requests in the last second + oneSecondAgo := now - int64(time.Second) + recentRequests := 0 + for i := len(entry.requestTimestamps) - 1; i >= 0; i-- { + if entry.requestTimestamps[i] >= oneSecondAgo { + recentRequests++ + } else { + break + } + } + + // If we have recent requests, use them; otherwise calculate from total time span + if recentRequests > 0 { + return float64(recentRequests) + } + + // Fallback: calculate average pps from first to last request + first := entry.firstRequestAt.Load() + last := entry.lastRequestAt.Load() + if first == 0 || last == 0 || last <= first { + return 0 + } + + duration := time.Duration(last - first) + if duration <= 0 { + return 0 + } + + totalRequests := entry.requestCount.Load() + if totalRequests < 2 { + return 0 + } + + return float64(totalRequests) / duration.Seconds() +} + +func (p *PerHostRateLimitPool) PrintStats() { + stats := p.Stats() + if stats.Size == 0 { + return + } + gologger.Info().Msgf("[perhost-ratelimit-pool] Rate limit stats: Hits=%d Misses=%d HitRate=%.1f%% Hosts=%d", + stats.Hits, stats.Misses, + float64(stats.Hits)*100/float64(stats.Hits+stats.Misses+1), + stats.Size) +} + +// PrintPerHostPPSStats prints requests per second for each host +func (p *PerHostRateLimitPool) PrintPerHostPPSStats() { + if p.Size() == 0 { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + hostStats := []struct { + host string + pps float64 + requests uint64 + age time.Duration + }{} + + for _, key := range p.cache.Keys() { + entry, ok := p.cache.Peek(key) + if !ok || entry == nil { + continue + } + + pps := p.calculatePPS(entry) + requests := entry.requestCount.Load() + age := time.Since(entry.createdAt) + + hostStats = append(hostStats, struct { + host string + pps float64 + requests uint64 + age time.Duration + }{ + host: key, + pps: pps, + requests: requests, + age: age, + }) + } + + if len(hostStats) == 0 { + return + } + + gologger.Info().Msgf("[perhost-ratelimit-pool] Per-host requests per second (pps):") + for _, stat := range hostStats { + gologger.Info().Msgf(" %s: %.2f pps (total: %d requests, age: %v)", + stat.host, stat.pps, stat.requests, stat.age.Round(time.Second)) + } +} diff --git a/pkg/protocols/http/httpclientpool/sharded_pool.go b/pkg/protocols/http/httpclientpool/sharded_pool.go new file mode 100644 index 0000000000..36121a7412 --- /dev/null +++ b/pkg/protocols/http/httpclientpool/sharded_pool.go @@ -0,0 +1,464 @@ +package httpclientpool + +import ( + "context" + "crypto/tls" + "fmt" + "hash/fnv" + "math" + "net" + "net/http" + "net/url" + "sync/atomic" + "time" + + "github.com/pkg/errors" + "golang.org/x/net/proxy" + + "github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate" + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/utils" + "github.com/projectdiscovery/nuclei/v3/pkg/types" + "github.com/projectdiscovery/retryablehttp-go" + urlutil "github.com/projectdiscovery/utils/url" +) + +const ( + // DefaultShardCount is the default number of shards when auto-calculated + DefaultShardCount = 16 + // MinShardCount is the minimum number of shards + MinShardCount = 4 + // MaxShardCount is the maximum number of shards + MaxShardCount = 256 +) + +// ShardedClientPool manages HTTP clients distributed across multiple shards +// Each shard handles a subset of hosts, enabling connection reuse while +// preventing overload of a single client +type ShardedClientPool struct { + shards []*ShardEntry + numShards int + + // Statistics + totalRequests atomic.Uint64 + shardRequests []atomic.Uint64 +} + +// ShardEntry represents a single shard with its HTTP client +type ShardEntry struct { + client *retryablehttp.Client + hostCount atomic.Int64 // Number of unique hosts using this shard + requestCount atomic.Uint64 // Total requests through this shard + createdAt time.Time + lastAccess atomic.Value // time.Time + maxIdleConnsPerHost int // Calculated: baseMaxIdle / estimatedHostsPerShard +} + +// calculateOptimalShardCount calculates the optimal number of shards based on input size +// Formula: min(256, max(4, sqrt(inputSize) * 2)) +func calculateOptimalShardCount(inputSize int) int { + if inputSize <= 0 { + // Use default if input size is unknown + return DefaultShardCount + } + + // Formula: sqrt(inputSize) * 2, clamped between 4 and 256 + // This scales shards with input size: more inputs = more shards + optimalShards := int(math.Sqrt(float64(inputSize)) * 2) + + // Ensure minimum of 4 shards for distribution + if optimalShards < MinShardCount { + optimalShards = MinShardCount + } + + // Cap at maximum of 256 shards + if optimalShards > MaxShardCount { + optimalShards = MaxShardCount + } + + return optimalShards +} + +// NewShardedClientPool creates a new sharded client pool with automatic shard calculation +func NewShardedClientPool(numShards int, options *types.Options, baseConfig *Configuration, inputSize int) (*ShardedClientPool, error) { + // If numShards is 0 or negative, calculate optimal number based on input size + if numShards <= 0 { + numShards = calculateOptimalShardCount(inputSize) + } else { + // Validate provided shard count + if numShards < MinShardCount { + numShards = MinShardCount + } + if numShards > MaxShardCount { + numShards = MaxShardCount + } + } + + // Base max idle conns per host (from existing logic: 500 when threading enabled) + baseMaxIdleConnsPerHost := 500 + if baseConfig.Threads == 0 { + // If no threading, we still want some pooling for sharding + baseMaxIdleConnsPerHost = 500 + } + + // Use a fixed maxIdleConnsPerHost per shard + // This provides good connection reuse without needing to estimate host distribution + // Each shard can handle multiple hosts efficiently + maxIdleConnsPerHost := baseMaxIdleConnsPerHost + + pool := &ShardedClientPool{ + shards: make([]*ShardEntry, numShards), + numShards: numShards, + shardRequests: make([]atomic.Uint64, numShards), + } + + // Initialize all shards with calculated maxIdleConnsPerHost + for i := 0; i < numShards; i++ { + client, err := createShardClient(options, baseConfig, maxIdleConnsPerHost) + if err != nil { + return nil, fmt.Errorf("failed to create shard %d client: %w", i, err) + } + + pool.shards[i] = &ShardEntry{ + client: client, + createdAt: time.Now(), + maxIdleConnsPerHost: maxIdleConnsPerHost, + } + pool.shards[i].lastAccess.Store(time.Now()) + } + + gologger.Debug().Msgf("[sharded-pool] Initialized %d HTTP client shards (maxIdleConnsPerHost=%d)", + numShards, maxIdleConnsPerHost) + return pool, nil +} + +// GetClientForHost returns the HTTP client for the given host based on consistent hashing +// Returns the client and the shard index +func (p *ShardedClientPool) GetClientForHost(host string) (*retryablehttp.Client, int) { + shardIndex := p.getShardIndex(host) + shard := p.shards[shardIndex] + + p.shardRequests[shardIndex].Add(1) + p.totalRequests.Add(1) + shard.requestCount.Add(1) + shard.lastAccess.Store(time.Now()) + + return shard.client, shardIndex +} + +// getShardIndex calculates the shard index for a host using consistent hashing +func (p *ShardedClientPool) getShardIndex(host string) int { + normalizedHost := normalizeHostForSharding(host) + + hash := fnv.New32a() + hash.Write([]byte(normalizedHost)) + + return int(hash.Sum32()) % p.numShards +} + +// normalizeHostForSharding normalizes a host URL for consistent sharding +// Returns host:port format (e.g., "example.com:443") +func normalizeHostForSharding(rawURL string) string { + if rawURL == "" { + return "" + } + + parsed, err := urlutil.Parse(rawURL) + if err != nil { + // Fallback: try to extract host:port manually + return extractHostPortFromStringForReuse(rawURL) + } + + hostname := parsed.Hostname() + if hostname == "" { + return extractHostPortFromStringForReuse(rawURL) + } + + port := parsed.Port() + if port == "" { + scheme := parsed.Scheme + if scheme == "" { + scheme = "http" + } + if scheme == "https" { + port = "443" + } else { + port = "80" + } + } + + return fmt.Sprintf("%s:%s", hostname, port) +} + +// createShardClient creates an HTTP client for a shard with custom maxIdleConnsPerHost +func createShardClient(options *types.Options, config *Configuration, maxIdleConnsPerHost int) (*retryablehttp.Client, error) { + cfg := config.Clone() + if cfg.Connection == nil { + cfg.Connection = &ConnectionConfiguration{} + } + + // Enable keep-alive for connection reuse + cfg.Connection.DisableKeepAlive = false + + // Disable cookies for sharded clients to avoid concurrent map writes + // cookiejar.Jar is not thread-safe and sharded clients are shared across goroutines + // If cookies are needed, use per-host pooling instead + cfg.DisableCookie = true + + // Set threading to enable connection pooling + originalThreads := cfg.Threads + cfg.Threads = 1 // Minimal threading, sharding provides concurrency + + // Create a modified hash that includes the custom maxIdle value + // This ensures shards with different maxIdle values get different clients + hash := hashWithCookieJar(cfg.Hash(), cfg) + hash = hash + fmt.Sprintf(":maxIdle:%d", maxIdleConnsPerHost) + + // Use wrappedGetWithCustomMaxIdle to create client with custom maxIdleConnsPerHost + client, err := wrappedGetWithCustomMaxIdle(options, cfg, maxIdleConnsPerHost, hash) + cfg.Threads = originalThreads + + return client, err +} + +// wrappedGetWithCustomMaxIdle creates an HTTP client with a custom maxIdleConnsPerHost value +// This is used for sharding to distribute idle connections evenly per host +func wrappedGetWithCustomMaxIdle(options *types.Options, configuration *Configuration, customMaxIdleConnsPerHost int, hash string) (*retryablehttp.Client, error) { + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + if dialers == nil { + return nil, fmt.Errorf("dialers not initialized for %s", options.ExecutionId) + } + + // Check if client already exists with this hash + if client, ok := dialers.HTTPClientPool.Get(hash); ok { + return client, nil + } + + // Use standard wrappedGet logic but override maxIdleConnsPerHost + retryableHttpOptions := retryablehttp.DefaultOptionsSingle + disableKeepAlives := false + maxIdleConns := 500 + maxConnsPerHost := customMaxIdleConnsPerHost + maxIdleConnsPerHost := customMaxIdleConnsPerHost // Use custom value + + retryableHttpOptions.RetryWaitMax = 10 * time.Second + retryableHttpOptions.RetryMax = options.Retries + retryableHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second + if configuration.ResponseHeaderTimeout > 0 && configuration.ResponseHeaderTimeout > retryableHttpOptions.Timeout { + retryableHttpOptions.Timeout = configuration.ResponseHeaderTimeout + } + + redirectFlow := configuration.RedirectFlow + maxRedirects := configuration.MaxRedirects + + if forceMaxRedirects > 0 { + switch { + case options.FollowHostRedirects: + redirectFlow = FollowSameHostRedirect + default: + redirectFlow = FollowAllRedirect + } + maxRedirects = forceMaxRedirects + } + if options.DisableRedirects { + options.FollowRedirects = false + options.FollowHostRedirects = false + redirectFlow = DontFollowRedirect + maxRedirects = 0 + } + + // Override connection's settings if required + if configuration.Connection != nil { + disableKeepAlives = configuration.Connection.DisableKeepAlive + } + + // Set the base TLS configuration definition + tlsConfig := &tls.Config{ + Renegotiation: tls.RenegotiateOnceAsClient, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS10, + ClientSessionCache: tls.NewLRUClientSessionCache(1024), + } + + if options.SNI != "" { + tlsConfig.ServerName = options.SNI + } + + // Add the client certificate authentication to the request if it's configured + var err error + tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options) + if err != nil { + return nil, errors.Wrap(err, "could not create client certificate") + } + + // responseHeaderTimeout is max timeout for response headers to be read + responseHeaderTimeout := options.GetTimeouts().HttpResponseHeaderTimeout + if configuration.ResponseHeaderTimeout != 0 { + responseHeaderTimeout = configuration.ResponseHeaderTimeout + } + + if responseHeaderTimeout < retryableHttpOptions.Timeout { + responseHeaderTimeout = retryableHttpOptions.Timeout + } + + if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 { + responseHeaderTimeout = configuration.Connection.CustomMaxTimeout + } + + transport := &http.Transport{ + ForceAttemptHTTP2: options.ForceAttemptHTTP2, + DialContext: dialers.Fastdialer.Dial, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if options.TlsImpersonate { + return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) + } + if options.HasClientCertificates() || options.ForceAttemptHTTP2 { + return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) + } + return dialers.Fastdialer.DialTLS(ctx, network, addr) + }, + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, // Custom value for sharding + MaxConnsPerHost: maxConnsPerHost, // Same value for consistency + TLSClientConfig: tlsConfig, + DisableKeepAlives: disableKeepAlives, + ResponseHeaderTimeout: responseHeaderTimeout, + } + + if options.AliveHttpProxy != "" { + if proxyURL, err := url.Parse(options.AliveHttpProxy); err == nil { + transport.Proxy = http.ProxyURL(proxyURL) + } + } else if options.AliveSocksProxy != "" { + socksURL, proxyErr := url.Parse(options.AliveSocksProxy) + if proxyErr != nil { + return nil, proxyErr + } + + dialer, err := proxy.FromURL(socksURL, proxy.Direct) + if err != nil { + return nil, err + } + + dc := dialer.(interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) + }) + + transport.DialContext = dc.DialContext + transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := dc.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + if tlsConfig.ServerName == "" { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tlsConfig.ServerName = host + } + return tls.Client(conn, tlsConfig), nil + } + } + + // CRITICAL: Never use cookiejars in sharded clients! + // cookiejar.Jar is not thread-safe and sharded clients are shared across goroutines. + // This causes "fatal error: concurrent map writes" when multiple goroutines access the same jar. + // Cookies are disabled for sharded clients (set in createShardClient via cfg.DisableCookie = true) + + httpclient := &http.Client{ + Transport: transport, + CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects), + } + if !configuration.NoTimeout { + httpclient.Timeout = options.GetTimeouts().HttpTimeout + if configuration.Connection != nil && configuration.Connection.CustomMaxTimeout > 0 { + httpclient.Timeout = configuration.Connection.CustomMaxTimeout + } + } + client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions) + // jar is always nil for sharded clients (thread safety) + client.CheckRetry = retryablehttp.HostSprayRetryPolicy() + + // Store in pool with modified hash + // Sharded clients never use cookiejars (disabled for thread safety), so always store in pool + if err := dialers.HTTPClientPool.Set(hash, client); err != nil { + return nil, errors.Wrap(err, "could not store client in pool") + } + + return client, nil +} + +// Stats returns statistics about the sharded pool +func (p *ShardedClientPool) Stats() ShardedPoolStats { + stats := ShardedPoolStats{ + NumShards: p.numShards, + TotalRequests: p.totalRequests.Load(), + ShardStats: make([]ShardStat, p.numShards), + } + + for i := 0; i < p.numShards; i++ { + shard := p.shards[i] + if shard == nil { + continue + } + + lastAccess := time.Time{} + if la := shard.lastAccess.Load(); la != nil { + lastAccess = la.(time.Time) + } + + stats.ShardStats[i] = ShardStat{ + Index: i, + RequestCount: shard.requestCount.Load(), + HostCount: shard.hostCount.Load(), + LastAccess: lastAccess, + } + } + + return stats +} + +// ShardedPoolStats contains statistics about the sharded pool +type ShardedPoolStats struct { + NumShards int + TotalRequests uint64 + ShardStats []ShardStat +} + +// ShardStat contains statistics for a single shard +type ShardStat struct { + Index int + RequestCount uint64 + HostCount int64 + LastAccess time.Time +} + +// PrintStats prints statistics about the sharded pool +func (p *ShardedClientPool) PrintStats() { + stats := p.Stats() + if stats.TotalRequests == 0 { + return + } + + gologger.Info().Msgf("[sharded-pool] HTTP client sharding stats: Shards=%d TotalRequests=%d", + stats.NumShards, stats.TotalRequests) + + // Print per-shard stats in verbose mode + // Note: Verbose logging is controlled by gologger's global level + // We'll always print per-shard stats if there are requests + for _, shardStat := range stats.ShardStats { + if shardStat.RequestCount > 0 { + gologger.Verbose().Msgf(" Shard %d: Requests=%d Hosts=%d LastAccess=%v", + shardStat.Index, shardStat.RequestCount, shardStat.HostCount, + shardStat.LastAccess.Round(time.Second)) + } + } +} + +// Close closes the sharded pool (clients are managed by the main HTTPClientPool) +func (p *ShardedClientPool) Close() { + // Clients are managed by the main HTTPClientPool, no cleanup needed + // This is just for interface compatibility +} diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 0b7a35bc4f..b2e5925c3a 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -8,6 +8,7 @@ import ( "io" "maps" "net/http" + "net/http/httptrace" "strconv" "strings" "sync" @@ -70,6 +71,22 @@ func (request *Request) Type() templateTypes.ProtocolType { return templateTypes.HTTPProtocol } +// rateLimitTake handles rate limiting, using per-host rate limiter if enabled, otherwise global +func (request *Request) rateLimitTake(hostname string) { + if request.options.Options.PerHostRateLimit && hostname != "" { + // Use per-host rate limiter + if limiter, err := httpclientpool.GetPerHostRateLimiter(request.options.Options, hostname); err == nil && limiter != nil { + limiter.Take() + // Record request for pps stats + httpclientpool.RecordPerHostRateLimitRequest(request.options.Options, hostname) + return + } + // Fallback to global if per-host fails + } + // Use global rate limiter (or unlimited if per-host is enabled but hostname is empty) + request.options.RateLimitTake() +} + // executeRaceRequest executes race condition request for a URL func (request *Request) executeRaceRequest(input *contextargs.Context, previous output.InternalEvent, callback protocols.OutputEventCallback) error { reqURL := input.MetaInput.Input @@ -267,7 +284,15 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV spmHandler.Release() continue } - request.options.RateLimitTake() + // Extract hostname for per-host rate limiting (use full URL - normalization happens in rateLimitTake) + hostname := t.updatedInput.MetaInput.Input + if t.req != nil && t.req.URL() != "" { + hostname = t.req.URL() + } else if t.req != nil && t.req.request != nil && t.req.request.URL != nil { + // Extract from request URL if available + hostname = t.req.request.String() + } + request.rateLimitTake(hostname) select { case <-spmHandler.Done(): spmHandler.Release() @@ -515,8 +540,6 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa executeFunc := func(data string, payloads, dynamicValue map[string]interface{}) (bool, error) { hasInteractMatchers := interactsh.HasMatchers(request.CompiledOperators) - request.options.RateLimitTake() - ctx := request.newContext(input) ctxWithTimeout, cancel := context.WithTimeoutCause(ctx, request.options.Options.GetTimeouts().HttpTimeout, ErrHttpEngineRequestDeadline) defer cancel() @@ -535,6 +558,14 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa // but this should be replaced once templateCtx is refactored properly updatedInput := contextargs.GetCopyIfHostOutdated(input, generatedHttpRequest.URL()) + // Extract hostname for per-host rate limiting (use generated request URL - normalization happens in rateLimitTake) + hostname := input.MetaInput.Input + if generatedHttpRequest.URL() != "" { + // Use the generated URL directly - the normalization function will extract host:port correctly + hostname = generatedHttpRequest.URL() + } + request.rateLimitTake(hostname) + if generatedHttpRequest.customCancelFunction != nil { defer generatedHttpRequest.customCancelFunction() } @@ -805,6 +836,14 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ } httpclient := request.httpClient + // Extract target URL for per-host pooling (use request URL or fallback to input) + targetURL := input.MetaInput.Input + if generatedRequest.request != nil && generatedRequest.request.URL != nil { + targetURL = generatedRequest.request.String() + } else if generatedRequest.request != nil { + targetURL = generatedRequest.request.String() + } + // this will be assigned/updated if this specific request has a custom configuration var modifiedConfig *httpclientpool.Configuration @@ -825,7 +864,20 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ modifiedConfig.ResponseHeaderTimeout = updatedTimeout.Timeout } - if modifiedConfig != nil { + // Prefer per-host pooled client or sharded client for better reuse when flags are enabled + // choose config to use (modified if present else default) + configToUse := modifiedConfig + if configToUse == nil { + configToUse = request.connConfiguration + } + if request.options.Options.PerHostClientPool || request.options.Options.HTTPClientShards { + if client, err := httpclientpool.GetForTarget(request.options.Options, configToUse, targetURL); err == nil { + httpclient = client + } else { + return errors.Wrap(err, "could not get http client") + } + } else if modifiedConfig != nil { + modifiedConfig.Threads = request.Threads client, err := httpclientpool.Get(request.options.Options, modifiedConfig) if err != nil { return errors.Wrap(err, "could not get http client") @@ -833,6 +885,45 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ httpclient = client } + // Check if HTTP-to-HTTPS port correction is needed before sending request + if generatedRequest.request != nil && generatedRequest.request.URL != nil { + tracker := httpclientpool.GetHTTPToHTTPSPortTracker(request.options.Options) + if tracker != nil { + requestURL := generatedRequest.request.String() + if tracker.RequiresHTTPS(requestURL) { + // Modify request URL scheme from http to https + if generatedRequest.request.Scheme == "http" { + generatedRequest.request.Scheme = "https" + gologger.Debug().Msgf("[http-to-https-tracker] Corrected HTTP to HTTPS for %s", requestURL) + } + } + } + } + + // Track connection reuse for all HTTP requests + if generatedRequest.request != nil { + // Extract hostname for connection reuse tracking (use actual request URL, same as rate limiting) + hostnameForReuse := input.MetaInput.Input + if generatedRequest.request.URL != nil { + // Use the actual request URL - normalization will extract host:port correctly + hostnameForReuse = generatedRequest.request.String() + } else if generatedRequest.URL() != "" { + // Fallback to generated request URL method + hostnameForReuse = generatedRequest.URL() + } else if targetURL != "" { + hostnameForReuse = targetURL + } + + trace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + // Record connection reuse event + httpclientpool.RecordConnectionReuse(request.options.Options, hostnameForReuse, info.Reused) + }, + } + ctx := httptrace.WithClientTrace(generatedRequest.request.Context(), trace) + generatedRequest.request = generatedRequest.request.WithContext(ctx) + } + resp, err = httpclient.Do(generatedRequest.request) } } @@ -964,8 +1055,25 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ bodyStr := respChain.BodyString() headersStr := respChain.HeadersString() + // Detect HTTP-to-HTTPS port mismatch (400 error with specific message) + statusCode := respChain.Response().StatusCode + if statusCode == 400 && strings.Contains(bodyStr, "The plain HTTP request was sent to HTTPS port") { + // Extract host:port from the request URL + var requestURL string + if generatedRequest.request != nil && generatedRequest.request.URL != nil { + requestURL = generatedRequest.request.String() + } else if generatedRequest.rawRequest != nil && generatedRequest.rawRequest.FullURL != "" { + requestURL = generatedRequest.rawRequest.FullURL + } else if respChain.Request() != nil && respChain.Request().URL != nil { + requestURL = respChain.Request().URL.String() + } + if requestURL != "" { + httpclientpool.RecordHTTPToHTTPSPortMismatch(request.options.Options, requestURL) + } + } + // log request stats - request.options.Output.RequestStatsLog(strconv.Itoa(respChain.Response().StatusCode), fullResponseStr) + request.options.Output.RequestStatsLog(strconv.Itoa(statusCode), fullResponseStr) // save response to projectfile onceFunc() diff --git a/pkg/protocols/http/request_fuzz.go b/pkg/protocols/http/request_fuzz.go index 3a7e2cc74a..1cfa6dfdb1 100644 --- a/pkg/protocols/http/request_fuzz.go +++ b/pkg/protocols/http/request_fuzz.go @@ -181,7 +181,9 @@ func (request *Request) executeGeneratedFuzzingRequest(gr fuzz.GeneratedRequest, if request.options.HostErrorsCache != nil && request.options.HostErrorsCache.Check(request.options.ProtocolType.String(), input) { return false } - request.options.RateLimitTake() + // Extract hostname for per-host rate limiting + hostname := input.MetaInput.Input + request.rateLimitTake(hostname) req := &generatedRequest{ request: gr.Request, dynamicValues: gr.DynamicValues, diff --git a/pkg/protocols/http/request_test.go b/pkg/protocols/http/request_test.go index 9eb7b100e4..9e22417313 100644 --- a/pkg/protocols/http/request_test.go +++ b/pkg/protocols/http/request_test.go @@ -388,6 +388,9 @@ func TestExecuteParallelHTTP_GoroutineLeaks(t *testing.T) { goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mpoolDrain"), goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).tCompaction"), goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mCompaction"), + // expirable LRU cache creates a background goroutine for TTL expiration that persists + // see: https://github.com/hashicorp/golang-lru/blob/770151e9c8cdfae1797826b7b74c33d6f103fbd8/expirable/expirable_lru.go#L79 + goleak.IgnoreAnyContainingPkg("github.com/hashicorp/golang-lru/v2/expirable"), ) options := testutils.DefaultOptions diff --git a/pkg/protocols/utils/http/requtils.go b/pkg/protocols/utils/http/requtils.go index bfc602a055..7a8d7dd6f5 100644 --- a/pkg/protocols/utils/http/requtils.go +++ b/pkg/protocols/utils/http/requtils.go @@ -48,5 +48,5 @@ func SetHeader(req *retryablehttp.Request, name, value string) { // ShouldDisableKeepAlive depending on scan strategy func ShouldDisableKeepAlive(options *types.Options) bool { // with host-spray strategy keep-alive must be enabled - return options.ScanStrategy != scanstrategy.HostSpray.String() + return options.TemplateThreads == 0 && options.ScanStrategy != scanstrategy.HostSpray.String() } diff --git a/pkg/types/types.go b/pkg/types/types.go index 0f6663f384..1d85b3f0a0 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -201,6 +201,22 @@ type Options struct { DebugResponse bool // DisableHTTPProbe disables http probing feature of input normalization DisableHTTPProbe bool + // PreflightPortScan enables a preflight resolve + TCP portscan and filters targets + // before running templates. Disabled by default. + PreflightPortScan bool + // PerHostClientPool enables per-host HTTP client pooling for better connection reuse. + // When enabled, each host gets its own client instance keyed by (host, configuration). + // Disabled by default. + PerHostClientPool bool + // HTTPClientShards enables HTTP client sharding for connection pooling. + // When enabled, hosts are distributed across a fixed number of HTTP client shards (auto-calculated, max 256). + // This provides a balance between connection reuse and memory efficiency. + // Disabled by default. + HTTPClientShards bool + // PerHostRateLimit enables per-host rate limiting for HTTP requests. + // When enabled, each host gets its own rate limiter and global rate limit becomes unlimited. + // Disabled by default. + PerHostRateLimit bool // LeaveDefaultPorts skips normalization of default ports LeaveDefaultPorts bool // AutomaticScan enables automatic tech based template execution