Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions lib/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"io"
"sync"

"github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog"
Expand Down Expand Up @@ -64,6 +65,7 @@ type NucleiEngine struct {
templatesLoaded bool

// unexported core fields
ctx context.Context
interactshClient *interactsh.Client
catalog catalog.Catalog
rateLimiter *ratelimit.Limiter
Expand Down Expand Up @@ -246,9 +248,9 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f
}

filtered := []func(event *output.ResultEvent){}
for _, callback := range callback {
if callback != nil {
filtered = append(filtered, callback)
for _, cb := range callback {
if cb != nil {
filtered = append(filtered, cb)
}
}
e.resultCallbacks = append(e.resultCallbacks, filtered...)
Expand All @@ -258,15 +260,32 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f
return ErrNoTemplatesAvailable
}

_ = e.engine.ExecuteScanWithOpts(ctx, templatesAndWorkflows, e.inputProvider, false)
defer e.engine.WorkPool().Wait()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = e.engine.ExecuteScanWithOpts(ctx, templatesAndWorkflows, e.inputProvider, false)
}()

// wait for context to be cancelled
select {
case <-ctx.Done():
<-wait(&wg) // wait for scan to finish
return ctx.Err()
case <-wait(&wg):
// scan finished
}
return nil
}

// ExecuteWithCallback is same as ExecuteCallbackWithCtx but with default context
// Note this is deprecated and will be removed in future major release
func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.ResultEvent)) error {
return e.ExecuteCallbackWithCtx(context.Background(), callback...)
ctx := context.Background()
if e.ctx != nil {
ctx = e.ctx
}
return e.ExecuteCallbackWithCtx(ctx, callback...)
}

// Options return nuclei Type Options
Expand All @@ -290,6 +309,7 @@ func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*Nucl
e := &NucleiEngine{
opts: types.DefaultOptions(),
mode: singleInstance,
ctx: ctx,
}
for _, option := range options {
if err := option(e); err != nil {
Expand All @@ -306,3 +326,13 @@ func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*Nucl
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return NewNucleiEngineCtx(context.Background(), options...)
}

// wait for a waitgroup to finish
func wait(wg *sync.WaitGroup) <-chan struct{} {
ch := make(chan struct{})
go func() {
defer close(ch)
wg.Wait()
}()
return ch
}
37 changes: 37 additions & 0 deletions lib/sdk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package nuclei_test

import (
"context"
"log"
"testing"
"time"

nuclei "github.com/projectdiscovery/nuclei/v3/lib"
"github.com/stretchr/testify/require"
)

func TestContextCancelNucleiEngine(t *testing.T) {
// create nuclei engine with options
ctx, cancel := context.WithCancel(context.Background())
ne, err := nuclei.NewNucleiEngineCtx(ctx,
nuclei.WithTemplateFilters(nuclei.TemplateFilters{Tags: []string{"oast"}}),
nuclei.EnableStatsWithOpts(nuclei.StatsOptions{MetricServerPort: 0}),
)
require.NoError(t, err, "could not create nuclei engine")

go func() {
time.Sleep(time.Second * 2)
cancel()
log.Println("Test: context cancelled")
}()

// load targets and optionally probe non http/https targets
ne.LoadTargets([]string{"http://honey.scanme.sh"}, false)
// when callback is nil it nuclei will print JSON output to stdout
err = ne.ExecuteWithCallback(nil)
if err != nil {
// we expect a context cancellation error
require.ErrorIs(t, err, context.Canceled, "was expecting context cancellation error")
}
defer ne.Close()
}
Loading