Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion internal/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,7 @@ func (r *Runner) RunEnumeration() error {
}()
}

now := time.Now()
enumeration := false
var results *atomic.Bool
results, err = r.runStandardEnumeration(executorOpts, store, executorEngine)
Expand All @@ -725,11 +726,17 @@ func (r *Runner) RunEnumeration() error {
}
r.fuzzFrequencyCache.Close()

r.progress.Stop()
timeTaken := time.Since(now)
// todo: error propagation without canonical straight error check is required by cloud?
// use safe dereferencing to avoid potential panics in case of previous unchecked errors
if v := ptrutil.Safe(results); !v.Load() {
gologger.Info().Msgf("No results found. Better luck next time!")
gologger.Info().Msgf("Scan completed in %s. No results found.", shortDur(timeTaken))
} else {
matchCount := r.output.ResultCount()
gologger.Info().Msgf("Scan completed in %s. %d matches found.", shortDur(timeTaken), matchCount)
}

// check if a passive scan was requested but no target was provided
if r.options.OfflineHTTP && len(r.options.Targets) == 0 && r.options.TargetsFilePath == "" {
return errors.Wrap(err, "missing required input (http response) to run passive templates")
Expand All @@ -738,6 +745,24 @@ func (r *Runner) RunEnumeration() error {
return err
}

func shortDur(d time.Duration) string {
if d < time.Minute {
return d.String()
}

// Truncate to the nearest minute
d = d.Truncate(time.Minute)
s := d.String()

if strings.HasSuffix(s, "m0s") {
s = s[:len(s)-2]
}
if strings.HasSuffix(s, "h0m") {
s = s[:len(s)-2]
}
return s
}

func (r *Runner) isInputNonHTTP() bool {
var nonURLInput bool
r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool {
Expand Down
2 changes: 2 additions & 0 deletions pkg/core/execute_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*temp
defer wp.Wait()

for _, template := range templatesList {
template := template

select {
case <-ctx.Done():
return results
Expand Down
10 changes: 10 additions & 0 deletions pkg/output/multi_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,13 @@ func (mw *MultiWriter) RequestStatsLog(statusCode, response string) {
writer.RequestStatsLog(statusCode, response)
}
}

func (mw *MultiWriter) ResultCount() int {
count := 0
for _, writer := range mw.writers {
if count := writer.ResultCount(); count > 0 {
return count
}
}
return count
}
9 changes: 9 additions & 0 deletions pkg/output/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type Writer interface {
RequestStatsLog(statusCode, response string)
// WriteStoreDebugData writes the request/response debug data to file
WriteStoreDebugData(host, templateID, eventType string, data string)
// ResultCount returns the total number of results written
ResultCount() int
}

// StandardWriter is a writer writing output to file and screen for results.
Expand All @@ -79,6 +81,8 @@ type StandardWriter struct {
// JSONLogRequestHook is a hook that can be used to log request/response
// when using custom server code with output
JSONLogRequestHook func(*JSONLogRequest)

resultCount atomic.Int32
}

var _ Writer = &StandardWriter{}
Expand Down Expand Up @@ -287,6 +291,10 @@ func NewStandardWriter(options *types.Options) (*StandardWriter, error) {
return writer, nil
}

func (w *StandardWriter) ResultCount() int {
return int(w.resultCount.Load())
}

// Write writes the event to file and/or screen.
func (w *StandardWriter) Write(event *ResultEvent) error {
// Enrich the result event with extra metadata on the template-path and url.
Expand Down Expand Up @@ -336,6 +344,7 @@ func (w *StandardWriter) Write(event *ResultEvent) error {
_, _ = w.outputFile.Write([]byte("\n"))
}
}
w.resultCount.Add(1)
return nil
}

Expand Down
3 changes: 3 additions & 0 deletions pkg/output/output_stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ func (tw *StatsOutputWriter) RequestStatsLog(statusCode, response string) {
tw.Tracker.TrackStatusCode(statusCode)
tw.Tracker.TrackWAFDetected(response)
}
func (tw *StatsOutputWriter) ResultCount() int {
return 0
}
4 changes: 1 addition & 3 deletions pkg/progress/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ func (p *StatsTicker) IncrementRequests() {

// SetRequests sets the counter by incrementing it with a delta
func (p *StatsTicker) SetRequests(count uint64) {
value, _ := p.stats.GetCounter("requests")
delta := count - value
p.stats.IncrementCounter("requests", int(delta))
p.stats.IncrementCounter("requests", int(count))
}

// IncrementMatched increments the matched counter by 1.
Expand Down
18 changes: 13 additions & 5 deletions pkg/protocols/common/hosterrorscache/hosterrorscache.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ func (c *Cache) NormalizeCacheValue(value string) string {

u, err := url.ParseRequestURI(value)
if err != nil || u.Host == "" {
if strings.Contains(value, ":") {
return normalizedValue
}
u, err2 := url.ParseRequestURI("https://" + value)
if err2 != nil {
return normalizedValue
Expand Down Expand Up @@ -236,14 +239,19 @@ func (c *Cache) GetKeyFromContext(ctx *contextargs.Context, err error) string {
// should be reflected in contextargs but it is not yet reflected in some cases
// and needs refactor of ScanContext + ContextArgs to achieve that
// i.e why we use real address from error if present
address := ctx.MetaInput.Address()
// get address override from error
var address string

// 1. the address carried inside the error (if the transport sets it)
if err != nil {
tmp := errkit.GetAttrValue(err, "address")
if tmp.Any() != nil {
address = tmp.String()
if v := errkit.GetAttrValue(err, "address"); v.Any() != nil {
address = v.String()
}
}

if address == "" {
address = ctx.MetaInput.Address()
}

finalValue := c.NormalizeCacheValue(address)
return finalValue
}
Expand Down
8 changes: 0 additions & 8 deletions pkg/protocols/http/build_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,6 @@ func (g *generatedRequest) URL() string {
return ""
}

// Total returns the total number of requests for the generator
func (r *requestGenerator) Total() int {
if r.payloadIterator != nil {
return len(r.request.Raw) * r.payloadIterator.Remaining()
}
return len(r.request.Path)
}

// Make creates a http request for the provided input.
// It returns ErrNoMoreRequests as error when all the requests have been exhausted.
func (r *requestGenerator) Make(ctx context.Context, input *contextargs.Context, reqData string, payloads, dynamicValues map[string]interface{}) (gr *generatedRequest, err error) {
Expand Down
21 changes: 2 additions & 19 deletions pkg/protocols/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
}
}

return nil
}

Expand All @@ -517,24 +516,8 @@ func (request *Request) RebuildGenerator() error {

// Requests returns the total number of requests the YAML rule will perform
func (request *Request) Requests() int {
if request.generator != nil {
payloadRequests := request.generator.NewIterator().Total()
if len(request.Raw) > 0 {
payloadRequests = payloadRequests * len(request.Raw)
}
if len(request.Path) > 0 {
payloadRequests = payloadRequests * len(request.Path)
}
return payloadRequests
}
if len(request.Raw) > 0 {
requests := len(request.Raw)
if requests == 1 && request.RaceNumberRequests != 0 {
requests *= request.RaceNumberRequests
}
return requests
}
return len(request.Path)
generator := request.newGenerator(false)
return generator.Total()
}

const (
Expand Down
14 changes: 9 additions & 5 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import (
"github.com/projectdiscovery/rawhttp"
convUtil "github.com/projectdiscovery/utils/conversion"
"github.com/projectdiscovery/utils/errkit"
errorutil "github.com/projectdiscovery/utils/errors"
httpUtils "github.com/projectdiscovery/utils/http"
"github.com/projectdiscovery/utils/reader"
sliceutil "github.com/projectdiscovery/utils/slice"
Expand Down Expand Up @@ -484,7 +483,6 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if err == types.ErrNoMoreRequests {
return true, nil
}
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return true, err
}
// ideally if http template used a custom port or hostname
Expand Down Expand Up @@ -541,14 +539,19 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if errors.Is(execReqErr, ErrMissingVars) {
return true, nil
}

if execReqErr != nil {
request.markHostError(updatedInput, execReqErr)

// if applicable mark the host as unresponsive
requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL())
reqKitErr := errkit.FromError(execReqErr)
reqKitErr.Msgf("got err while executing %v", generatedHttpRequest.URL())

requestErr = reqKitErr
request.options.Progress.IncrementFailedRequestsBy(1)
} else {
request.options.Progress.IncrementRequests()
}
request.markHostError(updatedInput, execReqErr)

// If this was a match, and we want to stop at first match, skip all further requests.
shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch
Expand Down Expand Up @@ -585,6 +588,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
requestErr = gotErr
}
if skip || gotErr != nil {
request.options.Progress.SetRequests(uint64(generator.Remaining() + 1))
break
}
}
Expand Down Expand Up @@ -1212,7 +1216,7 @@ func (request *Request) newContext(input *contextargs.Context) context.Context {

// markHostError checks if the error is a unreponsive host error and marks it
func (request *Request) markHostError(input *contextargs.Context, err error) {
if request.options.HostErrorsCache != nil {
if request.options.HostErrorsCache != nil && err != nil {
request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err)
}
}
Expand Down
64 changes: 64 additions & 0 deletions pkg/protocols/http/request_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,67 @@ func (r *requestGenerator) hasMarker(request string, mark flowMark) bool {
fo, hasOverrides := parseFlowAnnotations(request)
return hasOverrides && fo == mark
}

// Remaining returns the number of requests that are still left to be
// generated (and therefore to be sent) by this generator.
func (r *requestGenerator) Remaining() int {
var sequence []string
switch {
case len(r.request.Path) > 0:
sequence = r.request.Path
case len(r.request.Raw) > 0:
sequence = r.request.Raw
default:
return 0
}

remainingInCurrentPass := 0
for i := r.currentIndex; i < len(sequence); i++ {
if !r.hasMarker(sequence[i], Once) {
remainingInCurrentPass++
}
}

if r.payloadIterator == nil {
return remainingInCurrentPass
}

numRemainingPayloadSets := r.payloadIterator.Remaining()
totalValidInSequence := 0
for _, req := range sequence {
if !r.hasMarker(req, Once) {
totalValidInSequence++
}
}

// Total remaining = remaining in current pass + (remaining payload sets * requests per full pass)
return remainingInCurrentPass + numRemainingPayloadSets*totalValidInSequence
}

func (r *requestGenerator) Total() int {
var sequence []string
switch {
case len(r.request.Path) > 0:
sequence = r.request.Path
case len(r.request.Raw) > 0:
sequence = r.request.Raw
default:
return 0
}

applicableRequests := 0
additionalRequests := 0
for _, request := range sequence {
if !r.hasMarker(request, Once) {
applicableRequests++
} else {
additionalRequests++
}
}

if r.payloadIterator == nil {
return applicableRequests + additionalRequests
}

return (applicableRequests * r.payloadIterator.Total()) + additionalRequests
}
4 changes: 4 additions & 0 deletions pkg/testutils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ func (m *MockOutputWriter) Colorizer() aurora.Aurora {
return m.aurora
}

func (m *MockOutputWriter) ResultCount() int {
return 0
}

// Write writes the event to file and/or screen.
func (m *MockOutputWriter) Write(result *output.ResultEvent) error {
if m.WriteCallback != nil {
Expand Down
16 changes: 16 additions & 0 deletions pkg/tmplexec/flow/flow_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ type FlowExecutor struct {
// these are keys whose values are meant to be flatten before executing
// a request ex: if dynamic extractor returns ["value"] it will be converted to "value"
flattenKeys []string

executed *mapsutil.SyncLockMap[string, struct{}]
}

// NewFlowExecutor creates a new flow executor from a list of requests
Expand Down Expand Up @@ -98,6 +100,7 @@ func NewFlowExecutor(requests []protocols.Request, ctx *scan.ScanContext, option
results: results,
ctx: ctx,
program: program,
executed: mapsutil.NewSyncLockMap[string, struct{}](),
}
return f, nil
}
Expand Down Expand Up @@ -243,6 +246,7 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {

// pass flow and execute the js vm and handle errors
_, err := runtime.RunProgram(f.program)
f.reconcileProgress()
if err != nil {
ctx.LogError(err)
return errorutil.NewWithErr(err).Msgf("failed to execute flow\n%v\n", f.options.Flow)
Expand All @@ -256,6 +260,18 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
return nil
}

func (f *FlowExecutor) reconcileProgress() {
for proto, list := range f.allProtocols {
for idx, req := range list {
key := requestKey(proto, req, strconv.Itoa(idx+1))
if _, seen := f.executed.Get(key); !seen {
// never executed → pretend it finished so that stats match
f.options.Progress.SetRequests(uint64(req.Requests()))
}
}
}
}

// GetRuntimeErrors returns all runtime errors (i.e errors from all protocol combined)
func (f *FlowExecutor) GetRuntimeErrors() error {
errs := []error{}
Expand Down
Loading
Loading