Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions pkg/protocols/common/protocolstate/memguardian.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,25 @@ func StartActiveMemGuardian(ctx context.Context) {

memTimer = time.NewTicker(memguardian.DefaultInterval)
ctx, cancelFunc = context.WithCancel(ctx)
go func() {

ticker := memTimer
go func(t *time.Ticker) {
if t == nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-memTimer.C:
case <-t.C:
if IsLowOnMemory() {
_ = GlobalGuardBytesBufferAlloc()
} else {
GlobalRestoreBytesBufferAlloc()
}
}
}
}()
}(ticker)
}

func StopActiveMemGuardian() {
Expand All @@ -52,9 +57,13 @@ func StopActiveMemGuardian() {
return
}

if cancelFunc != nil {
cancelFunc()
cancelFunc = nil
}
if memTimer != nil {
memTimer.Stop()
cancelFunc()
memTimer = nil
}
}

Expand Down
123 changes: 123 additions & 0 deletions pkg/protocols/common/protocolstate/memguardian_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package protocolstate

import (
"context"
"testing"
"time"

"github.com/projectdiscovery/utils/memguardian"
"github.com/stretchr/testify/require"
"github.com/tarunKoyalwar/goleak"
)

// TestMemGuardianGoroutineLeak tests that MemGuardian properly cleans up goroutines
func TestMemGuardianGoroutineLeak(t *testing.T) {
defer goleak.VerifyNone(t,
goleak.IgnoreAnyContainingPkg("go.opencensus.io/stats/view"),
goleak.IgnoreAnyContainingPkg("github.com/syndtr/goleveldb"),
goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"),
)

// Initialize memguardian if not already initialized
if memguardian.DefaultMemGuardian == nil {
var err error
memguardian.DefaultMemGuardian, err = memguardian.New()
require.NoError(t, err, "Failed to initialize memguardian")
}

t.Run("StartAndStopMemGuardian", func(t *testing.T) {
// Test that starting and stopping memguardian doesn't leak goroutines
ctx := context.Background()

// Start MemGuardian
StartActiveMemGuardian(ctx)
require.NotNil(t, memTimer, "memTimer should be initialized")
require.NotNil(t, cancelFunc, "cancelFunc should be initialized")

// Give it a moment to start
time.Sleep(10 * time.Millisecond)

// Stop MemGuardian
StopActiveMemGuardian()

// Give goroutine time to exit
time.Sleep(20 * time.Millisecond)

// Verify cleanup
require.Nil(t, memTimer, "memTimer should be nil after stop")
require.Nil(t, cancelFunc, "cancelFunc should be nil after stop")
})

t.Run("MultipleStartStop", func(t *testing.T) {
// Test multiple start/stop cycles
for i := 0; i < 3; i++ {
ctx := context.Background()
StartActiveMemGuardian(ctx)
time.Sleep(5 * time.Millisecond)
StopActiveMemGuardian()
time.Sleep(10 * time.Millisecond)
}
})

t.Run("ContextCancellation", func(t *testing.T) {
// Test that context cancellation properly stops the goroutine
ctx, cancel := context.WithCancel(context.Background())

StartActiveMemGuardian(ctx)
require.NotNil(t, memTimer, "memTimer should be initialized")

// Cancel context to trigger goroutine exit
cancel()

// Give it time to process cancellation
time.Sleep(20 * time.Millisecond)

// Clean up
StopActiveMemGuardian()
time.Sleep(10 * time.Millisecond)
})

t.Run("IdempotentStart", func(t *testing.T) {
// Test that multiple starts don't create multiple goroutines
ctx := context.Background()

StartActiveMemGuardian(ctx)
firstTimer := memTimer

// Start again - should be idempotent
StartActiveMemGuardian(ctx)
require.Equal(t, firstTimer, memTimer, "memTimer should be the same")
require.NotNil(t, cancelFunc, "cancelFunc should still be set")

StopActiveMemGuardian()
time.Sleep(10 * time.Millisecond)
})
}

// TestMemGuardianReset tests resetting global state
func TestMemGuardianReset(t *testing.T) {
defer goleak.VerifyNone(t,
goleak.IgnoreAnyContainingPkg("go.opencensus.io/stats/view"),
goleak.IgnoreAnyContainingPkg("github.com/syndtr/goleveldb"),
goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"),
)

// Ensure clean state
StopActiveMemGuardian()
time.Sleep(20 * time.Millisecond) // Allow any existing goroutines to exit

// Test that we can start after stop
ctx := context.Background()
StartActiveMemGuardian(ctx)

// Verify it started
require.NotNil(t, memTimer, "memTimer should be initialized after restart")

// Clean up
StopActiveMemGuardian()
time.Sleep(10 * time.Millisecond) // Allow cleanup
}
8 changes: 8 additions & 0 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,17 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV

select {
case <-input.Context().Done():
close(tasks)
workersWg.Wait()
return input.Context().Err()
default:
}

