Skip to content

Commit

Permalink
🐛 Fix async pending error
Browse files Browse the repository at this point in the history
  • Loading branch information
holegots committed May 28, 2024
1 parent 30deec7 commit b1268ff
Showing 1 changed file with 77 additions and 90 deletions.
167 changes: 77 additions & 90 deletions thriftpy2/contrib/aio/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@
from thriftpy2.transport._ssl import (
create_thriftpy_context,
RESTRICTED_SERVER_CIPHERS,
DEFAULT_CIPHERS
DEFAULT_CIPHERS,
)

MAC_OR_BSD = sys.platform == 'darwin' or sys.platform.startswith('freebsd')
MAC_OR_BSD = sys.platform == "darwin" or sys.platform.startswith("freebsd")


class TAsyncSocket(object):
"""Socket implementation for client side."""

def __init__(self, host=None, port=None, unix_socket=None,
sock=None, socket_family=socket.AF_INET,
socket_timeout=3000, connect_timeout=None,
ssl_context=None, validate=True,
cafile=None, capath=None, certfile=None, keyfile=None,
ciphers=DEFAULT_CIPHERS):
def __init__(
self,
host=None,
port=None,
unix_socket=None,
sock=None,
socket_family=socket.AF_INET,
socket_timeout=3000,
connect_timeout=None,
ssl_context=None,
validate=True,
cafile=None,
capath=None,
certfile=None,
keyfile=None,
ciphers=DEFAULT_CIPHERS,
):
"""Initialize a TSocket
TSocket can be initialized in 3 ways:
Expand Down Expand Up @@ -81,20 +92,21 @@ def __init__(self, host=None, port=None, unix_socket=None,

self.socket_family = socket_family
self.socket_timeout = socket_timeout / 1000 if socket_timeout else None
self.connect_timeout = connect_timeout / 1000 if connect_timeout \
else self.socket_timeout
self.connect_timeout = (
connect_timeout / 1000 if connect_timeout else self.socket_timeout
)

if ssl_context:
self.ssl_context = ssl_context
self.server_hostname = host
elif certfile or keyfile:
self.server_hostname = host
self.ssl_context = create_thriftpy_context(server_side=False,
ciphers=ciphers)
self.ssl_context = create_thriftpy_context(
server_side=False, ciphers=ciphers
)

if cafile or capath:
self.ssl_context.load_verify_locations(cafile=cafile,
capath=capath)
self.ssl_context.load_verify_locations(cafile=cafile, capath=capath)

if certfile:
self.ssl_context.load_cert_chain(certfile, keyfile=keyfile)
Expand All @@ -106,85 +118,52 @@ def __init__(self, host=None, port=None, unix_socket=None,
self.ssl_context = None
self.server_hostname = None

def _init_sock(self):
if self.unix_socket:
_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
_sock = socket.socket(self.socket_family, socket.SOCK_STREAM)
_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

# socket options
linger = struct.pack('ii', 0, 0)
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
_sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

self.raw_sock = _sock

def set_handle(self, sock):
self.raw_sock = sock

def set_timeout(self, ms):
"""Backward compat api, will bind the timeout to both connect_timeout
and socket_timeout.
"""
self.socket_timeout = ms / 1000 if (ms and ms > 0) else None
self.connect_timeout = self.socket_timeout

if self.raw_sock is not None:
self.raw_sock.settimeout(self.socket_timeout)

def is_open(self):
return bool(self.raw_sock)

async def open(self):
self._init_sock()

addr = self.unix_socket or (self.host, self.port)

try:
if self.connect_timeout:
self.raw_sock.settimeout(self.connect_timeout)

self.raw_sock.connect(addr)

if self.socket_timeout:
self.raw_sock.settimeout(self.socket_timeout)

kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context}
if self.server_hostname:
kwargs['server_hostname'] = self.server_hostname

self.reader, self.writer = await asyncio.wait_for(
self.sock_factory(**kwargs),
self.socket_timeout
)
if self.unix_socket:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_unix_connection(addr), self.connect_timeout
)
else:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(self.host, self.port, ssl=self.ssl_context),
self.connect_timeout,
)
sock = self.writer.get_extra_info("socket")
# socket options
linger = struct.pack("ii", 0, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

except (socket.error, OSError):
raise TTransportException(
type=TTransportException.NOT_OPEN,
message="Could not connect to %s" % str(addr))
message="Could not connect to %s" % str(addr),
)

async def read(self, sz):
try:
buff = await asyncio.wait_for(
self.reader.read(sz),
self.connect_timeout
buff = await asyncio.wait_for(self.reader.read(sz), self.connect_timeout)
except asyncio.TimeoutError:
raise TTransportException(
type=TTransportException.TIMED_OUT, message="TSocket read timed out"
)
except asyncio.IncompleteReadError as e:
raise TTransportException(
type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes"
)
except socket.error as e:
if e.errno == errno.ECONNRESET and MAC_OR_BSD:
# freebsd and Mach don't follow POSIX semantic of recv
# and fail with ECONNRESET if peer performed shutdown.
# See corresponding comment and code in TSocket::read()
# in lib/cpp/src/transport/TSocket.cpp.
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
buff = ''
buff = ""
else:
raise

if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes')
raise TTransportException(
type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes"
)
return buff

def write(self, buff):
Expand All @@ -199,7 +178,6 @@ def close(self):

try:
self.writer.close()
self.raw_sock.close()
self.raw_sock = None
except (socket.error, OSError):
pass
Expand All @@ -208,10 +186,19 @@ def close(self):
class TAsyncServerSocket(object):
"""Socket implementation for server side."""

def __init__(self, host=None, port=None, unix_socket=None,
socket_family=socket.AF_INET, client_timeout=3000,
backlog=128, ssl_context=None, certfile=None, keyfile=None,
ciphers=RESTRICTED_SERVER_CIPHERS):
def __init__(
self,
host=None,
port=None,
unix_socket=None,
socket_family=socket.AF_INET,
client_timeout=3000,
backlog=128,
ssl_context=None,
certfile=None,
keyfile=None,
ciphers=RESTRICTED_SERVER_CIPHERS,
):
"""Initialize a TServerSocket
TSocket can be initialized in 2 ways:
Expand Down Expand Up @@ -251,10 +238,11 @@ def __init__(self, host=None, port=None, unix_socket=None,
self.ssl_context = ssl_context
elif certfile:
if not os.access(certfile, os.R_OK):
raise IOError('No such certfile found: %s' % certfile)
raise IOError("No such certfile found: %s" % certfile)

self.ssl_context = create_thriftpy_context(server_side=True,
ciphers=ciphers)
self.ssl_context = create_thriftpy_context(
server_side=True, ciphers=ciphers
)
self.ssl_context.load_cert_chain(certfile, keyfile=keyfile)
else:
self.ssl_context = None
Expand Down Expand Up @@ -294,17 +282,15 @@ async def accept(self, callback):
server = await self.sock_factory(
self._create_client_connected_cb(callback),
sock=self.raw_sock,
ssl=self.ssl_context
ssl=self.ssl_context,
)
return server

def _create_client_connected_cb(self, callback):

async def client_connected_cb(reader, writer):
try:
await asyncio.wait_for(
callback(StreamHandler(reader, writer)),
self.client_timeout
callback(StreamHandler(reader, writer)), self.client_timeout
)
except asyncio.exceptions.TimeoutError:
writer.close()
Expand Down Expand Up @@ -337,13 +323,14 @@ async def read(self, sz):
# in lib/cpp/src/transport/TSocket.cpp.
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
buff = ''
buff = ""
else:
raise

if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes')
raise TTransportException(
type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes"
)
return buff

def write(self, buff):
Expand Down

0 comments on commit b1268ff

Please sign in to comment.