-
Notifications
You must be signed in to change notification settings - Fork 3.3k
fix(authx): prevent concurrent dynamic fetch race #6946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ package authx | |
| import ( | ||
| "fmt" | ||
| "strings" | ||
| "sync/atomic" | ||
| "sync" | ||
|
|
||
| "github.com/projectdiscovery/gologger" | ||
| "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/replacer" | ||
|
|
@@ -14,6 +14,14 @@ import ( | |
|
|
||
| type LazyFetchSecret func(d *Dynamic) error | ||
|
|
||
| // errNotValidated is returned when Fetch is called before Validate. | ||
| var errNotValidated = errkit.New("dynamic secret not validated: call Validate() before Fetch()") | ||
|
|
||
| type fetchState struct { | ||
| once sync.Once | ||
| err error | ||
| } | ||
|
|
||
| var ( | ||
| _ json.Unmarshaler = &Dynamic{} | ||
| ) | ||
|
|
@@ -30,9 +38,9 @@ type Dynamic struct { | |
| Input string `json:"input" yaml:"input"` // (optional) target for the dynamic secret | ||
| Extracted map[string]interface{} `json:"-" yaml:"-"` // extracted values from the dynamic secret | ||
| fetchCallback LazyFetchSecret `json:"-" yaml:"-"` | ||
| fetched *atomic.Bool `json:"-" yaml:"-"` // atomic flag to check if the secret has been fetched | ||
| fetching *atomic.Bool `json:"-" yaml:"-"` // atomic flag to prevent recursive fetch calls | ||
| error error `json:"-" yaml:"-"` // error if any | ||
| // fetchState is shared across value-copies of Dynamic (e.g., inside DynamicAuthStrategy). | ||
| // It must be initialized via Validate() before calling Fetch(). | ||
| fetchState *fetchState `json:"-" yaml:"-"` | ||
| } | ||
|
|
||
| func (d *Dynamic) GetDomainAndDomainRegex() ([]string, []string) { | ||
|
|
@@ -70,8 +78,9 @@ func (d *Dynamic) UnmarshalJSON(data []byte) error { | |
|
|
||
| // Validate validates the dynamic secret | ||
| func (d *Dynamic) Validate() error { | ||
| d.fetched = &atomic.Bool{} | ||
| d.fetching = &atomic.Bool{} | ||
| // NOTE: Validate() must not be called concurrently with Fetch()/GetStrategies(). | ||
| // Re-validating resets fetch state and allows re-fetching. | ||
| d.fetchState = &fetchState{} | ||
| if d.TemplatePath == "" { | ||
| return errkit.New(" template-path is required for dynamic secret") | ||
| } | ||
|
|
@@ -181,18 +190,14 @@ func (d *Dynamic) applyValuesToSecret(secret *Secret) error { | |
| return nil | ||
| } | ||
|
|
||
| // GetStrategy returns the auth strategies for the dynamic secret | ||
| // GetStrategies returns the auth strategies for the dynamic secret | ||
| func (d *Dynamic) GetStrategies() []AuthStrategy { | ||
| if d.fetched.Load() { | ||
| if d.error != nil { | ||
| return nil | ||
| } | ||
| } else { | ||
| // Try to fetch if not already fetched | ||
| _ = d.Fetch(true) | ||
| } | ||
| // Ensure fetch has completed before returning strategies. | ||
| // Fetch errors are treated as non-fatal here so a failed dynamic auth fetch | ||
| // does not terminate the entire scan process. | ||
| _ = d.Fetch(false) | ||
|
|
||
| if d.error != nil { | ||
| if d.fetchState != nil && d.fetchState.err != nil { | ||
| return nil | ||
|
Comment on lines
+196
to
201
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When if d.fetchState != nil && d.fetchState.err != nilshort-circuits to The fix is to use the return value of 🐛 Proposed fix- _ = d.Fetch(false)
-
- if d.fetchState != nil && d.fetchState.err != nil {
+ if err := d.Fetch(false); err != nil {
return nil
}🤖 Prompt for AI Agents |
||
| } | ||
| var strategies []AuthStrategy | ||
|
|
@@ -208,30 +213,31 @@ func (d *Dynamic) GetStrategies() []AuthStrategy { | |
| // Fetch fetches the dynamic secret | ||
| // if isFatal is true, it will stop the execution if the secret could not be fetched | ||
| func (d *Dynamic) Fetch(isFatal bool) error { | ||
| if d.fetched.Load() { | ||
| return d.error | ||
| } | ||
|
|
||
| // Try to set fetching flag atomically | ||
| if !d.fetching.CompareAndSwap(false, true) { | ||
| // Already fetching, return current error | ||
| return d.error | ||
| if d.fetchState == nil { | ||
| if isFatal { | ||
| gologger.Fatal().Msgf("Could not fetch dynamic secret: %s\n", errNotValidated) | ||
| } | ||
| return errNotValidated | ||
| } | ||
|
|
||
| // We're the only one fetching, call the callback | ||
| d.error = d.fetchCallback(d) | ||
|
|
||
| // Mark as fetched and clear fetching flag | ||
| d.fetched.Store(true) | ||
| d.fetching.Store(false) | ||
| d.fetchState.once.Do(func() { | ||
| if d.fetchCallback == nil { | ||
| d.fetchState.err = errkit.New("dynamic secret fetch callback not set: call SetLazyFetchCallback() before Fetch()") | ||
| return | ||
| } | ||
| d.fetchState.err = d.fetchCallback(d) | ||
| }) | ||
|
|
||
| if d.error != nil && isFatal { | ||
| gologger.Fatal().Msgf("Could not fetch dynamic secret: %s\n", d.error) | ||
| if d.fetchState.err != nil && isFatal { | ||
| gologger.Fatal().Msgf("Could not fetch dynamic secret: %s\n", d.fetchState.err) | ||
| } | ||
| return d.error | ||
| return d.fetchState.err | ||
| } | ||
|
|
||
| // Error returns the error if any | ||
| func (d *Dynamic) Error() error { | ||
| return d.error | ||
| if d.fetchState == nil { | ||
| return nil | ||
| } | ||
| return d.fetchState.err | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| package authprovider | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "net/http" | ||
| "os" | ||
| "path/filepath" | ||
| "sync" | ||
| "sync/atomic" | ||
| "testing" | ||
| "time" | ||
|
|
||
| "github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestFileAuthProviderDynamicSecretConcurrentAccess(t *testing.T) { | ||
| secretFile := filepath.Join(t.TempDir(), "secret.yaml") | ||
| secretData := []byte(`id: test-auth | ||
| info: | ||
| name: test | ||
| author: test | ||
| severity: info | ||
| dynamic: | ||
| - template: auth-template.yaml | ||
| variables: | ||
| - key: username | ||
| value: test | ||
| type: Header | ||
| domains: | ||
| - example.com | ||
| headers: | ||
| - key: Authorization | ||
| value: "Bearer {{token}}" | ||
| `) | ||
| require.NoError(t, os.WriteFile(secretFile, secretData, 0o600)) | ||
|
|
||
| var fetchCalls atomic.Int32 | ||
| provider, err := NewFileAuthProvider(secretFile, func(dynamic *authx.Dynamic) error { | ||
| fetchCalls.Add(1) | ||
| time.Sleep(75 * time.Millisecond) | ||
| dynamic.Extracted = map[string]interface{}{"token": "session-token"} | ||
| return nil | ||
| }) | ||
| require.NoError(t, err) | ||
|
|
||
| const workers = 20 | ||
| barrier := make(chan struct{}) | ||
| errs := make(chan error, workers) | ||
| var wg sync.WaitGroup | ||
| wg.Add(workers) | ||
|
|
||
| for i := 0; i < workers; i++ { | ||
| go func() { | ||
| defer wg.Done() | ||
| <-barrier | ||
|
|
||
| strategies := provider.LookupAddr("example.com") | ||
| if len(strategies) == 0 { | ||
| errs <- fmt.Errorf("no auth strategies found") | ||
| return | ||
| } | ||
|
|
||
| req, reqErr := http.NewRequest(http.MethodGet, "https://example.com", nil) | ||
| if reqErr != nil { | ||
| errs <- reqErr | ||
| return | ||
| } | ||
| for _, strategy := range strategies { | ||
| strategy.Apply(req) | ||
| } | ||
| if got := req.Header.Get("Authorization"); got != "Bearer session-token" { | ||
| errs <- fmt.Errorf("expected Authorization header to be set, got %q", got) | ||
| } | ||
| }() | ||
| } | ||
|
|
||
| close(barrier) | ||
| wg.Wait() | ||
| close(errs) | ||
|
|
||
| for gotErr := range errs { | ||
| require.NoError(t, gotErr) | ||
| } | ||
| require.Equal(t, int32(1), fetchCalls.Load(), "dynamic secret fetch should execute once") | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
errNotValidatedis unexported but surfaced through the exportedFetch()method.When
Fetch(false)is called beforeValidate(), it returnserrNotValidated. Because the sentinel is unexported, callers outside theauthxpackage cannot perform a typed check witherrors.Is(err, errNotValidated)and must either inspect the error message or treat it as a generic non-nil error.If callers in other packages need to programmatically distinguish this condition, consider exporting it as
ErrNotValidated.🤖 Prompt for AI Agents