diff --git a/pkg/protocols/common/protocolstate/memguardian.go b/pkg/protocols/common/protocolstate/memguardian.go index 8a0d9699a0..2f31f4ca7d 100644 --- a/pkg/protocols/common/protocolstate/memguardian.go +++ b/pkg/protocols/common/protocolstate/memguardian.go @@ -28,12 +28,17 @@ func StartActiveMemGuardian(ctx context.Context) { memTimer = time.NewTicker(memguardian.DefaultInterval) ctx, cancelFunc = context.WithCancel(ctx) - go func() { + + ticker := memTimer + go func(t *time.Ticker) { + if t == nil { + return + } for { select { case <-ctx.Done(): return - case <-memTimer.C: + case <-t.C: if IsLowOnMemory() { _ = GlobalGuardBytesBufferAlloc() } else { @@ -41,7 +46,7 @@ func StartActiveMemGuardian(ctx context.Context) { } } } - }() + }(ticker) } func StopActiveMemGuardian() { @@ -52,9 +57,13 @@ func StopActiveMemGuardian() { return } + if cancelFunc != nil { + cancelFunc() + cancelFunc = nil + } if memTimer != nil { memTimer.Stop() - cancelFunc() + memTimer = nil } } diff --git a/pkg/protocols/common/protocolstate/memguardian_test.go b/pkg/protocols/common/protocolstate/memguardian_test.go new file mode 100644 index 0000000000..7306b81e23 --- /dev/null +++ b/pkg/protocols/common/protocolstate/memguardian_test.go @@ -0,0 +1,123 @@ +package protocolstate + +import ( + "context" + "testing" + "time" + + "github.com/projectdiscovery/utils/memguardian" + "github.com/stretchr/testify/require" + "github.com/tarunKoyalwar/goleak" +) + +// TestMemGuardianGoroutineLeak tests that MemGuardian properly cleans up goroutines +func TestMemGuardianGoroutineLeak(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreAnyContainingPkg("go.opencensus.io/stats/view"), + goleak.IgnoreAnyContainingPkg("github.com/syndtr/goleveldb"), + goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"), + ) + + // Initialize memguardian if not already initialized + if memguardian.DefaultMemGuardian == nil { + var err error + memguardian.DefaultMemGuardian, err = memguardian.New() + require.NoError(t, err, "Failed to initialize memguardian") + } + + t.Run("StartAndStopMemGuardian", func(t *testing.T) { + // Test that starting and stopping memguardian doesn't leak goroutines + ctx := context.Background() + + // Start MemGuardian + StartActiveMemGuardian(ctx) + require.NotNil(t, memTimer, "memTimer should be initialized") + require.NotNil(t, cancelFunc, "cancelFunc should be initialized") + + // Give it a moment to start + time.Sleep(10 * time.Millisecond) + + // Stop MemGuardian + StopActiveMemGuardian() + + // Give goroutine time to exit + time.Sleep(20 * time.Millisecond) + + // Verify cleanup + require.Nil(t, memTimer, "memTimer should be nil after stop") + require.Nil(t, cancelFunc, "cancelFunc should be nil after stop") + }) + + t.Run("MultipleStartStop", func(t *testing.T) { + // Test multiple start/stop cycles + for i := 0; i < 3; i++ { + ctx := context.Background() + StartActiveMemGuardian(ctx) + time.Sleep(5 * time.Millisecond) + StopActiveMemGuardian() + time.Sleep(10 * time.Millisecond) + } + }) + + t.Run("ContextCancellation", func(t *testing.T) { + // Test that context cancellation properly stops the goroutine + ctx, cancel := context.WithCancel(context.Background()) + + StartActiveMemGuardian(ctx) + require.NotNil(t, memTimer, "memTimer should be initialized") + + // Cancel context to trigger goroutine exit + cancel() + + // Give it time to process cancellation + time.Sleep(20 * time.Millisecond) + + // Clean up + StopActiveMemGuardian() + time.Sleep(10 * time.Millisecond) + }) + + t.Run("IdempotentStart", func(t *testing.T) { + // Test that multiple starts don't create multiple goroutines + ctx := context.Background() + + StartActiveMemGuardian(ctx) + firstTimer := memTimer + + // Start again - should be idempotent + StartActiveMemGuardian(ctx) + require.Equal(t, firstTimer, memTimer, "memTimer should be the same") + require.NotNil(t, cancelFunc, "cancelFunc should still be set") + + StopActiveMemGuardian() + time.Sleep(10 * time.Millisecond) + }) +} + +// TestMemGuardianReset tests resetting global state +func TestMemGuardianReset(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreAnyContainingPkg("go.opencensus.io/stats/view"), + goleak.IgnoreAnyContainingPkg("github.com/syndtr/goleveldb"), + goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"), + ) + + // Ensure clean state + StopActiveMemGuardian() + time.Sleep(20 * time.Millisecond) // Allow any existing goroutines to exit + + // Test that we can start after stop + ctx := context.Background() + StartActiveMemGuardian(ctx) + + // Verify it started + require.NotNil(t, memTimer, "memTimer should be initialized after restart") + + // Clean up + StopActiveMemGuardian() + time.Sleep(10 * time.Millisecond) // Allow cleanup +} diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index c6e3de4d97..37c78eacdd 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -292,6 +292,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV select { case <-input.Context().Done(): + close(tasks) + workersWg.Wait() return input.Context().Err() default: } @@ -299,6 +301,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV // resize check point - nop if there are no changes if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency { if err := spmHandler.Resize(input.Context(), request.options.Options.PayloadConcurrency); err != nil { + close(tasks) + workersWg.Wait() return err } // if payload concurrency increased, add more workers @@ -322,6 +326,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV break } request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total())) + close(tasks) + workersWg.Wait() return err } if input.MetaInput.Input == "" { @@ -331,6 +337,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV if request.isUnresponsiveAddress(updatedInput) { // skip on unresponsive host no need to continue spmHandler.Cancel() + close(tasks) + workersWg.Wait() return nil } select { diff --git a/pkg/protocols/http/request_test.go b/pkg/protocols/http/request_test.go index b547d91b5e..9eb7b100e4 100644 --- a/pkg/protocols/http/request_test.go +++ b/pkg/protocols/http/request_test.go @@ -7,8 +7,10 @@ import ( "net/http/httptest" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" + "github.com/tarunKoyalwar/goleak" "github.com/projectdiscovery/nuclei/v3/pkg/model" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity" @@ -371,3 +373,156 @@ func TestExecuteParallelHTTP_SkipOnUnresponsiveFromCache(t *testing.T) { require.NoError(t, err) require.Equal(t, int32(0), atomic.LoadInt32(&matches), "expected no matches when host is marked unresponsive") } + +// TestExecuteParallelHTTP_GoroutineLeaks uses goleak to detect goroutine leaks in all HTTP parallel execution scenarios +func TestExecuteParallelHTTP_GoroutineLeaks(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreAnyContainingPkg("go.opencensus.io/stats/view"), + goleak.IgnoreAnyContainingPkg("github.com/syndtr/goleveldb"), + goleak.IgnoreAnyContainingPkg("github.com/go-rod/rod"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/server"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/interactsh/pkg/client"), + goleak.IgnoreAnyContainingPkg("github.com/projectdiscovery/ratelimit"), + goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb/util.(*BufferPool).drain"), + goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).compactionError"), + goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mpoolDrain"), + goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).tCompaction"), + goleak.IgnoreAnyFunction("github.com/syndtr/goleveldb/leveldb.(*DB).mCompaction"), + ) + + options := testutils.DefaultOptions + testutils.Init(options) + defer testutils.Cleanup(options) + + // Test Case 1: Normal execution with StopAtFirstMatch + t.Run("StopAtFirstMatch", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Millisecond) + _, _ = fmt.Fprintf(w, "test response") + })) + defer ts.Close() + + req := &Request{ + ID: "parallel-stop-first-match", + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/test?param={{payload}}"}, + Threads: 4, + Payloads: map[string]interface{}{ + "payload": []string{"1", "2", "3", "4", "5", "6", "7", "8"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"test response"}, + }}, + }, + StopAtFirstMatch: true, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: "parallel-stop-first-match", + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + + err := req.Compile(executerOpts) + require.NoError(t, err) + + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) + + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {}) + require.NoError(t, err) + }) + + // Test Case 2: Unresponsive host scenario + t.Run("UnresponsiveHost", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, "response") + })) + defer ts.Close() + + req := &Request{ + ID: "parallel-unresponsive", + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/test?param={{payload}}"}, + Threads: 3, + Payloads: map[string]interface{}{ + "payload": []string{"1", "2", "3", "4", "5"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"response"}, + }}, + }, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: "parallel-unresponsive", + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + executerOpts.HostErrorsCache = &fakeHostErrorsCache{} + + err := req.Compile(executerOpts) + require.NoError(t, err) + + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) + + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {}) + require.NoError(t, err) + }) + + // Test Case 3: Context cancellation scenario + t.Run("ContextCancellation", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + _, _ = fmt.Fprintf(w, "response") + })) + defer ts.Close() + + req := &Request{ + ID: "parallel-context-cancel", + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/test?param={{payload}}"}, + Threads: 3, + Payloads: map[string]interface{}{ + "payload": []string{"1", "2", "3", "4", "5"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"response"}, + }}, + }, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: "parallel-context-cancel", + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + + err := req.Compile(executerOpts) + require.NoError(t, err) + + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + + ctx, cancel := context.WithCancel(context.Background()) + ctxArgs := contextargs.NewWithInput(ctx, ts.URL) + + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {}) + require.Error(t, err) + require.Equal(t, context.Canceled, err) + }) +} diff --git a/pkg/testutils/testutils.go b/pkg/testutils/testutils.go index 5f791c2c12..521654c452 100644 --- a/pkg/testutils/testutils.go +++ b/pkg/testutils/testutils.go @@ -33,6 +33,11 @@ func Init(options *types.Options) { _ = protocolinit.Init(options) } +// Cleanup cleans up the protocols and their configurations +func Cleanup(options *types.Options) { + protocolstate.Close(options.ExecutionId) +} + // DefaultOptions is the default options structure for nuclei during mocking. var DefaultOptions = &types.Options{ Metrics: false,