Skip to content

Commit 7d0613f

Browse files
committed
feat: add thread-safe SetExpectedState for cross-request OAuth flows
Enables OAuth state management when initialization and callback steps are handled by different OAuthHandler instances, such as in web servers where separate HTTP request handlers process the auth flow stages. - Add SetExpectedState method for explicit state configuration - Add mutex protection for thread-safe expectedState access - Add comprehensive test for cross-request scenario validation
1 parent 7c38b56 commit 7d0613f

File tree

2 files changed

+123
-7
lines changed

2 files changed

+123
-7
lines changed

client/transport/oauth.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ type OAuthHandler struct {
115115
metadataFetchErr error
116116
metadataOnce sync.Once
117117
baseURL string
118-
expectedState string // Expected state value for CSRF protection
118+
119+
mu sync.RWMutex // Protects expectedState
120+
expectedState string // Expected state value for CSRF protection
119121
}
120122

121123
// NewOAuthHandler creates a new OAuth handler
@@ -263,9 +265,27 @@ func (h *OAuthHandler) SetBaseURL(baseURL string) {
263265

264266
// GetExpectedState returns the expected state value (for testing purposes)
265267
func (h *OAuthHandler) GetExpectedState() string {
268+
h.mu.RLock()
269+
defer h.mu.RUnlock()
266270
return h.expectedState
267271
}
268272

273+
// SetExpectedState sets the expected state value.
274+
//
275+
// This can be useful if you cannot maintain an OAuthHandler
276+
// instance throughout the authentication flow; for example, if
277+
// the initialization and callback steps are handled in different
278+
// requests.
279+
//
280+
// In such cases, this should be called with the state value generated
281+
// during the initial authentication request (e.g. by GenerateState)
282+
// and included in the authorization URL.
283+
func (h *OAuthHandler) SetExpectedState(expectedState string) {
284+
h.mu.Lock()
285+
defer h.mu.Unlock()
286+
h.expectedState = expectedState
287+
}
288+
269289
// OAuthError represents a standard OAuth 2.0 error response
270290
type OAuthError struct {
271291
ErrorCode string `json:"error"`
@@ -547,18 +567,21 @@ var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack"
547567
// ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token
548568
func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error {
549569
// Validate the state parameter to prevent CSRF attacks
550-
if h.expectedState == "" {
570+
h.mu.Lock()
571+
expectedState := h.expectedState
572+
if expectedState == "" {
573+
h.mu.Unlock()
551574
return errors.New("no expected state found, authorization flow may not have been initiated properly")
552575
}
553576

554-
if state != h.expectedState {
577+
if state != expectedState {
578+
h.mu.Unlock()
555579
return ErrInvalidState
556580
}
557581

558582
// Clear the expected state after validation
559-
defer func() {
560-
h.expectedState = ""
561-
}()
583+
h.expectedState = ""
584+
h.mu.Unlock()
562585

563586
metadata, err := h.getServerMetadata(ctx)
564587
if err != nil {
@@ -629,7 +652,7 @@ func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChall
629652
}
630653

631654
// Store the state for later validation
632-
h.expectedState = state
655+
h.SetExpectedState(state)
633656

634657
params := url.Values{}
635658
params.Set("response_type", "code")

client/transport/oauth_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,96 @@ func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T)
300300
t.Errorf("Got ErrInvalidState when expected a different error for empty expected state")
301301
}
302302
}
303+
304+
func TestOAuthHandler_SetExpectedState_CrossRequestScenario(t *testing.T) {
305+
// Simulate the scenario where different OAuthHandler instances are used
306+
// for initialization and callback steps (different HTTP request handlers)
307+
308+
config := OAuthConfig{
309+
ClientID: "test-client",
310+
RedirectURI: "http://localhost:8085/callback",
311+
Scopes: []string{"mcp.read", "mcp.write"},
312+
TokenStore: NewMemoryTokenStore(),
313+
AuthServerMetadataURL: "http://example.com/.well-known/oauth-authorization-server",
314+
PKCEEnabled: true,
315+
}
316+
317+
// Step 1: First handler instance (initialization request)
318+
// This simulates the handler that generates the authorization URL
319+
handler1 := NewOAuthHandler(config)
320+
321+
// Mock the server metadata for the first handler
322+
handler1.serverMetadata = &AuthServerMetadata{
323+
Issuer: "http://example.com",
324+
AuthorizationEndpoint: "http://example.com/authorize",
325+
TokenEndpoint: "http://example.com/token",
326+
}
327+
328+
// Generate state and get authorization URL (this would typically be done in the init handler)
329+
testState := "generated-state-value-123"
330+
_, err := handler1.GetAuthorizationURL(context.Background(), testState, "test-code-challenge")
331+
if err != nil {
332+
// We expect this to fail since we're not actually connecting to a server,
333+
// but it should still store the expected state
334+
if !strings.Contains(err.Error(), "connection") && !strings.Contains(err.Error(), "dial") {
335+
t.Errorf("Expected connection error, got: %v", err)
336+
}
337+
}
338+
339+
// Verify the state was stored in the first handler
340+
if handler1.GetExpectedState() != testState {
341+
t.Errorf("Expected state %s to be stored in first handler, got %s", testState, handler1.GetExpectedState())
342+
}
343+
344+
// Step 2: Second handler instance (callback request)
345+
// This simulates a completely separate handler instance that would be created
346+
// in a different HTTP request handler for processing the OAuth callback
347+
handler2 := NewOAuthHandler(config)
348+
349+
// Mock the server metadata for the second handler
350+
handler2.serverMetadata = &AuthServerMetadata{
351+
Issuer: "http://example.com",
352+
AuthorizationEndpoint: "http://example.com/authorize",
353+
TokenEndpoint: "http://example.com/token",
354+
}
355+
356+
// Initially, the second handler has no expected state
357+
if handler2.GetExpectedState() != "" {
358+
t.Errorf("Expected second handler to have empty state initially, got %s", handler2.GetExpectedState())
359+
}
360+
361+
// Step 3: Transfer the state from the first handler to the second
362+
// This is the key functionality being tested - setting the expected state
363+
// in a different handler instance
364+
handler2.SetExpectedState(testState)
365+
366+
// Verify the state was transferred correctly
367+
if handler2.GetExpectedState() != testState {
368+
t.Errorf("Expected state %s to be set in second handler, got %s", testState, handler2.GetExpectedState())
369+
}
370+
371+
// Step 4: Test that state validation works correctly in the second handler
372+
373+
// Test with correct state - should pass validation but fail at token exchange
374+
// (since we're not actually running a real OAuth server)
375+
err = handler2.ProcessAuthorizationResponse(context.Background(), "test-code", testState, "test-code-verifier")
376+
if err == nil {
377+
t.Errorf("Expected error due to token exchange failure, got nil")
378+
}
379+
// Should NOT be ErrInvalidState since the state matches
380+
if errors.Is(err, ErrInvalidState) {
381+
t.Errorf("Got ErrInvalidState with matching state, should have failed at token exchange instead")
382+
}
383+
384+
// Verify state was cleared after processing (even though token exchange failed)
385+
if handler2.GetExpectedState() != "" {
386+
t.Errorf("Expected state to be cleared after processing, got %s", handler2.GetExpectedState())
387+
}
388+
389+
// Step 5: Test with wrong state after resetting
390+
handler2.SetExpectedState("different-state-value")
391+
err = handler2.ProcessAuthorizationResponse(context.Background(), "test-code", testState, "test-code-verifier")
392+
if !errors.Is(err, ErrInvalidState) {
393+
t.Errorf("Expected ErrInvalidState with wrong state, got %v", err)
394+
}
395+
}

0 commit comments

Comments
 (0)