diff --git a/pkg/authprovider/authx/dynamic.go b/pkg/authprovider/authx/dynamic.go index 28ec59298a..0258576649 100644 --- a/pkg/authprovider/authx/dynamic.go +++ b/pkg/authprovider/authx/dynamic.go @@ -3,7 +3,7 @@ package authx import ( "fmt" "strings" - "sync/atomic" + "sync" "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/replacer" @@ -14,6 +14,14 @@ import ( type LazyFetchSecret func(d *Dynamic) error +// errNotValidated is returned when Fetch is called before Validate. +var errNotValidated = errkit.New("dynamic secret not validated: call Validate() before Fetch()") + +type fetchState struct { + once sync.Once + err error +} + var ( _ json.Unmarshaler = &Dynamic{} ) @@ -30,9 +38,9 @@ type Dynamic struct { Input string `json:"input" yaml:"input"` // (optional) target for the dynamic secret Extracted map[string]interface{} `json:"-" yaml:"-"` // extracted values from the dynamic secret fetchCallback LazyFetchSecret `json:"-" yaml:"-"` - fetched *atomic.Bool `json:"-" yaml:"-"` // atomic flag to check if the secret has been fetched - fetching *atomic.Bool `json:"-" yaml:"-"` // atomic flag to prevent recursive fetch calls - error error `json:"-" yaml:"-"` // error if any + // fetchState is shared across value-copies of Dynamic (e.g., inside DynamicAuthStrategy). + // It must be initialized via Validate() before calling Fetch(). + fetchState *fetchState `json:"-" yaml:"-"` } func (d *Dynamic) GetDomainAndDomainRegex() ([]string, []string) { @@ -70,8 +78,9 @@ func (d *Dynamic) UnmarshalJSON(data []byte) error { // Validate validates the dynamic secret func (d *Dynamic) Validate() error { - d.fetched = &atomic.Bool{} - d.fetching = &atomic.Bool{} + // NOTE: Validate() must not be called concurrently with Fetch()/GetStrategies(). + // Re-validating resets fetch state and allows re-fetching. + d.fetchState = &fetchState{} if d.TemplatePath == "" { return errkit.New(" template-path is required for dynamic secret") } @@ -181,18 +190,14 @@ func (d *Dynamic) applyValuesToSecret(secret *Secret) error { return nil } -// GetStrategy returns the auth strategies for the dynamic secret +// GetStrategies returns the auth strategies for the dynamic secret func (d *Dynamic) GetStrategies() []AuthStrategy { - if d.fetched.Load() { - if d.error != nil { - return nil - } - } else { - // Try to fetch if not already fetched - _ = d.Fetch(true) - } + // Ensure fetch has completed before returning strategies. + // Fetch errors are treated as non-fatal here so a failed dynamic auth fetch + // does not terminate the entire scan process. + _ = d.Fetch(false) - if d.error != nil { + if d.fetchState != nil && d.fetchState.err != nil { return nil } var strategies []AuthStrategy @@ -208,30 +213,31 @@ func (d *Dynamic) GetStrategies() []AuthStrategy { // Fetch fetches the dynamic secret // if isFatal is true, it will stop the execution if the secret could not be fetched func (d *Dynamic) Fetch(isFatal bool) error { - if d.fetched.Load() { - return d.error - } - - // Try to set fetching flag atomically - if !d.fetching.CompareAndSwap(false, true) { - // Already fetching, return current error - return d.error + if d.fetchState == nil { + if isFatal { + gologger.Fatal().Msgf("Could not fetch dynamic secret: %s\n", errNotValidated) + } + return errNotValidated } - // We're the only one fetching, call the callback - d.error = d.fetchCallback(d) - - // Mark as fetched and clear fetching flag - d.fetched.Store(true) - d.fetching.Store(false) + d.fetchState.once.Do(func() { + if d.fetchCallback == nil { + d.fetchState.err = errkit.New("dynamic secret fetch callback not set: call SetLazyFetchCallback() before Fetch()") + return + } + d.fetchState.err = d.fetchCallback(d) + }) - if d.error != nil && isFatal { - gologger.Fatal().Msgf("Could not fetch dynamic secret: %s\n", d.error) + if d.fetchState.err != nil && isFatal { + gologger.Fatal().Msgf("Could not fetch dynamic secret: %s\n", d.fetchState.err) } - return d.error + return d.fetchState.err } // Error returns the error if any func (d *Dynamic) Error() error { - return d.error + if d.fetchState == nil { + return nil + } + return d.fetchState.err } diff --git a/pkg/authprovider/authx/dynamic_test.go b/pkg/authprovider/authx/dynamic_test.go index ffa38ea83c..d6843b3346 100644 --- a/pkg/authprovider/authx/dynamic_test.go +++ b/pkg/authprovider/authx/dynamic_test.go @@ -1,7 +1,11 @@ package authx import ( + "errors" + "sync" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -123,3 +127,100 @@ func TestDynamicUnmarshalJSON(t *testing.T) { require.NoError(t, err) }) } + +func TestDynamicFetchConcurrent(t *testing.T) { + t.Run("all-waiters-block-until-done", func(t *testing.T) { + const numGoroutines = 10 + wantErr := errors.New("auth fetch failed") + fetchStarted := make(chan struct{}) + fetchUnblock := make(chan struct{}) + + d := &Dynamic{ + TemplatePath: "test-template.yaml", + Variables: []KV{{Key: "username", Value: "test"}}, + } + require.NoError(t, d.Validate()) + d.SetLazyFetchCallback(func(_ *Dynamic) error { + close(fetchStarted) + <-fetchUnblock + return wantErr + }) + + results := make([]error, numGoroutines) + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = d.Fetch(false) + }(i) + } + + select { + case <-fetchStarted: + case <-time.After(5 * time.Second): + t.Fatal("fetch callback never started") + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + t.Fatal("fetch callers returned before fetch completed") + case <-time.After(25 * time.Millisecond): + } + + close(fetchUnblock) + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("fetch callers did not complete in time") + } + + for _, err := range results { + require.ErrorIs(t, err, wantErr) + } + }) + + t.Run("fetch-callback-runs-once", func(t *testing.T) { + const numGoroutines = 20 + var callCount atomic.Int32 + errs := make(chan error, numGoroutines) + barrier := make(chan struct{}) + + d := &Dynamic{ + TemplatePath: "test-template.yaml", + Variables: []KV{{Key: "username", Value: "test"}}, + } + require.NoError(t, d.Validate()) + d.SetLazyFetchCallback(func(dynamic *Dynamic) error { + callCount.Add(1) + time.Sleep(20 * time.Millisecond) + dynamic.Extracted = map[string]interface{}{"token": "secret-token"} + return nil + }) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + <-barrier + errs <- d.Fetch(false) + }() + } + close(barrier) + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + require.Equal(t, int32(1), callCount.Load(), "fetch callback must be called exactly once") + }) +} diff --git a/pkg/authprovider/file_test.go b/pkg/authprovider/file_test.go new file mode 100644 index 0000000000..99a41f8b80 --- /dev/null +++ b/pkg/authprovider/file_test.go @@ -0,0 +1,86 @@ +package authprovider + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx" + "github.com/stretchr/testify/require" +) + +func TestFileAuthProviderDynamicSecretConcurrentAccess(t *testing.T) { + secretFile := filepath.Join(t.TempDir(), "secret.yaml") + secretData := []byte(`id: test-auth +info: + name: test + author: test + severity: info +dynamic: + - template: auth-template.yaml + variables: + - key: username + value: test + type: Header + domains: + - example.com + headers: + - key: Authorization + value: "Bearer {{token}}" +`) + require.NoError(t, os.WriteFile(secretFile, secretData, 0o600)) + + var fetchCalls atomic.Int32 + provider, err := NewFileAuthProvider(secretFile, func(dynamic *authx.Dynamic) error { + fetchCalls.Add(1) + time.Sleep(75 * time.Millisecond) + dynamic.Extracted = map[string]interface{}{"token": "session-token"} + return nil + }) + require.NoError(t, err) + + const workers = 20 + barrier := make(chan struct{}) + errs := make(chan error, workers) + var wg sync.WaitGroup + wg.Add(workers) + + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + <-barrier + + strategies := provider.LookupAddr("example.com") + if len(strategies) == 0 { + errs <- fmt.Errorf("no auth strategies found") + return + } + + req, reqErr := http.NewRequest(http.MethodGet, "https://example.com", nil) + if reqErr != nil { + errs <- reqErr + return + } + for _, strategy := range strategies { + strategy.Apply(req) + } + if got := req.Header.Get("Authorization"); got != "Bearer session-token" { + errs <- fmt.Errorf("expected Authorization header to be set, got %q", got) + } + }() + } + + close(barrier) + wg.Wait() + close(errs) + + for gotErr := range errs { + require.NoError(t, gotErr) + } + require.Equal(t, int32(1), fetchCalls.Load(), "dynamic secret fetch should execute once") +}