Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 62 additions & 67 deletions crates/goose/src/token_counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ impl TokenCounter {
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use mcp_core::role::Role;
use serde_json::json;

#[test]
fn test_add_tokenizer_and_count_tokens() {
Expand Down Expand Up @@ -235,73 +238,65 @@ mod tests {
assert_eq!(count, 3);
}

#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use mcp_core::role::Role;
use serde_json::json;

#[test]
fn test_count_chat_tokens() {
let token_counter = TokenCounter::new();

let system_prompt =
"You are a helpful assistant that can answer questions about the weather.";

let messages = vec![
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::text(
"What's the weather like in San Francisco?",
)],
},
Message {
role: Role::Assistant,
created: 1,
content: vec![MessageContent::text(
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
)],
},
Message {
role: Role::User,
created: 2,
content: vec![MessageContent::text("How about New York?")],
},
];

let tools = vec![Tool {
name: "get_current_weather".to_string(),
description: "Get the current weather in a given location".to_string(),
input_schema: json!({
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"description": "The unit of temperature to return",
"enum": ["celsius", "fahrenheit"]
}
#[test]
fn test_count_chat_tokens() {
let token_counter = TokenCounter::new();

let system_prompt =
"You are a helpful assistant that can answer questions about the weather.";

let messages = vec![
Message {
role: Role::User,
created: 0,
content: vec![MessageContent::text(
"What's the weather like in San Francisco?",
)],
},
Message {
role: Role::Assistant,
created: 1,
content: vec![MessageContent::text(
"Looks like it's 60 degrees Fahrenheit in San Francisco.",
)],
},
Message {
role: Role::User,
created: 2,
content: vec![MessageContent::text("How about New York?")],
},
];

let tools = vec![Tool {
name: "get_current_weather".to_string(),
description: "Get the current weather in a given location".to_string(),
input_schema: json!({
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"required": ["location"]
}),
}];

let token_count_without_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &vec![], Some("gpt-4o"));
println!("Total tokens without tools: {}", token_count_without_tools);

let token_count_with_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &tools, Some("gpt-4o"));
println!("Total tokens with tools: {}", token_count_with_tools);

// The token count for messages without tools is calculated using the tokenizer - https://tiktokenizer.vercel.app/
// The token count for messages with tools is taken from tiktoken github repo example (notebook)
assert_eq!(token_count_without_tools, 56);
assert_eq!(token_count_with_tools, 124);
}
"unit": {
"type": "string",
"description": "The unit of temperature to return",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}),
}];

let token_count_without_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &[], Some("gpt-4o"));
println!("Total tokens without tools: {}", token_count_without_tools);

let token_count_with_tools =
token_counter.count_chat_tokens(system_prompt, &messages, &tools, Some("gpt-4o"));
println!("Total tokens with tools: {}", token_count_with_tools);

// The token count for messages without tools is calculated using the tokenizer - https://tiktokenizer.vercel.app/
// The token count for messages with tools is taken from tiktoken github repo example (notebook)
assert_eq!(token_count_without_tools, 56);
assert_eq!(token_count_with_tools, 124);
}
}
50 changes: 0 additions & 50 deletions crates/mcp-client/src/stdio_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,59 +119,9 @@ impl Transport for StdioTransport {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::time::Duration;
use tokio::time::timeout;

#[tokio::test]
async fn test_stdio_transport() {
let transport = StdioTransport {
params: StdioServerParams {
command: "tee".to_string(), // tee will echo back what it receives
args: vec![],
env: None,
},
};

let (mut rx, tx) = transport.connect().await.unwrap();

// Create test messages
let request = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(1),
method: "ping".to_string(),
params: None,
});

let response = JsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: Some(2),
result: Some(json!({})),
error: None,
});

// Send messages
tx.send(request.clone()).await.unwrap();
tx.send(response.clone()).await.unwrap();

// Receive and verify messages
let mut read_messages = Vec::new();

// Use timeout to avoid hanging if messages aren't received
for _ in 0..2 {
match timeout(Duration::from_secs(1), rx.recv()).await {
Ok(Some(Ok(msg))) => read_messages.push(msg),
Ok(Some(Err(e))) => panic!("Received error: {}", e),
Ok(None) => break,
Err(_) => panic!("Timeout waiting for message"),
}
}

