diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 19f00ca4..efd13e17 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -19,6 +19,7 @@ micronaut = "4.9.11" micronaut-platform = "4.9.3" micronaut-docs = "2.0.0" +micronaut-security = "4.15.0" micronaut-serde = "2.15.1" micronaut-json-schema = "1.7.1" micronaut-logging = "1.7.0" @@ -33,7 +34,7 @@ micronaut-gradle-plugin = "4.5.4" [libraries] micronaut-core = { module = 'io.micronaut:micronaut-core-bom', version.ref = 'micronaut' } - +micronaut-security = { module = 'io.micronaut.security:micronaut-security-bom', version.ref = 'micronaut-security' } micronaut-json-schema = { module = 'io.micronaut.jsonschema:micronaut-json-schema-bom', version.ref = 'micronaut-json-schema' } micronaut-langchain4j = { module = 'io.micronaut.langchain4j:micronaut-langchain4j-bom', version.ref = 'micronaut-langchain4j' } micronaut-serde = { module = 'io.micronaut.serde:micronaut-serde-bom', version.ref = 'micronaut-serde' } diff --git a/micronaut-mcp-server-java-sdk/build.gradle.kts b/micronaut-mcp-server-java-sdk/build.gradle.kts index bd971303..02ff44f6 100644 --- a/micronaut-mcp-server-java-sdk/build.gradle.kts +++ b/micronaut-mcp-server-java-sdk/build.gradle.kts @@ -14,6 +14,7 @@ dependencies { annotationProcessor(mnJsonSchema.micronaut.json.schema.processor) implementation(mnJsonSchema.micronaut.json.schema.annotations) implementation(mnJsonSchema.micronaut.json.schema.validation) + testImplementation(mnSecurity.micronaut.security) api(mnJsonSchema.micronaut.json.schema.utils) api(mnValidation.validation) compileOnly(mn.micronaut.http.server) diff --git a/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/DefaultMcpTransportContextExtractor.java b/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/DefaultMcpTransportContextExtractor.java similarity index 63% rename from micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/DefaultMcpTransportContextExtractor.java rename to micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/DefaultMcpTransportContextExtractor.java index 1f95ba26..738f6f36 100644 --- a/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/DefaultMcpTransportContextExtractor.java +++ b/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/DefaultMcpTransportContextExtractor.java @@ -13,16 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.micronaut.mcp.server; +package io.micronaut.mcp.server.context; import io.micronaut.core.annotation.Internal; +import io.micronaut.core.util.LocaleResolver; +import io.micronaut.http.HttpAttributes; import io.micronaut.http.HttpHeaders; import io.micronaut.http.HttpRequest; +import io.micronaut.http.server.util.HttpHostResolver; import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.ProtocolVersions; import jakarta.inject.Singleton; +import java.security.Principal; import java.util.HashMap; import java.util.Map; @@ -32,17 +36,32 @@ @Internal @Singleton final class DefaultMcpTransportContextExtractor implements McpTransportContextExtractor> { + private final HttpHostResolver hostResolver; + private final LocaleResolver> localeResolver; + + DefaultMcpTransportContextExtractor(HttpHostResolver hostResolver, + LocaleResolver> localeResolver) { + this.hostResolver = hostResolver; + this.localeResolver = localeResolver; + } + @Override public McpTransportContext extract(HttpRequest request) { - return McpTransportContext.create(metadata(request)); + return new MicronautMcpTransportContextAdapter(McpTransportContext.create(metadata(request))); } private Map metadata(HttpRequest request) { - return metadata(request.getHeaders()); + Map m = new HashMap<>(metadata(request.getHeaders())); + m.put(HttpHeaders.HOST, hostResolver.resolve(request)); + localeResolver.resolve(request) + .ifPresent(locale -> m.put(HttpHeaders.ACCEPT_LANGUAGE, locale)); + request.getAttribute(HttpAttributes.PRINCIPAL.toString(), Principal.class) + .ifPresent(auth -> m.put(HttpAttributes.PRINCIPAL.toString(), auth)); + return m; } private Map metadata(HttpHeaders headers) { - Map metadata = new HashMap<>(); + Map metadata = new HashMap<>(3); metadata.put(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION, headers.get(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION, String.class) .orElse(ProtocolVersions.MCP_2025_03_26)); diff --git a/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/MicronautMcpTransportContext.java b/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/MicronautMcpTransportContext.java new file mode 100644 index 00000000..cad6f3cf --- /dev/null +++ b/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/MicronautMcpTransportContext.java @@ -0,0 +1,70 @@ +/* + * Copyright 2017-2025 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.mcp.server.context; + +import io.micronaut.core.annotation.Nullable; +import io.modelcontextprotocol.common.McpTransportContext; + +import java.security.Principal; +import java.util.Locale; + +/** + * Extension of {@link McpTransportContext} with convenience methods to access common transport metadata in a Micronaut context. + */ +public interface MicronautMcpTransportContext extends McpTransportContext { + /** + * + * @return The Locale of the request, if available + */ + @Nullable + Locale locale(); + + /** + * + * @return The server host if available + */ + @Nullable + String host(); + + + /** + * + * @return The authenticated principal if available + */ + @Nullable + Principal principal(); + + /** + * + * @return the last event ID if available + */ + @Nullable + String lastEventId(); + + /** + * + * @return the session ID if available + */ + @Nullable + String sessionId(); + + /** + * + * @return the MCP Protocol version + */ + @Nullable + String protocolVersion(); +} diff --git a/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/MicronautMcpTransportContextAdapter.java b/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/MicronautMcpTransportContextAdapter.java new file mode 100644 index 00000000..035c1d88 --- /dev/null +++ b/micronaut-mcp-server-java-sdk/src/main/java/io/micronaut/mcp/server/context/MicronautMcpTransportContextAdapter.java @@ -0,0 +1,91 @@ +/* + * Copyright 2017-2025 original authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.micronaut.mcp.server.context; + +import io.micronaut.core.annotation.Internal; +import io.micronaut.core.annotation.Nullable; +import io.micronaut.http.HttpAttributes; +import io.micronaut.http.HttpHeaders; +import io.modelcontextprotocol.common.McpTransportContext; + +import java.security.Principal; +import java.util.Locale; + +@Internal +final class MicronautMcpTransportContextAdapter implements MicronautMcpTransportContext { + private final McpTransportContext delegate; + + MicronautMcpTransportContextAdapter(McpTransportContext context) { + this.delegate = context; + } + + @Override + public Object get(String key) { + return delegate.get(key); + } + + @Nullable + @Override + public Locale locale() { + Object obj = get(HttpHeaders.ACCEPT_LANGUAGE); + if (obj instanceof Locale locale) { + return locale; + } + return null; + } + + @Nullable + @Override + public String host() { + return getString(HttpHeaders.HOST); + } + + @Nullable + @Override + public Principal principal() { + Object obj = get(HttpAttributes.PRINCIPAL.toString()); + if (obj instanceof Principal principal) { + return principal; + } + return null; + } + + @Nullable + @Override + public String lastEventId() { + return getString(io.modelcontextprotocol.spec.HttpHeaders.LAST_EVENT_ID); + } + + @Nullable + @Override + public String sessionId() { + return getString(io.modelcontextprotocol.spec.HttpHeaders.MCP_SESSION_ID); + } + + @Nullable + @Override + public String protocolVersion() { + return getString(io.modelcontextprotocol.spec.HttpHeaders.PROTOCOL_VERSION); + } + + private String getString(String key) { + Object obj = get(key); + if (obj instanceof String str) { + return str; + } + return null; + } +} diff --git a/micronaut-mcp-server-java-sdk/src/test/java/io/micronaut/mcp/server/context/MicronautMcpTransportContextTest.java b/micronaut-mcp-server-java-sdk/src/test/java/io/micronaut/mcp/server/context/MicronautMcpTransportContextTest.java new file mode 100644 index 00000000..e4077a30 --- /dev/null +++ b/micronaut-mcp-server-java-sdk/src/test/java/io/micronaut/mcp/server/context/MicronautMcpTransportContextTest.java @@ -0,0 +1,385 @@ +package io.micronaut.mcp.server.context; + +import io.micronaut.context.annotation.Factory; +import io.micronaut.runtime.server.EmbeddedServer; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.inject.Inject; +import org.json.JSONException; +import io.micronaut.http.client.BlockingHttpClient; +import io.micronaut.http.client.HttpClient; +import io.micronaut.http.client.annotation.Client; +import io.micronaut.context.annotation.Property; +import io.micronaut.context.annotation.Requires; +import io.micronaut.core.async.publisher.Publishers; +import io.micronaut.http.HttpRequest; +import io.micronaut.http.HttpResponse; +import io.micronaut.http.HttpStatus; +import io.micronaut.security.authentication.Authentication; +import io.micronaut.security.filters.AuthenticationFetcher; +import io.micronaut.test.extensions.junit5.annotation.MicronautTest; +import jakarta.inject.Singleton; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import org.skyscreamer.jsonassert.JSONAssert; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Property(name = "micronaut.mcp.server.info.name", value = "mcp-server") +@Property(name = "micronaut.mcp.server.info.version", value = "0.0.1") +@Property(name = "micronaut.mcp.server.transport", value = "HTTP") +@Property(name = "spec.name", value = "MicronautMcpTransportContextTest") +@Property(name = "micronaut.server.locale-resolution.fixed", value = "es_ES") +@MicronautTest +class MicronautMcpTransportContextTest { + + @Inject + EmbeddedServer embeddedServer; + + @Test + void lastEventIdInContext(@Client("/") HttpClient httpClient) throws JSONException { + BlockingHttpClient client = httpClient.toBlocking(); + HttpRequest req = HttpRequest.POST("/mcp", """ + { + "method": "tools/call", + "params": { + "name": "lastEventId", + "arguments": {} + }, + "jsonrpc": "2.0", + "id": 20 + }""").header("Last-Event-ID", "4578"); + HttpResponse response = assertDoesNotThrow(() -> client.exchange(req, String.class)); + assertEquals(HttpStatus.OK, response.getStatus()); + String responseJson = response.body(); + String expected = String.format(""" + + { + "jsonrpc": "2.0", + "id": 20, + "result": { + "content": [ + { + "type": "text", + "text": "4578" + } + ], + "isError": false + } +}""", embeddedServer.getPort()); + JSONAssert.assertEquals(expected, responseJson, true); + } + + @Test + void sessionIdInContext(@Client("/") HttpClient httpClient) throws JSONException { + BlockingHttpClient client = httpClient.toBlocking(); + HttpRequest req = HttpRequest.POST("/mcp", """ + { + "method": "tools/call", + "params": { + "name": "sessionId", + "arguments": {} + }, + "jsonrpc": "2.0", + "id": 20 + }""").header("Mcp-Session-Id", "123456789"); + HttpResponse response = assertDoesNotThrow(() -> client.exchange(req, String.class)); + assertEquals(HttpStatus.OK, response.getStatus()); + String responseJson = response.body(); + String expected = String.format(""" + + { + "jsonrpc": "2.0", + "id": 20, + "result": { + "content": [ + { + "type": "text", + "text": "123456789" + } + ], + "isError": false + } +}""", embeddedServer.getPort()); + JSONAssert.assertEquals(expected, responseJson, true); + } + + @Test + void protocolVersionInContext(@Client("/") HttpClient httpClient) throws JSONException { + BlockingHttpClient client = httpClient.toBlocking(); + HttpRequest req = HttpRequest.POST("/mcp", """ + { + "method": "tools/call", + "params": { + "name": "protocolVersion", + "arguments": {} + }, + "jsonrpc": "2.0", + "id": 20 + }""").header("MCP-Protocol-Version", "2025-06-18"); + HttpResponse response = assertDoesNotThrow(() -> client.exchange(req, String.class)); + assertEquals(HttpStatus.OK, response.getStatus()); + String responseJson = response.body(); + String expected = String.format(""" + + { + "jsonrpc": "2.0", + "id": 20, + "result": { + "content": [ + { + "type": "text", + "text": "2025-06-18" + } + ], + "isError": false + } +}""", embeddedServer.getPort()); + JSONAssert.assertEquals(expected, responseJson, true); + } + + @Test + void hostTool(@Client("/") HttpClient httpClient) throws JSONException { + BlockingHttpClient client = httpClient.toBlocking(); + HttpRequest req = HttpRequest.POST("/mcp", """ + { + "method": "tools/call", + "params": { + "name": "host", + "arguments": {} + }, + "jsonrpc": "2.0", + "id": 20 + }"""); + HttpResponse response = assertDoesNotThrow(() -> client.exchange(req, String.class)); + assertEquals(HttpStatus.OK, response.getStatus()); + String responseJson = response.body(); + String expected = String.format(""" + + { + "jsonrpc": "2.0", + "id": 20, + "result": { + "content": [ + { + "type": "text", + "text": "http://localhost:%s" + } + ], + "isError": false + } +}""", embeddedServer.getPort()); + JSONAssert.assertEquals(expected, responseJson, true); + } + + @Test + void localeTool(@Client("/") HttpClient httpClient) throws JSONException { + BlockingHttpClient client = httpClient.toBlocking(); + HttpRequest req = HttpRequest.POST("/mcp", """ + { + "method": "tools/call", + "params": { + "name": "locale", + "arguments": {} + }, + "jsonrpc": "2.0", + "id": 20 + }"""); + HttpResponse response = assertDoesNotThrow(() -> client.exchange(req, String.class)); + assertEquals(HttpStatus.OK, response.getStatus()); + String responseJson = response.body(); + String expected = """ + + { + "jsonrpc": "2.0", + "id": 20, + "result": { + "content": [ + { + "type": "text", + "text": "es-ES" + } + ], + "isError": false + } +}"""; + JSONAssert.assertEquals(expected, responseJson, true); + } + + @Test + void userTool(@Client("/") HttpClient httpClient) throws JSONException { + BlockingHttpClient client = httpClient.toBlocking(); + HttpRequest req = HttpRequest.POST("/mcp", """ + { + "method": "tools/call", + "params": { + "name": "user", + "arguments": {} + }, + "jsonrpc": "2.0", + "id": 20 + }"""); + HttpResponse response = assertDoesNotThrow(() -> client.exchange(req, String.class)); + assertEquals(HttpStatus.OK, response.getStatus()); + String responseJson = response.body(); + String expected = """ + + { + "jsonrpc": "2.0", + "id": 20, + "result": { + "content": [ + { + "type": "text", + "text": "user: sdelamo role: [ROLE_USER]" + } + ], + "isError": false + } +}"""; + JSONAssert.assertEquals(expected, responseJson, true); + } + + + @Requires(property = "spec.name", value = "MicronautMcpTransportContextTest") + @Factory + static class ToolsFactory { + @Singleton + McpStatelessServerFeatures.SyncToolSpecification hostTool() { + return McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("host") + .build()) + .callHandler((exchange, req) -> { + if (exchange instanceof MicronautMcpTransportContext context) { + return McpSchema.CallToolResult.builder() + .addTextContent(context.host()) + .build(); + } else { + return McpSchema.CallToolResult.builder() + .isError(true) + .build(); + } + }) + .build(); + } + + @Singleton + McpStatelessServerFeatures.SyncToolSpecification localeTool() { + return McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("locale") + .build()) + .callHandler((exchange, req) -> { + if (exchange instanceof MicronautMcpTransportContext context) { + return McpSchema.CallToolResult.builder() + .addTextContent(context.locale().toLanguageTag()) + .build(); + } else { + return McpSchema.CallToolResult.builder() + .isError(true) + .build(); + } + }) + .build(); + } + + @Singleton + McpStatelessServerFeatures.SyncToolSpecification protocolVersionTool() { + return McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("protocolVersion") + .build()) + .callHandler((exchange, req) -> { + if (exchange instanceof MicronautMcpTransportContext context) { + return McpSchema.CallToolResult.builder() + .addTextContent(context.protocolVersion()) + .build(); + } else { + return McpSchema.CallToolResult.builder() + .isError(true) + .build(); + } + }) + .build(); + } + + @Singleton + McpStatelessServerFeatures.SyncToolSpecification lastEventIdTool() { + return McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("lastEventId") + .build()) + .callHandler((exchange, req) -> { + if (exchange instanceof MicronautMcpTransportContext context) { + return McpSchema.CallToolResult.builder() + .addTextContent(context.lastEventId()) + .build(); + } else { + return McpSchema.CallToolResult.builder() + .isError(true) + .build(); + } + }) + .build(); + } + + @Singleton + McpStatelessServerFeatures.SyncToolSpecification sessionIdTool() { + return McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("sessionId") + .build()) + .callHandler((exchange, req) -> { + if (exchange instanceof MicronautMcpTransportContext context) { + return McpSchema.CallToolResult.builder() + .addTextContent(context.sessionId()) + .build(); + } else { + return McpSchema.CallToolResult.builder() + .isError(true) + .build(); + } + }) + .build(); + } + + @Singleton + McpStatelessServerFeatures.SyncToolSpecification userTool() { + return McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("user") + .build()) + .callHandler((exchange, req) -> { + if (exchange instanceof MicronautMcpTransportContext context) { + if (context.principal() instanceof Authentication authentication) { + return McpSchema.CallToolResult.builder() + .addTextContent("user: " + authentication.getName() + " role: " + authentication.getRoles()) + .build(); + } else { + return McpSchema.CallToolResult.builder() + .addTextContent("user: " + context.principal().getName()) + .build(); + } + } else { + return McpSchema.CallToolResult.builder() + .isError(true) + .build(); + } + }) + .build(); + } + } + + @Requires(property = "spec.name", value = "MicronautMcpTransportContextTest") + @Singleton + static class TestAuthenticationFetcher implements AuthenticationFetcher> { + @Override + public Publisher fetchAuthentication(HttpRequest request) { + return Publishers.just(Authentication.build("sdelamo", List.of("ROLE_USER"))); + } + } +} diff --git a/micronaut-mcp-server-java-sdk/src/test/resources/application-test.properties b/micronaut-mcp-server-java-sdk/src/test/resources/application-test.properties new file mode 100644 index 00000000..8bbc689d --- /dev/null +++ b/micronaut-mcp-server-java-sdk/src/test/resources/application-test.properties @@ -0,0 +1,2 @@ +micronaut.security.intercept-url-map[0].pattern=/mcp +micronaut.security.intercept-url-map[0].access[0]=isAnonymous() diff --git a/settings.gradle b/settings.gradle index ad1a1fa8..081ddea9 100644 --- a/settings.gradle +++ b/settings.gradle @@ -28,6 +28,7 @@ micronautBuild { useStandardizedProjectNames.set(true) importMicronautCatalog() importMicronautCatalog("micronaut-serde") + importMicronautCatalog("micronaut-security") importMicronautCatalog("micronaut-json-schema") importMicronautCatalog("micronaut-validation") importMicronautCatalog("micronaut-langchain4j") diff --git a/src/main/docs/guide/server/context.adoc b/src/main/docs/guide/server/context.adoc new file mode 100644 index 00000000..b931ab31 --- /dev/null +++ b/src/main/docs/guide/server/context.adoc @@ -0,0 +1 @@ +Micronaut MCP ships api:mcp.server.context.MicronautMcpTransportContext[], an extension to `io.modelcontextprotocol.common.McpTransportContext`, which allows you to access concepts such as the authenticated user, locale, host, etc. diff --git a/src/main/docs/guide/toc.yml b/src/main/docs/guide/toc.yml index fbb35e70..e304c221 100644 --- a/src/main/docs/guide/toc.yml +++ b/src/main/docs/guide/toc.yml @@ -14,6 +14,7 @@ server: serverInstance: Server Instance primitivesClassesPerServerType: Primitive types per Transport serverCapabilities: Server Capabilities + context: MCP Transport Context primitives: title: Primitives tools: