77import anyio .lowlevel
88import httpx
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10- from pydantic import BaseModel , RootModel
10+ from pydantic import BaseModel
1111
1212from mcp .shared .exceptions import McpError
1313from mcp .types import (
2121 JSONRPCNotification ,
2222 JSONRPCRequest ,
2323 JSONRPCResponse ,
24+ MessageFrame ,
2425 RequestParams ,
2526 ServerNotification ,
2627 ServerRequest ,
2728 ServerResult ,
2829)
2930
30- RawT = TypeVar ("RawT" )
31-
32-
33- class MessageFrame (RootModel [JSONRPCMessage ], Generic [RawT ]):
34- root : JSONRPCMessage
35- raw : RawT | None = None
36-
37- class Config :
38- arbitrary_types_allowed = True
39-
40-
41- ReadStream = MemoryObjectReceiveStream [MessageFrame [RawT ] | Exception ]
42- ReadStreamWriter = MemoryObjectSendStream [MessageFrame [RawT ] | Exception ]
43- WriteStream = MemoryObjectSendStream [MessageFrame [RawT ]]
44- WriteStreamReader = MemoryObjectReceiveStream [MessageFrame [RawT ]]
31+ ReadStream = MemoryObjectReceiveStream [MessageFrame | Exception ]
32+ ReadStreamWriter = MemoryObjectSendStream [MessageFrame | Exception ]
33+ WriteStream = MemoryObjectSendStream [MessageFrame ]
34+ WriteStreamReader = MemoryObjectReceiveStream [MessageFrame ]
4535
4636SendRequestT = TypeVar ("SendRequestT" , ClientRequest , ServerRequest )
4737SendResultT = TypeVar ("SendResultT" , ClientResult , ServerResult )
@@ -242,7 +232,7 @@ async def send_request(
242232 # TODO: Support progress callbacks
243233
244234 await self ._write_stream .send (
245- MessageFrame (JSONRPCMessage (jsonrpc_request ), None )
235+ MessageFrame (root = JSONRPCMessage (jsonrpc_request ), raw = None )
246236 )
247237
248238 try :
@@ -280,15 +270,17 @@ async def send_notification(self, notification: SendNotificationT) -> None:
280270 )
281271
282272 await self ._write_stream .send (
283- MessageFrame (JSONRPCMessage (jsonrpc_notification ))
273+ MessageFrame (root = JSONRPCMessage (jsonrpc_notification ), raw = None )
284274 )
285275
286276 async def _send_response (
287277 self , request_id : RequestId , response : SendResultT | ErrorData
288278 ) -> None :
289279 if isinstance (response , ErrorData ):
290280 jsonrpc_error = JSONRPCError (jsonrpc = "2.0" , id = request_id , error = response )
291- await self ._write_stream .send (MessageFrame (JSONRPCMessage (jsonrpc_error )))
281+ await self ._write_stream .send (
282+ MessageFrame (root = JSONRPCMessage (jsonrpc_error ), raw = None )
283+ )
292284 else :
293285 jsonrpc_response = JSONRPCResponse (
294286 jsonrpc = "2.0" ,
@@ -298,7 +290,7 @@ async def _send_response(
298290 ),
299291 )
300292 await self ._write_stream .send (
301- MessageFrame (JSONRPCMessage (jsonrpc_response ))
293+ MessageFrame (root = JSONRPCMessage (jsonrpc_response ), raw = None )
302294 )
303295
304296 async def _receive_loop (self ) -> None :
0 commit comments