assert_eq!(read_messages.len(), 2, "Expected 2 messages");
assert_eq!(read_messages[0], request);
assert_eq!(read_messages[1], response);
}

#[tokio::test]
async fn test_process_termination() {
let transport = StdioTransport {
Expand Down
20 changes: 15 additions & 5 deletions crates/mcp-core/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,39 @@ pub enum ToolError {
SerializationError(#[from] serde_json::Error),
#[error("Schema error: {0}")]
SchemaError(String),
#[error("Tool not found: {0}")]
NotFound(String),
}

#[derive(Error, Debug)]
pub enum ResourceError {
#[error("Execution failed: {0}")]
ExecutionError(String),
#[error("Resource not found: {0}")]
NotFound(String),
}

pub type Result<T> = std::result::Result<T, ToolError>;

/// Trait for implementing MCP tools
#[async_trait]
pub trait Tool: Send + Sync + 'static {
pub trait ToolHandler: Send + Sync + 'static {
/// The name of the tool
fn name() -> &'static str;
fn name(&self) -> &'static str;

/// A description of what the tool does
fn description() -> &'static str;
fn description(&self) -> &'static str;

/// JSON schema describing the tool's parameters
fn schema() -> Value;
fn schema(&self) -> Value;

/// Execute the tool with the given parameters
async fn call(&self, params: Value) -> Result<Value>;
}

/// Trait for implementing MCP resources
#[async_trait]
pub trait Resource: Send + Sync + 'static {
pub trait ResourceTemplateHandler: Send + Sync + 'static {
/// The URL template for this resource
fn template() -> &'static str;

Expand Down
68 changes: 67 additions & 1 deletion crates/mcp-core/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,105 @@ use serde_json::Value;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<ErrorData>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct JsonRpcError {
pub jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<u64>,
pub error: ErrorData,
}

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
#[serde(untagged, try_from = "JsonRpcRaw")]
pub enum JsonRpcMessage {
Request(JsonRpcRequest),
Response(JsonRpcResponse),
Notification(JsonRpcNotification),
Error(JsonRpcError),
}

#[derive(Debug, Serialize, Deserialize)]
struct JsonRpcRaw {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this get used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a deseraialization detail, where we use it to disambiguate when there are so many nulls that you can't tell if its a notification or a request

jsonrpc: String,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<u64>,
method: String,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<ErrorData>,
}

impl TryFrom<JsonRpcRaw> for JsonRpcMessage {
type Error = String;

fn try_from(raw: JsonRpcRaw) -> Result<Self, <Self as TryFrom<JsonRpcRaw>>::Error> {
// If it has an error field, it's an error response
if raw.error.is_some() {
return Ok(JsonRpcMessage::Error(JsonRpcError {
jsonrpc: raw.jsonrpc,
id: raw.id,
error: raw.error.unwrap(),
}));
}

// If it has a result field, it's a response
if raw.result.is_some() {
return Ok(JsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: raw.jsonrpc,
id: raw.id,
result: raw.result,
error: None,
}));
}

// If the method starts with "notifications/", it's a notification
if raw.method.starts_with("notifications/") {
return Ok(JsonRpcMessage::Notification(JsonRpcNotification {
jsonrpc: raw.jsonrpc,
method: raw.method,
params: raw.params,
}));
}

// Otherwise it's a request
Ok(JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: raw.jsonrpc,
id: raw.id,
method: raw.method,
params: raw.params,
}))
}
}

// Standard JSON-RPC error codes
pub const PARSE_ERROR: i32 = -32700;
pub const INVALID_REQUEST: i32 = -32600;
Expand Down Expand Up @@ -80,8 +143,11 @@ pub struct Implementation {

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct ServerCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompts: Option<PromptsCapability>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resources: Option<ResourcesCapability>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<ToolsCapability>,
// Add other capabilities as needed
}
Expand Down
2 changes: 2 additions & 0 deletions crates/mcp-core/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde_json::Value;

/// A tool that can be used by a model.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
/// The name of the tool
pub name: String,
Expand All @@ -31,6 +32,7 @@ impl Tool {

/// A tool call request that a system can execute
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolCall {
/// The name of the tool to execute
pub name: String,
Expand Down
Loading
Loading