Skip to content

Commit

Permalink
Add a stream type factory
Browse files Browse the repository at this point in the history
  • Loading branch information
goodboy committed Sep 8, 2021
1 parent 28d6720 commit 3eb376b
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions tractor/_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import platform
import struct
import typing
from typing import Any, Tuple, Optional
from typing import Any, Tuple, Optional, Type

from tricycle import BufferedReceiveStream
import msgpack
Expand Down Expand Up @@ -55,6 +55,7 @@ async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
unpacker = msgpack.Unpacker(
raw=False,
use_list=False,
strict_map_key=False
)
while True:
try:
Expand Down Expand Up @@ -130,12 +131,12 @@ def __init__(
prefix_size: int = 4,

) -> None:
import msgspec

super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size

import msgspec

# TODO: struct aware messaging coders
self.encode = msgspec.Encoder().encode
self.decode = msgspec.Decoder().decode # dict[str, Any])
Expand Down Expand Up @@ -185,7 +186,7 @@ async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
# channel will raise an expected error and bubble up.
log.error(f'`msgspec` failed to decode!?')
log.error('`msgspec` failed to decode!?')
last_decode_failed = True

async def send(self, data: Any) -> None:
Expand All @@ -200,11 +201,21 @@ async def send(self, data: Any) -> None:
return await self.stream.send_all(size + bytes_data)


def get_serializer_stream_type(
name: str,
) -> Type:
return {
'msgpack': MsgpackTCPStream,
'msgspec': MsgspecTCPStream,
}[name]


class Channel:
"""An inter-process channel for communication between (remote) actors.
'''An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``.
"""
'''
def __init__(

self,
Expand All @@ -218,17 +229,17 @@ def __init__(
self._recon_seq = on_reconnect
self._autorecon = auto_reconnect

stream_serializer_type = MsgpackTCPStream

# TODO: maybe expose this through the nursery api?
try:
# if installed load the msgspec transport since it's faster
import msgspec # noqa
stream_serializer_type = MsgspecTCPStream
serializer = 'msgspec'
except ImportError:
pass
serializer = 'msgpack'

self.stream_serializer_type = stream_serializer_type
self.msgstream = stream_serializer_type(stream) if stream else None
self.stream_serializer_type = get_serializer_stream_type(serializer)
self.msgstream = self.stream_serializer_type(
stream) if stream else None

if self.msgstream and destaddr:
raise ValueError(
Expand Down

0 comments on commit 3eb376b

Please sign in to comment.