// resize check point - nop if there are no changes
if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency {
if err := spmHandler.Resize(input.Context(), request.options.Options.PayloadConcurrency); err != nil {
close(tasks)
workersWg.Wait()
return err
}
// if payload concurrency increased, add more workers
Expand All @@ -322,6 +326,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
break
}
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
close(tasks)
workersWg.Wait()
return err
}
if input.MetaInput.Input == "" {
Expand All @@ -331,6 +337,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
if request.isUnresponsiveAddress(updatedInput) {
// skip on unresponsive host no need to continue
spmHandler.Cancel()
close(tasks)
workersWg.Wait()
return nil
}
select {
Expand Down
155 changes: 155 additions & 0 deletions pkg/protocols/http/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/tarunKoyalwar/goleak"

"github.com/projectdiscovery/nuclei/v3/pkg/model"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity"
Expand Down Expand Up @@ -371,3 +373,156 @@ func TestExecuteParallelHTTP_SkipOnUnresponsiveFromCache(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int32(0), atomic.LoadInt32(&matches), "expected no matches when host is marked unresponsive")
}

// TestExecuteParallelHTTP_GoroutineLeaks uses goleak to detect goroutine leaks in all HTTP parallel execution scenarios
func TestExecuteParallelHTTP_GoroutineLeaks(t *testing.T) {
defer goleak.VerifyNone(t,
goleak.IgnoreAnyContainingPkg("go.opencensus.io/stats/view"),
goleak.IgnoreAnyContainingPkg("github.com/syndtr/goleveldb"),
goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/client"),
goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"),
goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb/util.(*BufferPool).drain"),
goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).compactionError"),
goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mpoolDrain"),
goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).tCompaction"),
goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mCompaction"),
)

options := testutils.DefaultOptions
testutils.Init(options)
defer testutils.Cleanup(options)

// Test Case 1: Normal execution with StopAtFirstMatch
t.Run("StopAtFirstMatch", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Millisecond)
_, _ = fmt.Fprintf(w, "test response")
}))
defer ts.Close()

req := &Request{
ID: "parallel-stop-first-match",
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
Path: []string{"{{BaseURL}}/test?param={{payload}}"},
Threads: 4,
Payloads: map[string]interface{}{
"payload": []string{"1", "2", "3", "4", "5", "6", "7", "8"},
},
Operators: operators.Operators{
Matchers: []*matchers.Matcher{{
Part: "body",
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
Words: []string{"test response"},
}},
},
StopAtFirstMatch: true,
}

executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
ID: "parallel-stop-first-match",
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
})

err := req.Compile(executerOpts)
require.NoError(t, err)

metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)

err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {})
require.NoError(t, err)
})

// Test Case 2: Unresponsive host scenario
t.Run("UnresponsiveHost", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "response")
}))
defer ts.Close()

req := &Request{
ID: "parallel-unresponsive",
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
Path: []string{"{{BaseURL}}/test?param={{payload}}"},
Threads: 3,
Payloads: map[string]interface{}{
"payload": []string{"1", "2", "3", "4", "5"},
},
Operators: operators.Operators{
Matchers: []*matchers.Matcher{{
Part: "body",
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
Words: []string{"response"},
}},
},
}

executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
ID: "parallel-unresponsive",
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
})
executerOpts.HostErrorsCache = &fakeHostErrorsCache{}

err := req.Compile(executerOpts)
require.NoError(t, err)

metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)

err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {})
require.NoError(t, err)
})

// Test Case 3: Context cancellation scenario
t.Run("ContextCancellation", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond)
_, _ = fmt.Fprintf(w, "response")
}))
defer ts.Close()

req := &Request{
ID: "parallel-context-cancel",
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
Path: []string{"{{BaseURL}}/test?param={{payload}}"},
Threads: 3,
Payloads: map[string]interface{}{
"payload": []string{"1", "2", "3", "4", "5"},
},
Operators: operators.Operators{
Matchers: []*matchers.Matcher{{
Part: "body",
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
Words: []string{"response"},
}},
},
}

executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
ID: "parallel-context-cancel",
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
})

err := req.Compile(executerOpts)
require.NoError(t, err)

metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)

ctx, cancel := context.WithCancel(context.Background())
ctxArgs := contextargs.NewWithInput(ctx, ts.URL)

go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()

err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {})
require.Error(t, err)
require.Equal(t, context.Canceled, err)
})
}
5 changes: 5 additions & 0 deletions pkg/testutils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ func Init(options *types.Options) {
_ = protocolinit.Init(options)
}

// Cleanup cleans up the protocols and their configurations
func Cleanup(options *types.Options) {
protocolstate.Close(options.ExecutionId)
}

// DefaultOptions is the default options structure for nuclei during mocking.
var DefaultOptions = &types.Options{
Metrics: false,
Expand Down
Loading