diff --git a/cli/src/client/connection.ts b/cli/src/client/connection.ts index 931f803da..06914e400 100644 --- a/cli/src/client/connection.ts +++ b/cli/src/client/connection.ts @@ -1,16 +1,14 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; -import { McpResponse } from "./types.js"; +import type { + LoggingLevel, + EmptyResult, +} from "@modelcontextprotocol/sdk/types.js"; +import { LoggingLevelSchema } from "@modelcontextprotocol/sdk/types.js"; -export const validLogLevels = [ - "trace", - "debug", - "info", - "warn", - "error", -] as const; - -export type LogLevel = (typeof validLogLevels)[number]; +// Extract valid log levels directly from the SDK's Zod schema to avoid drift +// This ensures CLI validation stays in sync with what the SDK accepts +export const validLogLevels = LoggingLevelSchema.options; export async function connect( client: Client, @@ -38,10 +36,10 @@ export async function disconnect(transport: Transport): Promise { // Set logging level export async function setLoggingLevel( client: Client, - level: LogLevel, -): Promise { + level: LoggingLevel, +): Promise { try { - const response = await client.setLoggingLevel(level as any); + const response = await client.setLoggingLevel(level); return response; } catch (error) { throw new Error( diff --git a/cli/src/client/prompts.ts b/cli/src/client/prompts.ts index 0b237496d..73b010705 100644 --- a/cli/src/client/prompts.ts +++ b/cli/src/client/prompts.ts @@ -1,8 +1,11 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { McpResponse } from "./types.js"; +import { + ListPromptsResult, + GetPromptResult, +} from "@modelcontextprotocol/sdk/types.js"; // List available prompts -export async function listPrompts(client: Client): Promise { +export async function listPrompts(client: Client): Promise { try { const response = await client.listPrompts(); return response; @@ -18,7 +21,7 @@ export async function getPrompt( client: Client, name: string, args?: Record, -): Promise { +): Promise { try { const response = await client.getPrompt({ name, diff --git a/cli/src/client/resources.ts b/cli/src/client/resources.ts index bf33d64d2..b3c41dba7 100644 --- a/cli/src/client/resources.ts +++ b/cli/src/client/resources.ts @@ -1,8 +1,14 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { McpResponse } from "./types.js"; +import { + ListResourcesResult, + ReadResourceResult, + ListResourceTemplatesResult, +} from "@modelcontextprotocol/sdk/types.js"; // List available resources -export async function listResources(client: Client): Promise { +export async function listResources( + client: Client, +): Promise { try { const response = await client.listResources(); return response; @@ -17,7 +23,7 @@ export async function listResources(client: Client): Promise { export async function readResource( client: Client, uri: string, -): Promise { +): Promise { try { const response = await client.readResource({ uri }); return response; @@ -31,7 +37,7 @@ export async function readResource( // List resource templates export async function listResourceTemplates( client: Client, -): Promise { +): Promise { try { const response = await client.listResourceTemplates(); return response; diff --git a/cli/src/client/tools.ts b/cli/src/client/tools.ts index acdb48710..2d15bd9e0 100644 --- a/cli/src/client/tools.ts +++ b/cli/src/client/tools.ts @@ -1,6 +1,10 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; -import { Tool } from "@modelcontextprotocol/sdk/types.js"; -import { McpResponse } from "./types.js"; +import { + Tool, + ListToolsResult, + CallToolResult, + CompatibilityCallToolResult, +} from "@modelcontextprotocol/sdk/types.js"; type JsonSchemaType = { type: "string" | "number" | "integer" | "boolean" | "array" | "object"; @@ -9,7 +13,7 @@ type JsonSchemaType = { items?: JsonSchemaType; }; -export async function listTools(client: Client): Promise { +export async function listTools(client: Client): Promise { try { const response = await client.listTools(); return response; @@ -69,10 +73,10 @@ export async function callTool( client: Client, name: string, args: Record, -): Promise { +): Promise { try { const toolsResponse = await listTools(client); - const tools = toolsResponse.tools as Tool[]; + const tools = toolsResponse.tools; const tool = tools.find((t) => t.name === name); let convertedArgs: Record = args; diff --git a/cli/src/client/types.ts b/cli/src/client/types.ts index bbbe1bf4f..984dade05 100644 --- a/cli/src/client/types.ts +++ b/cli/src/client/types.ts @@ -1 +1,23 @@ -export type McpResponse = Record; +import type { + ListToolsResult, + CallToolResult, + CompatibilityCallToolResult, + ListPromptsResult, + GetPromptResult, + ListResourcesResult, + ReadResourceResult, + ListResourceTemplatesResult, + EmptyResult, +} from "@modelcontextprotocol/sdk/types.js"; + +// Union type for all possible MCP response types +export type McpResponse = + | ListToolsResult + | CallToolResult + | CompatibilityCallToolResult + | ListPromptsResult + | GetPromptResult + | ListResourcesResult + | ReadResourceResult + | ListResourceTemplatesResult + | EmptyResult; diff --git a/cli/src/index.ts b/cli/src/index.ts index 5d5dcf8b9..67bea0876 100644 --- a/cli/src/index.ts +++ b/cli/src/index.ts @@ -11,12 +11,12 @@ import { listResources, listResourceTemplates, listTools, - LogLevel, McpResponse, readResource, setLoggingLevel, validLogLevels, } from "./client/index.js"; +import type { LoggingLevel } from "@modelcontextprotocol/sdk/types.js"; import { handleError } from "./error-handler.js"; import { createTransport, TransportOptions } from "./transport.js"; @@ -26,7 +26,7 @@ type Args = { promptName?: string; promptArgs?: Record; uri?: string; - logLevel?: LogLevel; + logLevel?: LoggingLevel; toolName?: string; toolArg?: Record; transport?: "sse" | "stdio" | "http"; @@ -232,13 +232,13 @@ function parseArgs(): Args { "--log-level ", "Logging level (for logging/setLevel method)", (value: string) => { - if (!validLogLevels.includes(value as any)) { + if (!validLogLevels.includes(value as LoggingLevel)) { throw new Error( `Invalid log level: ${value}. Valid levels are: ${validLogLevels.join(", ")}`, ); } - return value as LogLevel; + return value as LoggingLevel; }, ) //