Skip to content

Commit dee8a13

Browse files
hallloManuel Naujoks
authored andcommitted
Tests, ValueTasks, and dedicated type for caching.
1 parent 027f7ee commit dee8a13

File tree

7 files changed

+316
-14
lines changed

7 files changed

+316
-14
lines changed

src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ public ClientOAuthProvider(
136136
{
137137
ThrowIfNotBearerScheme(scheme);
138138

139-
var token = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false);
139+
var cachedToken = await _tokenCache.GetTokenAsync(cancellationToken).ConfigureAwait(false);
140+
var token = cachedToken?.ForUse();
140141

141142
// Return the token if it's valid
142143
if (token != null && token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
@@ -150,7 +151,7 @@ public ClientOAuthProvider(
150151
var newToken = await RefreshTokenAsync(token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
151152
if (newToken != null)
152153
{
153-
await _tokenCache.StoreTokenAsync(newToken, cancellationToken).ConfigureAwait(false);
154+
await _tokenCache.StoreTokenAsync(newToken.ForCache(), cancellationToken).ConfigureAwait(false);
154155
return newToken.AccessToken;
155156
}
156157
}
@@ -237,7 +238,7 @@ private async Task PerformOAuthAuthorizationAsync(
237238
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token.");
238239
}
239240

240-
await _tokenCache.StoreTokenAsync(token, cancellationToken).ConfigureAwait(false);
241+
await _tokenCache.StoreTokenAsync(token.ForCache(), cancellationToken).ConfigureAwait(false);
241242
LogOAuthAuthorizationCompleted();
242243
}
243244

src/ModelContextProtocol.Core/Authentication/ITokenCache.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ namespace ModelContextProtocol.Authentication;
66
public interface ITokenCache
77
{
88
/// <summary>
9-
/// Cache the token.
9+
/// Cache the token. After a new access token is acquired, this method is invoked to store it.
1010
/// </summary>
11-
Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken);
11+
ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken);
1212

1313
/// <summary>
14-
/// Get the cached token.
14+
/// Get the cached token. This method is invoked for every request.
1515
/// </summary>
16-
Task<TokenContainer?> GetTokenAsync(CancellationToken cancellationToken);
16+
ValueTask<TokenContainerCacheable?> GetTokenAsync(CancellationToken cancellationToken);
1717
}

src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@ namespace ModelContextProtocol.Authentication;
66
/// </summary>
77
internal class InMemoryTokenCache : ITokenCache
88
{
9-
private TokenContainer? _token;
9+
private TokenContainerCacheable? _token;
1010

1111
/// <summary>
1212
/// Cache the token.
1313
/// </summary>
14-
public Task StoreTokenAsync(TokenContainer token, CancellationToken cancellationToken)
14+
public ValueTask StoreTokenAsync(TokenContainerCacheable token, CancellationToken cancellationToken)
1515
{
1616
_token = token;
17-
return Task.CompletedTask;
17+
return default;
1818
}
1919

2020
/// <summary>
2121
/// Get the cached token.
2222
/// </summary>
23-
public Task<TokenContainer?> GetTokenAsync(CancellationToken cancellationToken)
23+
public ValueTask<TokenContainerCacheable?> GetTokenAsync(CancellationToken cancellationToken)
2424
{
25-
return Task.FromResult(_token);
25+
return new ValueTask<TokenContainerCacheable?>(_token);
2626
}
2727
}

