File tree Expand file tree Collapse file tree 9 files changed +90
-75
lines changed Expand file tree Collapse file tree 9 files changed +90
-75
lines changed Original file line number Diff line number Diff line change @@ -26,6 +26,7 @@ nanoid = "^2.0.0"
2626pydantic = {git = " https://github.com/pydantic/pydantic.git" , rev = " f5d6acfe19fca38fad802458dab2b4c859182d7b" }
2727websockets = " ^12.0"
2828pydantic-core = " ^2.20.1"
29+ msgpack-types = " ^0.3.0"
2930
3031[tool .poetry .group .dev .dependencies ]
3132pytest = " ^7.4.0"
Original file line number Diff line number Diff line change @@ -25,7 +25,7 @@ def __init__(
2525 client_id : str ,
2626 server_id : str ,
2727 transport_options : TransportOptions ,
28- handshake_metadata : Optional [ HandshakeType ] = None ,
28+ handshake_metadata : HandshakeType ,
2929 ) -> None :
3030 self ._client_id = client_id
3131 self ._server_id = server_id
Original file line number Diff line number Diff line change @@ -51,7 +51,7 @@ def __init__(
5151 client_id : str ,
5252 server_id : str ,
5353 transport_options : TransportOptions ,
54- handshake_metadata : Optional [ HandshakeType ] = None ,
54+ handshake_metadata : HandshakeType ,
5555 ):
5656 super ().__init__ (
5757 transport_id = client_id ,
@@ -226,7 +226,7 @@ async def websocket_closed_callback() -> None:
226226 try :
227227 await send_transport_message (
228228 TransportMessage (
229- from_ = transport_id ,
229+ from_ = transport_id , # type: ignore
230230 to = to_id ,
231231 streamId = stream_id ,
232232 controlFlags = 0 ,
@@ -276,7 +276,7 @@ async def _establish_handshake(
276276 transport_id : str ,
277277 to_id : str ,
278278 session_id : str ,
279- handshake_metadata : Optional [ HandshakeType ] ,
279+ handshake_metadata : HandshakeType ,
280280 websocket : WebSocketCommonProtocol ,
281281 old_session : Optional [ClientSession ],
282282 ) -> Tuple [
Original file line number Diff line number Diff line change @@ -60,7 +60,7 @@ class RiverService(BaseModel):
6060
6161class RiverSchema (BaseModel ):
6262 services : Dict [str , RiverService ]
63- handshakeSchema : RiverConcreteType
63+ handshakeSchema : Optional [ RiverConcreteType ]
6464
6565
6666RiverSchemaFile = RootModel [RiverSchema ]
@@ -266,10 +266,13 @@ def generate_river_client_module(
266266 "" ,
267267 ]
268268
269- (handshake_type , handshake_chunks ) = encode_type (
270- schema_root .handshakeSchema , "HandshakeSchema"
271- )
272- chunks .extend (handshake_chunks )
269+ if schema_root .handshakeSchema is not None :
270+ (handshake_type , handshake_chunks ) = encode_type (
271+ schema_root .handshakeSchema , "HandshakeSchema"
272+ )
273+ chunks .extend (handshake_chunks )
274+ else :
275+ handshake_type = "Literal[None]"
273276
274277 for schema_name , schema in schema_root .services .items ():
275278 current_chunks : List [str ] = [
Original file line number Diff line number Diff line change 11import logging
22from typing import Any , Callable , Coroutine
33
4- import msgpack # type: ignore
4+ import msgpack
55import websockets
66from pydantic import ValidationError
77from pydantic_core import ValidationError as PydanticCoreValidationError
@@ -43,12 +43,11 @@ async def send_transport_message(
4343) -> None :
4444 logger .debug ("sending a message %r to ws %s" , msg , ws )
4545 try :
46- await ws .send (
47- prefix_bytes
48- + msgpack .packb (
49- msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
50- )
46+ packed = msgpack .packb (
47+ msg .model_dump (by_alias = True , exclude_none = True ), datetime = True
5148 )
49+ assert isinstance (packed , bytes )
50+ await ws .send (prefix_bytes + packed )
5251 except websockets .exceptions .ConnectionClosed as e :
5352 await websocket_closed_callback ()
5453 raise WebsocketClosedException ("Websocket closed during send message" ) from e
Original file line number Diff line number Diff line change @@ -87,7 +87,7 @@ async def _send_handshake_response(
8787 response_message = TransportMessage (
8888 streamId = request_message .streamId ,
8989 id = nanoid .generate (),
90- from_ = request_message .to ,
90+ from_ = request_message .to , # type: ignore
9191 to = request_message .from_ ,
9292 seq = 0 ,
9393 ack = 0 ,
Original file line number Diff line number Diff line change @@ -373,7 +373,7 @@ async def send_message(
373373 msg = TransportMessage (
374374 streamId = stream_id ,
375375 id = nanoid .generate (),
376- from_ = self ._transport_id ,
376+ from_ = self ._transport_id , # type: ignore
377377 to = self ._to_id ,
378378 seq = await self ._seq_manager .get_seq_and_increment (),
379379 ack = await self ._seq_manager .get_ack (),
Original file line number Diff line number Diff line change 11import asyncio
22import logging
33from collections .abc import AsyncIterator
4- from typing import Any , AsyncGenerator , NoReturn
4+ from typing import Any , AsyncGenerator , Literal
55
66import nanoid # type: ignore
77import pytest
@@ -36,7 +36,7 @@ def transport_message(
3636) -> TransportMessage :
3737 return TransportMessage (
3838 id = str (nanoid .generate ()),
39- from_ = from_ ,
39+ from_ = from_ , # type: ignore
4040 to = to ,
4141 streamId = streamId ,
4242 seq = seq ,
@@ -139,11 +139,12 @@ async def client(
139139) -> AsyncGenerator [Client , None ]:
140140 try :
141141 async with serve (server .serve , "localhost" , 8765 ):
142- client : Client [NoReturn ] = Client (
142+ client : Client [Literal [ None ] ] = Client (
143143 "ws://localhost:8765" ,
144144 client_id = "test_client" ,
145145 server_id = "test_server" ,
146146 transport_options = transport_options ,
147+ handshake_metadata = None ,
147148 )
148149 try :
149150 yield client
You can’t perform that action at this time.
0 commit comments