diff --git a/crates/goose/src/token_counter.rs b/crates/goose/src/token_counter.rs index 40d1ab09a0ed..0af73b6a968e 100644 --- a/crates/goose/src/token_counter.rs +++ b/crates/goose/src/token_counter.rs @@ -207,6 +207,9 @@ impl TokenCounter { #[cfg(test)] mod tests { use super::*; + use crate::message::MessageContent; + use mcp_core::role::Role; + use serde_json::json; #[test] fn test_add_tokenizer_and_count_tokens() { @@ -235,73 +238,65 @@ mod tests { assert_eq!(count, 3); } - #[cfg(test)] - mod tests { - use super::*; - use crate::message::MessageContent; - use mcp_core::role::Role; - use serde_json::json; - - #[test] - fn test_count_chat_tokens() { - let token_counter = TokenCounter::new(); - - let system_prompt = - "You are a helpful assistant that can answer questions about the weather."; - - let messages = vec![ - Message { - role: Role::User, - created: 0, - content: vec![MessageContent::text( - "What's the weather like in San Francisco?", - )], - }, - Message { - role: Role::Assistant, - created: 1, - content: vec![MessageContent::text( - "Looks like it's 60 degrees Fahrenheit in San Francisco.", - )], - }, - Message { - role: Role::User, - created: 2, - content: vec![MessageContent::text("How about New York?")], - }, - ]; - - let tools = vec![Tool { - name: "get_current_weather".to_string(), - description: "Get the current weather in a given location".to_string(), - input_schema: json!({ - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "description": "The unit of temperature to return", - "enum": ["celsius", "fahrenheit"] - } + #[test] + fn test_count_chat_tokens() { + let token_counter = TokenCounter::new(); + + let system_prompt = + "You are a helpful assistant that can answer questions about the weather."; + + let messages = vec![ + Message { + role: Role::User, + created: 0, + content: vec![MessageContent::text( + "What's the weather like in San Francisco?", + )], + }, + Message { + role: Role::Assistant, + created: 1, + content: vec![MessageContent::text( + "Looks like it's 60 degrees Fahrenheit in San Francisco.", + )], + }, + Message { + role: Role::User, + created: 2, + content: vec![MessageContent::text("How about New York?")], + }, + ]; + + let tools = vec![Tool { + name: "get_current_weather".to_string(), + description: "Get the current weather in a given location".to_string(), + input_schema: json!({ + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" }, - "required": ["location"] - }), - }]; - - let token_count_without_tools = - token_counter.count_chat_tokens(system_prompt, &messages, &vec![], Some("gpt-4o")); - println!("Total tokens without tools: {}", token_count_without_tools); - - let token_count_with_tools = - token_counter.count_chat_tokens(system_prompt, &messages, &tools, Some("gpt-4o")); - println!("Total tokens with tools: {}", token_count_with_tools); - - // The token count for messages without tools is calculated using the tokenizer - https://tiktokenizer.vercel.app/ - // The token count for messages with tools is taken from tiktoken github repo example (notebook) - assert_eq!(token_count_without_tools, 56); - assert_eq!(token_count_with_tools, 124); - } + "unit": { + "type": "string", + "description": "The unit of temperature to return", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + }), + }]; + + let token_count_without_tools = + token_counter.count_chat_tokens(system_prompt, &messages, &[], Some("gpt-4o")); + println!("Total tokens without tools: {}", token_count_without_tools); + + let token_count_with_tools = + token_counter.count_chat_tokens(system_prompt, &messages, &tools, Some("gpt-4o")); + println!("Total tokens with tools: {}", token_count_with_tools); + + // The token count for messages without tools is calculated using the tokenizer - https://tiktokenizer.vercel.app/ + // The token count for messages with tools is taken from tiktoken github repo example (notebook) + assert_eq!(token_count_without_tools, 56); + assert_eq!(token_count_with_tools, 124); } } diff --git a/crates/mcp-client/src/stdio_transport.rs b/crates/mcp-client/src/stdio_transport.rs index 67f0fbf12053..3b965ab83a41 100644 --- a/crates/mcp-client/src/stdio_transport.rs +++ b/crates/mcp-client/src/stdio_transport.rs @@ -119,59 +119,9 @@ impl Transport for StdioTransport { #[cfg(test)] mod tests { use super::*; - use serde_json::json; use std::time::Duration; use tokio::time::timeout; - #[tokio::test] - async fn test_stdio_transport() { - let transport = StdioTransport { - params: StdioServerParams { - command: "tee".to_string(), // tee will echo back what it receives - args: vec![], - env: None, - }, - }; - - let (mut rx, tx) = transport.connect().await.unwrap(); - - // Create test messages - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(1), - method: "ping".to_string(), - params: None, - }); - - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: Some(2), - result: Some(json!({})), - error: None, - }); - - // Send messages - tx.send(request.clone()).await.unwrap(); - tx.send(response.clone()).await.unwrap(); - - // Receive and verify messages - let mut read_messages = Vec::new(); - - // Use timeout to avoid hanging if messages aren't received - for _ in 0..2 { - match timeout(Duration::from_secs(1), rx.recv()).await { - Ok(Some(Ok(msg))) => read_messages.push(msg), - Ok(Some(Err(e))) => panic!("Received error: {}", e), - Ok(None) => break, - Err(_) => panic!("Timeout waiting for message"), - } - } - - assert_eq!(read_messages.len(), 2, "Expected 2 messages"); - assert_eq!(read_messages[0], request); - assert_eq!(read_messages[1], response); - } - #[tokio::test] async fn test_process_termination() { let transport = StdioTransport { diff --git a/crates/mcp-core/src/handler.rs b/crates/mcp-core/src/handler.rs index 167a7d56a106..407d91e64cb1 100644 --- a/crates/mcp-core/src/handler.rs +++ b/crates/mcp-core/src/handler.rs @@ -13,21 +13,31 @@ pub enum ToolError { SerializationError(#[from] serde_json::Error), #[error("Schema error: {0}")] SchemaError(String), + #[error("Tool not found: {0}")] + NotFound(String), +} + +#[derive(Error, Debug)] +pub enum ResourceError { + #[error("Execution failed: {0}")] + ExecutionError(String), + #[error("Resource not found: {0}")] + NotFound(String), } pub type Result = std::result::Result; /// Trait for implementing MCP tools #[async_trait] -pub trait Tool: Send + Sync + 'static { +pub trait ToolHandler: Send + Sync + 'static { /// The name of the tool - fn name() -> &'static str; + fn name(&self) -> &'static str; /// A description of what the tool does - fn description() -> &'static str; + fn description(&self) -> &'static str; /// JSON schema describing the tool's parameters - fn schema() -> Value; + fn schema(&self) -> Value; /// Execute the tool with the given parameters async fn call(&self, params: Value) -> Result; @@ -35,7 +45,7 @@ pub trait Tool: Send + Sync + 'static { /// Trait for implementing MCP resources #[async_trait] -pub trait Resource: Send + Sync + 'static { +pub trait ResourceTemplateHandler: Send + Sync + 'static { /// The URL template for this resource fn template() -> &'static str; diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs index 259050a20354..e178c41177c3 100644 --- a/crates/mcp-core/src/protocol.rs +++ b/crates/mcp-core/src/protocol.rs @@ -6,16 +6,21 @@ use serde_json::Value; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct JsonRpcRequest { pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] pub id: Option, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct JsonRpcResponse { pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } @@ -23,18 +28,20 @@ pub struct JsonRpcResponse { pub struct JsonRpcNotification { pub jsonrpc: String, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct JsonRpcError { pub jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] pub id: Option, pub error: ErrorData, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(untagged)] +#[serde(untagged, try_from = "JsonRpcRaw")] pub enum JsonRpcMessage { Request(JsonRpcRequest), Response(JsonRpcResponse), @@ -42,6 +49,62 @@ pub enum JsonRpcMessage { Error(JsonRpcError), } +#[derive(Debug, Serialize, Deserialize)] +struct JsonRpcRaw { + jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +impl TryFrom for JsonRpcMessage { + type Error = String; + + fn try_from(raw: JsonRpcRaw) -> Result>::Error> { + // If it has an error field, it's an error response + if raw.error.is_some() { + return Ok(JsonRpcMessage::Error(JsonRpcError { + jsonrpc: raw.jsonrpc, + id: raw.id, + error: raw.error.unwrap(), + })); + } + + // If it has a result field, it's a response + if raw.result.is_some() { + return Ok(JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: raw.jsonrpc, + id: raw.id, + result: raw.result, + error: None, + })); + } + + // If the method starts with "notifications/", it's a notification + if raw.method.starts_with("notifications/") { + return Ok(JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: raw.jsonrpc, + method: raw.method, + params: raw.params, + })); + } + + // Otherwise it's a request + Ok(JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: raw.jsonrpc, + id: raw.id, + method: raw.method, + params: raw.params, + })) + } +} + // Standard JSON-RPC error codes pub const PARSE_ERROR: i32 = -32700; pub const INVALID_REQUEST: i32 = -32600; @@ -80,8 +143,11 @@ pub struct Implementation { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ServerCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] pub prompts: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub resources: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option, // Add other capabilities as needed } diff --git a/crates/mcp-core/src/tool.rs b/crates/mcp-core/src/tool.rs index 6401b9632671..2497abe64b1f 100644 --- a/crates/mcp-core/src/tool.rs +++ b/crates/mcp-core/src/tool.rs @@ -5,6 +5,7 @@ use serde_json::Value; /// A tool that can be used by a model. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct Tool { /// The name of the tool pub name: String, @@ -31,6 +32,7 @@ impl Tool { /// A tool call request that a system can execute #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct ToolCall { /// The name of the tool to execute pub name: String, diff --git a/crates/mcp-macros/Cargo.toml b/crates/mcp-macros/Cargo.toml new file mode 100644 index 000000000000..65450f7d2b00 --- /dev/null +++ b/crates/mcp-macros/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "mcp-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = { version = "2.0", features = ["full", "extra-traits"] } +quote = "1.0" +proc-macro2 = "1.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +mcp-core = { path = "../mcp-core" } +async-trait = "0.1" +schemars = "0.8" +convert_case = "0.6.0" + +[dev-dependencies] +tokio = { version = "1.0", features = ["full"] } +async-trait = "0.1" +serde_json = "1.0" +schemars = "0.8" diff --git a/crates/mcp-macros/examples/calculator.rs b/crates/mcp-macros/examples/calculator.rs new file mode 100644 index 000000000000..74be963d64db --- /dev/null +++ b/crates/mcp-macros/examples/calculator.rs @@ -0,0 +1,53 @@ +use mcp_core::handler::{Result, ToolError, ToolHandler}; +use mcp_macros::tool; + +#[tokio::main] +async fn main() -> std::result::Result<(), Box> { + // Create an instance of our tool + let calculator = Calculator; + + // Print tool information + println!("Tool name: {}", calculator.name()); + println!("Tool description: {}", calculator.description()); + println!("Tool schema: {}", calculator.schema()); + + // Test the tool with some sample input + let input = serde_json::json!({ + "x": 5, + "y": 3, + "operation": "multiply" + }); + + let result = calculator.call(input).await?; + println!("Result: {}", result); + + Ok(()) +} + +#[tool( + name = "calculator", + description = "Perform basic arithmetic operations", + params( + x = "First number in the calculation", + y = "Second number in the calculation", + operation = "The operation to perform (add, subtract, multiply, divide)" + ) +)] +async fn calculator(x: i32, y: i32, operation: String) -> Result { + match operation.as_str() { + "add" => Ok(x + y), + "subtract" => Ok(x - y), + "multiply" => Ok(x * y), + "divide" => { + if y == 0 { + Err(ToolError::ExecutionError("Division by zero".into())) + } else { + Ok(x / y) + } + } + _ => Err(ToolError::InvalidParameters(format!( + "Unknown operation: {}", + operation + ))), + } +} diff --git a/crates/mcp-macros/src/lib.rs b/crates/mcp-macros/src/lib.rs new file mode 100644 index 000000000000..d777d9705e95 --- /dev/null +++ b/crates/mcp-macros/src/lib.rs @@ -0,0 +1,152 @@ +use convert_case::{Case, Casing}; +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use std::collections::HashMap; +use syn::{ + parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ExprLit, + FnArg, ItemFn, Lit, Meta, Pat, PatType, Token, +}; + +struct MacroArgs { + name: Option, + description: Option, + param_descriptions: HashMap, +} + +impl Parse for MacroArgs { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut description = None; + let mut param_descriptions = HashMap::new(); + + let meta_list: Punctuated = Punctuated::parse_terminated(input)?; + + for meta in meta_list { + match meta { + Meta::NameValue(nv) => { + let ident = nv.path.get_ident().unwrap().to_string(); + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = nv.value + { + match ident.as_str() { + "name" => name = Some(lit_str.value()), + "description" => description = Some(lit_str.value()), + _ => {} + } + } + } + Meta::List(list) if list.path.is_ident("params") => { + let nested: Punctuated = + list.parse_args_with(Punctuated::parse_terminated)?; + + for meta in nested { + if let Meta::NameValue(nv) = meta { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = nv.value + { + let param_name = nv.path.get_ident().unwrap().to_string(); + param_descriptions.insert(param_name, lit_str.value()); + } + } + } + } + _ => {} + } + } + + Ok(MacroArgs { + name, + description, + param_descriptions, + }) + } +} + +#[proc_macro_attribute] +pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as MacroArgs); + let input_fn = parse_macro_input!(input as ItemFn); + + // Extract function details + let fn_name = &input_fn.sig.ident; + let fn_name_str = fn_name.to_string(); + + // Generate PascalCase struct name from the function name + let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) }); + + // Use provided name or function name as default + let tool_name = args.name.unwrap_or(fn_name_str); + let tool_description = args.description.unwrap_or_default(); + + // Extract parameter names, types, and descriptions + let mut param_defs = Vec::new(); + let mut param_names = Vec::new(); + + for arg in input_fn.sig.inputs.iter() { + if let FnArg::Typed(PatType { pat, ty, .. }) = arg { + if let Pat::Ident(param_ident) = &**pat { + let param_name = ¶m_ident.ident; + let param_name_str = param_name.to_string(); + let description = args + .param_descriptions + .get(¶m_name_str) + .map(|s| s.as_str()) + .unwrap_or(""); + + param_names.push(param_name); + param_defs.push(quote! { + #[schemars(description = #description)] + #param_name: #ty + }); + } + } + } + + // Generate the implementation + let params_struct_name = format_ident!("{}Parameters", struct_name); + let expanded = quote! { + #[derive(serde::Deserialize, schemars::JsonSchema)] + struct #params_struct_name { + #(#param_defs,)* + } + + #input_fn + + #[derive(Default)] + struct #struct_name; + + #[async_trait::async_trait] + impl mcp_core::handler::ToolHandler for #struct_name { + fn name(&self) -> &'static str { + #tool_name + } + + fn description(&self) -> &'static str { + #tool_description + } + + fn schema(&self) -> serde_json::Value { + mcp_core::handler::generate_schema::<#params_struct_name>() + .expect("Failed to generate schema") + } + + async fn call(&self, params: serde_json::Value) -> mcp_core::handler::Result { + let params: #params_struct_name = serde_json::from_value(params) + .map_err(|e| mcp_core::handler::ToolError::InvalidParameters(e.to_string()))?; + + // Extract parameters and call the function + let result = #fn_name(#(params.#param_names,)*).await + .map_err(|e| mcp_core::handler::ToolError::ExecutionError(e.to_string()))?; + + serde_json::to_value(result) + .map_err(mcp_core::handler::ToolError::SerializationError) + } + } + }; + + TokenStream::from(expanded) +} diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml new file mode 100644 index 000000000000..3657fc582ab5 --- /dev/null +++ b/crates/mcp-server/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "mcp-server" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.94" +thiserror = "1.0" +mcp-core = { path = "../mcp-core" } +mcp-macros = { path = "../mcp-macros" } +serde = { version = "1.0.216", features = ["derive"] } +serde_json = "1.0.133" +schemars = "0.8" +tokio = { version = "1", features = ["full"] } +tower = { version = "0.4", features = ["timeout"] } +tower-service = "0.3" +futures = "0.3" +pin-project = "1.1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-appender = "0.2" +async-trait = "0.1" diff --git a/crates/mcp-server/README.md b/crates/mcp-server/README.md new file mode 100644 index 000000000000..1e4f06176553 --- /dev/null +++ b/crates/mcp-server/README.md @@ -0,0 +1,7 @@ +### Test with MCP Inspector + +```bash +npx @modelcontextprotocol/inspector cargo run -p mcp-server +``` + +Then visit the Inspector in the browser window and test the different endpoints. \ No newline at end of file diff --git a/crates/mcp-server/src/errors.rs b/crates/mcp-server/src/errors.rs new file mode 100644 index 000000000000..f49eab058993 --- /dev/null +++ b/crates/mcp-server/src/errors.rs @@ -0,0 +1,96 @@ +use thiserror::Error; + +pub type BoxError = Box; + +#[derive(Error, Debug)] +pub enum TransportError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON serialization error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Invalid UTF-8 sequence: {0}")] + Utf8(#[from] std::string::FromUtf8Error), + + #[error("Protocol error: {0}")] + Protocol(String), + + #[error("Invalid message format: {0}")] + InvalidMessage(String), +} + +#[derive(Error, Debug)] +pub enum ServerError { + #[error("Transport error: {0}")] + Transport(#[from] TransportError), + + #[error("Service error: {0}")] + Service(String), + + #[error("Internal error: {0}")] + Internal(String), + + #[error("Request timed out")] + Timeout(#[from] tower::timeout::error::Elapsed), +} + +#[derive(Error, Debug)] +pub enum RouterError { + #[error("Method not found: {0}")] + MethodNotFound(String), + + #[error("Invalid parameters: {0}")] + InvalidParams(String), + + #[error("Internal error: {0}")] + Internal(String), + + #[error("Tool not found: {0}")] + ToolNotFound(String), + + #[error("Resource not found: {0}")] + ResourceNotFound(String), +} + +impl From for mcp_core::protocol::ErrorData { + fn from(err: RouterError) -> Self { + use mcp_core::protocol::*; + match err { + RouterError::MethodNotFound(msg) => ErrorData { + code: METHOD_NOT_FOUND, + message: msg, + data: None, + }, + RouterError::InvalidParams(msg) => ErrorData { + code: INVALID_PARAMS, + message: msg, + data: None, + }, + RouterError::Internal(msg) => ErrorData { + code: INTERNAL_ERROR, + message: msg, + data: None, + }, + RouterError::ToolNotFound(msg) => ErrorData { + code: INVALID_REQUEST, + message: msg, + data: None, + }, + RouterError::ResourceNotFound(msg) => ErrorData { + code: INVALID_REQUEST, + message: msg, + data: None, + }, + } + } +} + +impl From for RouterError { + fn from(err: mcp_core::handler::ResourceError) -> Self { + match err { + mcp_core::handler::ResourceError::NotFound(msg) => RouterError::ResourceNotFound(msg), + _ => RouterError::Internal("Unknown resource error".to_string()), + } + } +} diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs new file mode 100644 index 000000000000..b054a98c59a4 --- /dev/null +++ b/crates/mcp-server/src/lib.rs @@ -0,0 +1,237 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{Future, Stream}; +use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse}; +use pin_project::pin_project; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +use tower_service::Service; + +mod errors; +pub use errors::{BoxError, RouterError, ServerError, TransportError}; + +pub mod router; +pub use router::Router; + +/// A transport layer that handles JSON-RPC messages over byte +#[pin_project] +pub struct ByteTransport { + #[pin] + reader: R, + #[pin] + writer: W, +} + +impl ByteTransport +where + R: AsyncRead, + W: AsyncWrite, +{ + pub fn new(reader: R, writer: W) -> Self { + Self { reader, writer } + } +} + +impl Stream for ByteTransport +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + let mut buf = Vec::new(); + let mut reader = BufReader::new(&mut this.reader); + + let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf)); + match read_future.as_mut().poll(cx) { + Poll::Ready(Ok(0)) => Poll::Ready(None), // EOF + Poll::Ready(Ok(_)) => { + // Convert to UTF-8 string + let line = match String::from_utf8(buf) { + Ok(s) => s, + Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))), + }; + + // Parse JSON and validate message format + match serde_json::from_str::(&line) { + Ok(value) => { + // Validate basic JSON-RPC structure + if !value.is_object() { + return Poll::Ready(Some(Err(TransportError::InvalidMessage( + "Message must be a JSON object".into(), + )))); + } + + let obj = value.as_object().unwrap(); // Safe due to check above + + // Check jsonrpc version field + if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" { + return Poll::Ready(Some(Err(TransportError::InvalidMessage( + "Missing or invalid jsonrpc version".into(), + )))); + } + + tracing::info!( + json = %line, + "incoming message" + ); + // Now try to parse as proper message + match serde_json::from_value::(value) { + Ok(msg) => Poll::Ready(Some(Ok(msg))), + Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), + } + } + Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), + } + } + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))), + Poll::Pending => Poll::Pending, + } + } +} + +impl ByteTransport +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + pub async fn write_message(&mut self, msg: JsonRpcMessage) -> Result<(), std::io::Error> { + let json = serde_json::to_string(&msg)?; + Pin::new(&mut self.writer) + .write_all(json.as_bytes()) + .await?; + Pin::new(&mut self.writer).write_all(b"\n").await?; + Pin::new(&mut self.writer).flush().await?; + Ok(()) + } +} + +/// The main server type that processes incoming requests +pub struct Server { + service: S, +} + +impl Server +where + S: Service + Send, + S::Error: Into, + S::Future: Send, +{ + pub fn new(service: S) -> Self { + Self { service } + } + + // TODO transport trait instead of byte transport if we implement others + pub async fn run(self, mut transport: ByteTransport) -> Result<(), ServerError> + where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, + { + use futures::StreamExt; + let mut service = self.service; + + tracing::info!("Server started"); + while let Some(msg_result) = transport.next().await { + let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered(); + match msg_result { + Ok(msg) => { + match msg { + JsonRpcMessage::Request(request) => { + // Serialize request for logging + let id = request.id; + let request_json = serde_json::to_string(&request) + .unwrap_or_else(|_| "Failed to serialize request".to_string()); + + tracing::info!( + request_id = ?id, + method = ?request.method, + json = %request_json, + "Received request" + ); + + // Process the request using our service + let response = match service.call(request).await { + Ok(resp) => resp, + Err(e) => { + let error_msg = e.into().to_string(); + tracing::error!(error = %error_msg, "Request processing failed"); + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(mcp_core::protocol::ErrorData { + code: mcp_core::protocol::INTERNAL_ERROR, + message: error_msg, + data: None, + }), + } + } + }; + + // Serialize response for logging + let response_json = serde_json::to_string(&response) + .unwrap_or_else(|_| "Failed to serialize response".to_string()); + + tracing::info!( + response_id = ?response.id, + json = %response_json, + "Sending response" + ); + // Send the response back + if let Err(e) = transport + .write_message(JsonRpcMessage::Response(response)) + .await + { + return Err(ServerError::Transport(TransportError::Io(e))); + } + } + JsonRpcMessage::Response(_) + | JsonRpcMessage::Notification(_) + | JsonRpcMessage::Error(_) => { + // Ignore responses and notifications for now + continue; + } + } + } + Err(e) => { + // Convert transport error to JSON-RPC error response + let error = match e { + TransportError::Json(_) | TransportError::InvalidMessage(_) => { + mcp_core::protocol::ErrorData { + code: mcp_core::protocol::PARSE_ERROR, + message: e.to_string(), + data: None, + } + } + TransportError::Protocol(_) => mcp_core::protocol::ErrorData { + code: mcp_core::protocol::INVALID_REQUEST, + message: e.to_string(), + data: None, + }, + _ => mcp_core::protocol::ErrorData { + code: mcp_core::protocol::INTERNAL_ERROR, + message: e.to_string(), + data: None, + }, + }; + + let error_response = JsonRpcMessage::Error(JsonRpcError { + jsonrpc: "2.0".to_string(), + id: None, + error, + }); + + if let Err(e) = transport.write_message(error_response).await { + return Err(ServerError::Transport(TransportError::Io(e))); + } + } + } + } + + Ok(()) + } +} diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs new file mode 100644 index 000000000000..b8911b3690ca --- /dev/null +++ b/crates/mcp-server/src/main.rs @@ -0,0 +1,166 @@ +use anyhow::Result; +use mcp_core::handler::ResourceError; +use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; +use mcp_server::router::{CapabilitiesBuilder, RouterService}; +use mcp_server::{ByteTransport, Router, Server}; +use serde_json::Value; +use std::{future::Future, pin::Pin, sync::Arc}; +use tokio::{ + io::{stdin, stdout}, + sync::Mutex, +}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{self, EnvFilter}; + +// A simple counter service that demonstrates the Router trait +#[derive(Clone)] +struct CounterRouter { + counter: Arc>, +} + +impl CounterRouter { + fn new() -> Self { + Self { + counter: Arc::new(Mutex::new(0)), + } + } + + async fn increment(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter += 1; + Ok(*counter) + } + + async fn decrement(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter -= 1; + Ok(*counter) + } + + async fn get_value(&self) -> Result { + let counter = self.counter.lock().await; + Ok(*counter) + } +} + +impl Router for CounterRouter { + fn capabilities(&self) -> ServerCapabilities { + CapabilitiesBuilder::new().with_tools(true).build() + } + + fn list_tools(&self) -> Vec { + vec![ + Tool::new( + "increment".to_string(), + "Increment the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "decrement".to_string(), + "Decrement the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "get_value".to_string(), + "Get the current counter value".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + ] + } + + fn call_tool( + &self, + tool_name: &str, + _arguments: Value, + ) -> Pin> + Send + 'static>> { + let this = self.clone(); + let tool_name = tool_name.to_string(); + + Box::pin(async move { + match tool_name.as_str() { + "increment" => { + let value = this.increment().await?; + Ok(Value::Number(value.into())) + } + "decrement" => { + let value = this.decrement().await?; + Ok(Value::Number(value.into())) + } + "get_value" => { + let value = this.get_value().await?; + Ok(Value::Number(value.into())) + } + _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), + } + }) + } + + fn list_resources(&self) -> Vec { + vec![Resource::new( + "memo://insights", + Some("text/plain".to_string()), + Some("memo-resource".to_string()), + ) + .unwrap()] + } + + fn read_resource( + &self, + uri: &str, + ) -> Pin> + Send + 'static>> { + let uri = uri.to_string(); + Box::pin(async move { + match uri.as_str() { + "memo://insights" => { + let memo = + "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; + Ok(memo.to_string()) + } + _ => Err(ResourceError::NotFound(format!( + "Resource {} not found", + uri + ))), + } + }) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Set up file appender for logging + let file_appender = RollingFileAppender::new(Rotation::DAILY, "logs", "mcp-server.log"); + + // Initialize the tracing subscriber with file and stdout logging + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::INFO.into())) + .with_writer(file_appender) + .with_target(false) + .with_thread_ids(true) + .with_file(true) + .with_line_number(true) + .init(); + + tracing::info!("Starting MCP server"); + + // Create an instance of our counter router + let router = RouterService(CounterRouter::new()); + + // Create and run the server + let server = Server::new(router); + let transport = ByteTransport::new(stdin(), stdout()); + + tracing::info!("Server initialized and ready to handle requests"); + Ok(server.run(transport).await?) +} diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs new file mode 100644 index 000000000000..d2d8a9ddf679 --- /dev/null +++ b/crates/mcp-server/src/router.rs @@ -0,0 +1,270 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use mcp_core::{ + content::Content, + handler::{ResourceError, ToolError}, + protocol::{ + CallToolResult, Implementation, InitializeResult, JsonRpcRequest, JsonRpcResponse, + ListResourcesResult, ListToolsResult, PromptsCapability, ReadResourceResult, + ResourcesCapability, ServerCapabilities, ToolsCapability, + }, + ResourceContents, +}; +use serde_json::Value; +use tower_service::Service; + +use crate::{BoxError, RouterError}; + +/// Builder for configuring and constructing capabilities +pub struct CapabilitiesBuilder { + tools: Option, + prompts: Option, + resources: Option, +} + +impl Default for CapabilitiesBuilder { + fn default() -> Self { + Self::new() + } +} + +impl CapabilitiesBuilder { + pub fn new() -> Self { + Self { + tools: None, + prompts: None, + resources: None, + } + } + + /// Add multiple tools to the router + pub fn with_tools(mut self, list_changed: bool) -> Self { + self.tools = Some(ToolsCapability { + list_changed: Some(list_changed), + }); + self + } + + /// Enable prompts capability + pub fn with_prompts(mut self, list_changed: bool) -> Self { + self.prompts = Some(PromptsCapability { + list_changed: Some(list_changed), + }); + self + } + + /// Enable resources capability + pub fn with_resources(mut self, subscribe: bool, list_changed: bool) -> Self { + self.resources = Some(ResourcesCapability { + subscribe: Some(subscribe), + list_changed: Some(list_changed), + }); + self + } + + /// Build the router with automatic capability inference + pub fn build(self) -> ServerCapabilities { + // Create capabilities based on what's configured + ServerCapabilities { + tools: self.tools, + prompts: self.prompts, + resources: self.resources, + } + } +} + +pub trait Router: Send + Sync + 'static { + fn capabilities(&self) -> ServerCapabilities; + fn list_tools(&self) -> Vec; + fn call_tool( + &self, + tool_name: &str, + arguments: Value, + ) -> Pin> + Send + 'static>>; + fn list_resources(&self) -> Vec; + fn read_resource( + &self, + uri: &str, + ) -> Pin> + Send + 'static>>; + + // Helper method to create base response + fn create_response(&self, id: Option) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: None, + } + } + + fn handle_initialize( + &self, + req: JsonRpcRequest, + ) -> impl Future> + Send { + async move { + let result = InitializeResult { + protocol_version: "2024-11-05".to_string(), + capabilities: self.capabilities().clone(), + server_info: Implementation { + name: "mcp-server".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }, + }; + + let mut response = self.create_response(req.id); + response.result = + Some(serde_json::to_value(result).map_err(|e| { + RouterError::Internal(format!("JSON serialization error: {}", e)) + })?); + + Ok(response) + } + } + + fn handle_tools_list( + &self, + req: JsonRpcRequest, + ) -> impl Future> + Send { + async move { + let tools = self.list_tools(); + + let result = ListToolsResult { tools }; + let mut response = self.create_response(req.id); + response.result = + Some(serde_json::to_value(result).map_err(|e| { + RouterError::Internal(format!("JSON serialization error: {}", e)) + })?); + + Ok(response) + } + } + + fn handle_tools_call( + &self, + req: JsonRpcRequest, + ) -> impl Future> + Send { + async move { + let params = req + .params + .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; + + let name = params + .get("name") + .and_then(Value::as_str) + .ok_or_else(|| RouterError::InvalidParams("Missing tool name".into()))?; + + let arguments = params.get("arguments").cloned().unwrap_or(Value::Null); + + let result = match self.call_tool(name, arguments).await { + Ok(result) => CallToolResult { + content: vec![Content::text(result.to_string())], + is_error: false, + }, + Err(err) => CallToolResult { + content: vec![Content::text(err.to_string())], + is_error: true, + }, + }; + + let mut response = self.create_response(req.id); + response.result = + Some(serde_json::to_value(result).map_err(|e| { + RouterError::Internal(format!("JSON serialization error: {}", e)) + })?); + + Ok(response) + } + } + + fn handle_resources_list( + &self, + req: JsonRpcRequest, + ) -> impl Future> + Send { + async move { + let resources = self.list_resources(); + + let result = ListResourcesResult { resources }; + let mut response = self.create_response(req.id); + response.result = + Some(serde_json::to_value(result).map_err(|e| { + RouterError::Internal(format!("JSON serialization error: {}", e)) + })?); + + Ok(response) + } + } + + fn handle_resources_read( + &self, + req: JsonRpcRequest, + ) -> impl Future> + Send { + async move { + let params = req + .params + .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; + + let uri = params + .get("uri") + .and_then(Value::as_str) + .ok_or_else(|| RouterError::InvalidParams("Missing resource URI".into()))?; + + let contents = self.read_resource(uri).await.map_err(RouterError::from)?; + + let result = ReadResourceResult { + contents: vec![ResourceContents::TextResourceContents { + uri: uri.to_string(), + mime_type: Some("text/plain".to_string()), + text: contents, + }], + }; + + let mut response = self.create_response(req.id); + response.result = + Some(serde_json::to_value(result).map_err(|e| { + RouterError::Internal(format!("JSON serialization error: {}", e)) + })?); + + Ok(response) + } + } +} + +pub struct RouterService(pub T); + +impl Service for RouterService +where + T: Router + Clone + Send + Sync + 'static, +{ + type Response = JsonRpcResponse; + type Error = BoxError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: JsonRpcRequest) -> Self::Future { + let this = self.0.clone(); + + Box::pin(async move { + let result = match req.method.as_str() { + "initialize" => this.handle_initialize(req).await, + "tools/list" => this.handle_tools_list(req).await, + "tools/call" => this.handle_tools_call(req).await, + "resources/list" => this.handle_resources_list(req).await, + "resources/read" => this.handle_resources_read(req).await, + _ => { + let mut response = this.create_response(req.id); + response.error = Some(RouterError::MethodNotFound(req.method).into()); + Ok(response) + } + }; + + result.map_err(BoxError::from) + }) + } +}