Skip to content

Commit 7062bf6

Browse files
chore/schema handshake metadata pt2 (#74)
Why === Follow-up to #72 to account for `schema.json` files without metadata. What changed ============ If the `handshakeSchema` field is defined, then the parameter is required. Otherwise, the parameter is `Literal[None]`, which matches the previous behavior of the default. Due to the metadata field now being required, I had to remove the `= None` default parameter. I think this is alright. It'll mean that #73 will likely need to change to `handshake_metadata_factory: HandshakeType | Callable[[], Awaitable[HandshakeType]]` to avoid `async def stub() -> None: return None` just to satisfy the async requirement. It's somewhat challenging to capture the exact semantics we want here, but I think that's alright. Test plan ========= Do typechecks pass?
1 parent d5aabb4 commit 7062bf6

File tree

9 files changed

+90
-75
lines changed

9 files changed

+90
-75
lines changed

poetry.lock

Lines changed: 66 additions & 55 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ nanoid = "^2.0.0"
2626
pydantic = {git = "https://github.com/pydantic/pydantic.git", rev = "f5d6acfe19fca38fad802458dab2b4c859182d7b"}
2727
websockets = "^12.0"
2828
pydantic-core = "^2.20.1"
29+
msgpack-types = "^0.3.0"
2930

3031
[tool.poetry.group.dev.dependencies]
3132
pytest = "^7.4.0"

replit_river/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
client_id: str,
2626
server_id: str,
2727
transport_options: TransportOptions,
28-
handshake_metadata: Optional[HandshakeType] = None,
28+
handshake_metadata: HandshakeType,
2929
) -> None:
3030
self._client_id = client_id
3131
self._server_id = server_id

replit_river/client_transport.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
client_id: str,
5252
server_id: str,
5353
transport_options: TransportOptions,
54-
handshake_metadata: Optional[HandshakeType] = None,
54+
handshake_metadata: HandshakeType,
5555
):
5656
super().__init__(
5757
transport_id=client_id,
@@ -226,7 +226,7 @@ async def websocket_closed_callback() -> None:
226226
try:
227227
await send_transport_message(
228228
TransportMessage(
229-
from_=transport_id,
229+
from_=transport_id, # type: ignore
230230
to=to_id,
231231
streamId=stream_id,
232232
controlFlags=0,
@@ -276,7 +276,7 @@ async def _establish_handshake(
276276
transport_id: str,
277277
to_id: str,
278278
session_id: str,
279-
handshake_metadata: Optional[HandshakeType],
279+
handshake_metadata: HandshakeType,
280280
websocket: WebSocketCommonProtocol,
281281
old_session: Optional[ClientSession],
282282
) -> Tuple[

replit_river/codegen/client.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class RiverService(BaseModel):
6060

6161
class RiverSchema(BaseModel):
6262
services: Dict[str, RiverService]
63-
handshakeSchema: RiverConcreteType
63+
handshakeSchema: Optional[RiverConcreteType]
6464

6565

6666
RiverSchemaFile = RootModel[RiverSchema]
@@ -266,10 +266,13 @@ def generate_river_client_module(
266266
"",
267267
]
268268

269-
(handshake_type, handshake_chunks) = encode_type(
270-
schema_root.handshakeSchema, "HandshakeSchema"
271-
)
272-
chunks.extend(handshake_chunks)
269+
if schema_root.handshakeSchema is not None:
270+
(handshake_type, handshake_chunks) = encode_type(
271+
schema_root.handshakeSchema, "HandshakeSchema"
272+
)
273+
chunks.extend(handshake_chunks)
274+
else:
275+
handshake_type = "Literal[None]"
273276

274277
for schema_name, schema in schema_root.services.items():
275278
current_chunks: List[str] = [

replit_river/messages.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import Any, Callable, Coroutine
33

4-
import msgpack # type: ignore
4+
import msgpack
55
import websockets
66
from pydantic import ValidationError
77
from pydantic_core import ValidationError as PydanticCoreValidationError
@@ -43,12 +43,11 @@ async def send_transport_message(
4343
) -> None:
4444
logger.debug("sending a message %r to ws %s", msg, ws)
4545
try:
46-
await ws.send(
47-
prefix_bytes
48-
+ msgpack.packb(
49-
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
50-
)
46+
packed = msgpack.packb(
47+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
5148
)
49+
assert isinstance(packed, bytes)
50+
await ws.send(prefix_bytes + packed)
5251
except websockets.exceptions.ConnectionClosed as e:
5352
await websocket_closed_callback()
5453
raise WebsocketClosedException("Websocket closed during send message") from e

replit_river/server_transport.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def _send_handshake_response(
8787
response_message = TransportMessage(
8888
streamId=request_message.streamId,
8989
id=nanoid.generate(),
90-
from_=request_message.to,
90+
from_=request_message.to, # type: ignore
9191
to=request_message.from_,
9292
seq=0,
9393
ack=0,

replit_river/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ async def send_message(
373373
msg = TransportMessage(
374374
streamId=stream_id,
375375
id=nanoid.generate(),
376-
from_=self._transport_id,
376+
from_=self._transport_id, # type: ignore
377377
to=self._to_id,
378378
seq=await self._seq_manager.get_seq_and_increment(),
379379
ack=await self._seq_manager.get_ack(),

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import AsyncIterator
4-
from typing import Any, AsyncGenerator, NoReturn
4+
from typing import Any, AsyncGenerator, Literal
55

66
import nanoid # type: ignore
77
import pytest
@@ -36,7 +36,7 @@ def transport_message(
3636
) -> TransportMessage:
3737
return TransportMessage(
3838
id=str(nanoid.generate()),
39-
from_=from_,
39+
from_=from_, # type: ignore
4040
to=to,
4141
streamId=streamId,
4242
seq=seq,
@@ -139,11 +139,12 @@ async def client(
139139
) -> AsyncGenerator[Client, None]:
140140
try:
141141
async with serve(server.serve, "localhost", 8765):
142-
client: Client[NoReturn] = Client(
142+
client: Client[Literal[None]] = Client(
143143
"ws://localhost:8765",
144144
client_id="test_client",
145145
server_id="test_server",
146146
transport_options=transport_options,
147+
handshake_metadata=None,
147148
)
148149
try:
149150
yield client

0 commit comments

Comments
 (0)