src/ModelContextProtocol.Core/Authentication/TokenContainer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace ModelContextProtocol.Authentication;
55
/// <summary>
66
/// Represents a token response from the OAuth server.
77
/// </summary>
8-
public sealed class TokenContainer
8+
internal sealed class TokenContainer
99
{
1010
/// <summary>
1111
/// Gets or sets the access token.
@@ -46,7 +46,7 @@ public sealed class TokenContainer
4646
/// <summary>
4747
/// Gets or sets the timestamp when the token was obtained.
4848
/// </summary>
49-
[JsonPropertyName("obtained_at")]
49+
[JsonIgnore]
5050
public DateTimeOffset ObtainedAt { get; set; }
5151

5252
/// <summary>
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
namespace ModelContextProtocol.Authentication;
2+
3+
/// <summary>
4+
/// Represents a cacheable token representation.
5+
/// </summary>
6+
public class TokenContainerCacheable
7+
{
8+
/// <summary>
9+
/// Gets or sets the access token.
10+
/// </summary>
11+
public string AccessToken { get; set; } = string.Empty;
12+
13+
/// <summary>
14+
/// Gets or sets the refresh token.
15+
/// </summary>
16+
public string? RefreshToken { get; set; }
17+
18+
/// <summary>
19+
/// Gets or sets the number of seconds until the access token expires.
20+
/// </summary>
21+
public int ExpiresIn { get; set; }
22+
23+
/// <summary>
24+
/// Gets or sets the extended expiration time in seconds.
25+
/// </summary>
26+
public int ExtExpiresIn { get; set; }
27+
28+
/// <summary>
29+
/// Gets or sets the token type (typically "Bearer").
30+
/// </summary>
31+
public string TokenType { get; set; } = string.Empty;
32+
33+
/// <summary>
34+
/// Gets or sets the scope of the access token.
35+
/// </summary>
36+
public string Scope { get; set; } = string.Empty;
37+
38+
/// <summary>
39+
/// Gets or sets the timestamp when the token was obtained.
40+
/// </summary>
41+
public DateTimeOffset ObtainedAt { get; set; }
42+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
namespace ModelContextProtocol.Authentication;
2+
3+
internal static class TokenContainerConvert
4+
{
5+
internal static TokenContainer ForUse(this TokenContainerCacheable token) => new()
6+
{
7+
AccessToken = token.AccessToken,
8+
RefreshToken = token.RefreshToken,
9+
ExpiresIn = token.ExpiresIn,
10+
ExtExpiresIn = token.ExtExpiresIn,
11+
TokenType = token.TokenType,
12+
Scope = token.Scope,
13+
ObtainedAt = token.ObtainedAt,
14+
};
15+
16+
internal static TokenContainerCacheable ForCache(this TokenContainer token) => new()
17+
{
18+
AccessToken = token.AccessToken,
19+
RefreshToken = token.RefreshToken,
20+
ExpiresIn = token.ExpiresIn,
21+
ExtExpiresIn = token.ExtExpiresIn,
22+
TokenType = token.TokenType,
23+
Scope = token.Scope,
24+
ObtainedAt = token.ObtainedAt,
25+
};
26+
}
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
using ModelContextProtocol.Client;
2+
using ModelContextProtocol.Protocol;
3+
using ModelContextProtocol.Authentication;
4+
using System.Text.Json;
5+
using Moq;
6+
using Moq.Protected;
7+
using System.Net;
8+
using System.Text.Json.Nodes;
9+
using System.Linq.Expressions;
10+
11+
namespace ModelContextProtocol.Tests.Client;
12+
13+
public class CustomTokenCacheTests
14+
{
15+
[Fact]
16+
public async Task GetTokenAsync_CachedAccessTokenIsUsedForOutgoingRequests()
17+
{
18+
// Arrange
19+
var cachedAccessToken = $"my_access_token_{Guid.NewGuid()}";
20+
21+
var tokenCacheMock = new Mock<ITokenCache>();
22+
MockCachedAccessToken(tokenCacheMock, cachedAccessToken);
23+
24+
var httpMessageHandlerMock = new Mock<HttpMessageHandler>();
25+
MockInitializeResponse(httpMessageHandlerMock);
26+
27+
var httpClientTransport = new HttpClientTransport(
28+
transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object),
29+
httpClient: new HttpClient(httpMessageHandlerMock.Object));
30+
31+
var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken);
32+
33+
// Act
34+
var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) };
35+
await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken);
36+
37+
// Assert
38+
httpMessageHandlerMock
39+
.Protected()
40+
.Verify("SendAsync", Times.AtLeastOnce(), ItExpr.Is<HttpRequestMessage>(req =>
41+
req.RequestUri == new Uri("http://localhost:1337/")
42+
&& req.Headers.Authorization != null
43+
&& req.Headers.Authorization.Scheme == "Bearer"
44+
&& req.Headers.Authorization.Parameter == cachedAccessToken
45+
), ItExpr.IsAny<CancellationToken>());
46+
47+
httpMessageHandlerMock
48+
.Protected()
49+
.Verify("SendAsync", Times.Never(), ItExpr.Is<HttpRequestMessage>(req =>
50+
req.RequestUri == new Uri("http://localhost:1337/")
51+
&& (req.Headers.Authorization == null || req.Headers.Authorization.Parameter != cachedAccessToken)
52+
), ItExpr.IsAny<CancellationToken>());
53+
}
54+
55+
[Fact]
56+
public async Task StoreTokenAsync_NewlyAcquiredAccessTokenIsCached()
57+
{
58+
// Arrange
59+
var tokenCacheMock = new Mock<ITokenCache>();
60+
MockNoAccessTokenUntilStored(tokenCacheMock);
61+
62+
var newAccessToken = $"new_access_token_{Guid.NewGuid()}";
63+
64+
var httpMessageHandlerMock = new Mock<HttpMessageHandler>();
65+
MockUnauthorizedResponse(httpMessageHandlerMock);
66+
MockProtectedResourceMetadataResponse(httpMessageHandlerMock);
67+
MockAuthorizationServerMetadataResponse(httpMessageHandlerMock);
68+
MockAccessTokenResponse(httpMessageHandlerMock, newAccessToken);
69+
MockInitializeResponse(httpMessageHandlerMock);
70+
71+
var httpClientTransport = new HttpClientTransport(
72+
transportOptions: NewHttpClientTransportOptions(tokenCacheMock.Object),
73+
httpClient: new HttpClient(httpMessageHandlerMock.Object));
74+
75+
var connectedTransport = await httpClientTransport.ConnectAsync(cancellationToken: TestContext.Current.CancellationToken);
76+
77+
// Act
78+
var initializeRequest = new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) };
79+
await connectedTransport.SendMessageAsync(initializeRequest, cancellationToken: TestContext.Current.CancellationToken);
80+
81+
// Assert
82+
tokenCacheMock
83+
.Verify(tc => tc.StoreTokenAsync(
84+
It.Is<TokenContainerCacheable>(token => token.AccessToken == newAccessToken),
85+
It.IsAny<CancellationToken>()), Times.Once);
86+
}
87+
88+
static HttpClientTransportOptions NewHttpClientTransportOptions(ITokenCache? tokenCache = null) => new()
89+
{
90+
Name = "MCP Server",
91+
Endpoint = new Uri("http://localhost:1337/"),
92+
TransportMode = HttpTransportMode.StreamableHttp,
93+
OAuth = new()
94+
{
95+
ClientId = "mcp_inspector",
96+
RedirectUri = new Uri("http://localhost:6274/oauth/callback"),
97+
Scopes = ["openid", "profile", "offline_access"],
98+
AuthorizationRedirectDelegate = (authorizationUrl, redirectUri, cancellationToken) => Task.FromResult<string?>($"auth_code_{Guid.NewGuid()}"),
99+
TokenCache = tokenCache,
100+
},
101+
};
102+
103+
static void MockCachedAccessToken(Mock<ITokenCache> tokenCache, string cachedAccessToken)
104+
{
105+
tokenCache
106+
.Setup(tc => tc.GetTokenAsync(It.IsAny<CancellationToken>()))
107+
.ReturnsAsync(new TokenContainerCacheable
108+
{
109+
AccessToken = cachedAccessToken,
110+
ObtainedAt = DateTimeOffset.UtcNow,
111+
ExpiresIn = (int)TimeSpan.FromHours(1).TotalSeconds,
112+
});
113+
}
114+
115+
static void MockNoAccessTokenUntilStored(Mock<ITokenCache> tokenCache)
116+
{
117+
tokenCache
118+
.Setup(tc => tc.StoreTokenAsync(It.IsAny<TokenContainerCacheable>(), It.IsAny<CancellationToken>()))
119+
.Callback<TokenContainerCacheable, CancellationToken>((token, ct) =>
120+
{
121+
// Simulate that the token is now cached
122+
MockCachedAccessToken(tokenCache, token.AccessToken);
123+
})
124+
.Returns(default(ValueTask));
125+
}
126+
127+
static void MockUnauthorizedResponse(Mock<HttpMessageHandler> httpMessageHandler)
128+
{
129+
MockHttpResponse(httpMessageHandler,
130+
request: req => req.RequestUri == new Uri("http://localhost:1337/")
131+
&& req.Method == HttpMethod.Post
132+
&& (req.Headers.Authorization == null || string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter)),
133+
response: new HttpResponseMessage(HttpStatusCode.Unauthorized)
134+
{
135+
Headers = {
136+
{ "WWW-Authenticate", "Bearer realm=\"Bearer\", resource_metadata=\"http://localhost:1337/.well-known/oauth-protected-resource\"" }
137+
},
138+
});
139+
}
140+
141+
static void MockProtectedResourceMetadataResponse(Mock<HttpMessageHandler> httpMessageHandler)
142+
{
143+
MockHttpResponse(httpMessageHandler,
144+
request: req => req.RequestUri == new Uri("http://localhost:1337/.well-known/oauth-protected-resource"),
145+
response: new HttpResponseMessage(HttpStatusCode.OK)
146+
{
147+
Content = ToJsonContent(new
148+
{
149+
resource = "http://localhost:1337/",
150+
authorization_servers = new[] { "http://localhost:1336/" },
151+
})
152+
});
153+
}
154+
155+
static void MockAuthorizationServerMetadataResponse(Mock<HttpMessageHandler> httpMessageHandler)
156+
{
157+
MockHttpResponse(httpMessageHandler,
158+
request: req => req.RequestUri == new Uri("http://localhost:1336/.well-known/openid-configuration"),
159+
response: new HttpResponseMessage(HttpStatusCode.OK)
160+
{
161+
Content = ToJsonContent(new
162+
{
163+
authorization_endpoint = "http://localhost:1336/connect/authorize",
164+
token_endpoint = "http://localhost:1336/connect/token",
165+
})
166+
});
167+
}
168+
169+
static void MockAccessTokenResponse(Mock<HttpMessageHandler> httpMessageHandler, string accessToken)
170+
{
171+
MockHttpResponse(httpMessageHandler,
172+
request: req => req.RequestUri == new Uri("http://localhost:1336/connect/token"),
173+
response: new HttpResponseMessage(HttpStatusCode.OK)
174+
{
175+
Content = ToJsonContent(new
176+
{
177+
access_token = accessToken,
178+
})
179+
});
180+
}
181+
182+
static void MockInitializeResponse(Mock<HttpMessageHandler> httpMessageHandler)
183+
{
184+
MockHttpResponse(httpMessageHandler,
185+
request: req => req.RequestUri == new Uri("http://localhost:1337/")
186+
&& req.Method == HttpMethod.Post
187+
&& req.Headers.Authorization != null
188+
&& req.Headers.Authorization.Scheme == "Bearer"
189+
&& !string.IsNullOrWhiteSpace(req.Headers.Authorization.Parameter),
190+
response: new HttpResponseMessage(HttpStatusCode.OK)
191+
{
192+
Content = ToJsonContent(new JsonRpcResponse
193+
{
194+
Id = new RequestId(1),
195+
Result = ToJson(new InitializeResult
196+
{
197+
ProtocolVersion = "2024-11-05",
198+
Capabilities = new ServerCapabilities
199+
{
200+
Prompts = new PromptsCapability { ListChanged = true },
201+
Resources = new ResourcesCapability { Subscribe = true, ListChanged = true },
202+
Tools = new ToolsCapability { ListChanged = true },
203+
Logging = new LoggingCapability(),
204+
Completions = new CompletionsCapability(),
205+
},
206+
ServerInfo = new Implementation
207+
{
208+
Name = "mcp-test-server",
209+
Version = "1.0.0"
210+
},
211+
Instructions = "This server provides weather information and file system access."
212+
})
213+
}),
214+
});
215+
}
216+
217+
static void MockHttpResponse(Mock<HttpMessageHandler> httpMessageHandler, Expression<Func<HttpRequestMessage, bool>>? request = null, HttpResponseMessage? response = null)
218+
{
219+
httpMessageHandler
220+
.Protected()
221+
.Setup<Task<HttpResponseMessage>>("SendAsync", request != null ? ItExpr.Is(request) : ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
222+
.ReturnsAsync(response ?? new HttpResponseMessage());
223+
}
224+
225+
static StringContent ToJsonContent<T>(T content) => new(
226+
content: JsonSerializer.Serialize(content, McpJsonUtilities.DefaultOptions),
227+
encoding: System.Text.Encoding.UTF8,
228+
mediaType: "application/json");
229+
230+
static JsonNode? ToJson<T>(T content) => JsonSerializer.SerializeToNode(
231+
value: content,
232+
options: McpJsonUtilities.DefaultOptions);
233+
}

0 commit comments

Comments
 (0)