-
Notifications
You must be signed in to change notification settings - Fork 5
Close all and counts #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,7 @@ | |
|
|
||
|
|
||
| if not sys.implementation.name == "circuitpython": | ||
| from typing import Optional, Tuple | ||
| from typing import List, Optional, Tuple | ||
|
|
||
| from circuitpython_typing.socket import ( | ||
| CircuitPythonSocketType, | ||
|
|
@@ -71,8 +71,7 @@ class _FakeSSLContext: | |
| def __init__(self, iface: InterfaceType) -> None: | ||
| self._iface = iface | ||
|
|
||
| # pylint: disable=unused-argument | ||
| def wrap_socket( | ||
| def wrap_socket( # pylint: disable=unused-argument | ||
| self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None | ||
| ) -> _FakeSSLSocket: | ||
| """Return the same socket""" | ||
|
|
@@ -99,7 +98,8 @@ def create_fake_ssl_context( | |
| return _FakeSSLContext(iface) | ||
|
|
||
|
|
||
| _global_socketpool = {} | ||
| _global_connection_managers = {} | ||
| _global_socketpools = {} | ||
| _global_ssl_contexts = {} | ||
|
|
||
|
|
||
|
|
@@ -113,7 +113,7 @@ def get_radio_socketpool(radio): | |
| * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) | ||
| """ | ||
| class_name = radio.__class__.__name__ | ||
| if class_name not in _global_socketpool: | ||
| if class_name not in _global_socketpools: | ||
| if class_name == "Radio": | ||
| import ssl # pylint: disable=import-outside-toplevel | ||
|
|
||
|
|
@@ -151,10 +151,10 @@ def get_radio_socketpool(radio): | |
| else: | ||
| raise AttributeError(f"Unsupported radio class: {class_name}") | ||
|
|
||
| _global_socketpool[class_name] = pool | ||
| _global_socketpools[class_name] = pool | ||
| _global_ssl_contexts[class_name] = ssl_context | ||
|
|
||
| return _global_socketpool[class_name] | ||
| return _global_socketpools[class_name] | ||
|
|
||
|
|
||
| def get_radio_ssl_context(radio): | ||
|
|
@@ -183,42 +183,75 @@ def __init__( | |
| ) -> None: | ||
| self._socket_pool = socket_pool | ||
| # Hang onto open sockets so that we can reuse them. | ||
| self._available_socket = {} | ||
| self._open_sockets = {} | ||
|
|
||
| def _free_sockets(self) -> None: | ||
| available_sockets = [] | ||
| for socket, free in self._available_socket.items(): | ||
| if free: | ||
| available_sockets.append(socket) | ||
| self._available_sockets = set() | ||
| self._managed_socket_by_key = {} | ||
| self._managed_socket_by_socket = {} | ||
dhalbert marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def _free_sockets(self, force: bool = False) -> None: | ||
| # cloning lists since items are being removed | ||
| available_sockets = list(self._available_sockets) | ||
| for socket in available_sockets: | ||
| self.close_socket(socket) | ||
| if force: | ||
| open_sockets = list(self._managed_socket_by_key.values()) | ||
| for socket in open_sockets: | ||
| self.close_socket(socket) | ||
|
|
||
| def _get_key_for_socket(self, socket): | ||
| def _get_connected_socket( # pylint: disable=too-many-arguments | ||
| self, | ||
| addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], | ||
| host: str, | ||
| port: int, | ||
| timeout: float, | ||
| is_ssl: bool, | ||
| ssl_context: Optional[SSLContextType] = None, | ||
| ): | ||
| try: | ||
| return next( | ||
| key for key, value in self._open_sockets.items() if value == socket | ||
| ) | ||
| except StopIteration: | ||
| return None | ||
| socket = self._socket_pool.socket(addr_info[0], addr_info[1]) | ||
| except (OSError, RuntimeError) as exc: | ||
| return exc | ||
|
|
||
| if is_ssl: | ||
| socket = ssl_context.wrap_socket(socket, server_hostname=host) | ||
| connect_host = host | ||
| else: | ||
| connect_host = addr_info[-1][0] | ||
| socket.settimeout(timeout) # socket read timeout | ||
|
|
||
| try: | ||
| socket.connect((connect_host, port)) | ||
| except (MemoryError, OSError) as exc: | ||
| socket.close() | ||
| return exc | ||
|
|
||
| return socket | ||
|
|
||
| @property | ||
| def available_socket_count(self) -> int: | ||
| """Get the count of freeable open sockets""" | ||
| return len(self._available_sockets) | ||
|
|
||
| @property | ||
| def managed_socket_count(self) -> int: | ||
| """Get the count of open sockets""" | ||
| return len(self._managed_socket_by_key) | ||
|
|
||
| def close_socket(self, socket: SocketType) -> None: | ||
| """Close a previously opened socket.""" | ||
| if socket not in self._open_sockets.values(): | ||
| if socket not in self._managed_socket_by_key.values(): | ||
| raise RuntimeError("Socket not managed") | ||
| key = self._get_key_for_socket(socket) | ||
| socket.close() | ||
| del self._available_socket[socket] | ||
| del self._open_sockets[key] | ||
| key = self._managed_socket_by_socket.pop(socket) | ||
| del self._managed_socket_by_key[key] | ||
| if socket in self._available_sockets: | ||
| self._available_sockets.remove(socket) | ||
|
|
||
| def free_socket(self, socket: SocketType) -> None: | ||
| """Mark a previously opened socket as available so it can be reused if needed.""" | ||
| if socket not in self._open_sockets.values(): | ||
| if socket not in self._managed_socket_by_key.values(): | ||
| raise RuntimeError("Socket not managed") | ||
| self._available_socket[socket] = True | ||
| self._available_sockets.add(socket) | ||
|
|
||
| # pylint: disable=too-many-branches,too-many-locals,too-many-statements | ||
| def get_socket( | ||
| self, | ||
| host: str, | ||
|
|
@@ -234,10 +267,10 @@ def get_socket( | |
| if session_id: | ||
| session_id = str(session_id) | ||
| key = (host, port, proto, session_id) | ||
| if key in self._open_sockets: | ||
| socket = self._open_sockets[key] | ||
| if self._available_socket[socket]: | ||
| self._available_socket[socket] = False | ||
| if key in self._managed_socket_by_key: | ||
| socket = self._managed_socket_by_key[key] | ||
| if socket in self._available_sockets: | ||
| self._available_sockets.remove(socket) | ||
| return socket | ||
|
|
||
| raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") | ||
|
|
@@ -253,64 +286,63 @@ def get_socket( | |
| host, port, 0, self._socket_pool.SOCK_STREAM | ||
| )[0] | ||
|
|
||
| try_count = 0 | ||
| socket = None | ||
| last_exc = None | ||
| while try_count < 2 and socket is None: | ||
| try_count += 1 | ||
| if try_count > 1: | ||
| if any( | ||
| socket | ||
| for socket, free in self._available_socket.items() | ||
| if free is True | ||
| ): | ||
| self._free_sockets() | ||
| else: | ||
| break | ||
| result = self._get_connected_socket( | ||
| addr_info, host, port, timeout, is_ssl, ssl_context | ||
| ) | ||
| if isinstance(result, Exception): | ||
| # Got an error, if there are any available sockets, free them and try again | ||
| if self.available_socket_count: | ||
| self._free_sockets() | ||
| result = self._get_connected_socket( | ||
| addr_info, host, port, timeout, is_ssl, ssl_context | ||
| ) | ||
| if isinstance(result, Exception): | ||
| raise RuntimeError(f"Error connecting socket: {result}") from result | ||
|
||
|
|
||
| try: | ||
| socket = self._socket_pool.socket(addr_info[0], addr_info[1]) | ||
| except OSError as exc: | ||
| last_exc = exc | ||
| continue | ||
| except RuntimeError as exc: | ||
| last_exc = exc | ||
| continue | ||
|
|
||
| if is_ssl: | ||
| socket = ssl_context.wrap_socket(socket, server_hostname=host) | ||
| connect_host = host | ||
| else: | ||
| connect_host = addr_info[-1][0] | ||
| socket.settimeout(timeout) # socket read timeout | ||
|
|
||
| try: | ||
| socket.connect((connect_host, port)) | ||
| except MemoryError as exc: | ||
| last_exc = exc | ||
| socket.close() | ||
| socket = None | ||
| except OSError as exc: | ||
| last_exc = exc | ||
| socket.close() | ||
| socket = None | ||
|
|
||
| if socket is None: | ||
| raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc | ||
|
|
||
| self._available_socket[socket] = False | ||
| self._open_sockets[key] = socket | ||
| return socket | ||
| self._managed_socket_by_key[key] = result | ||
| self._managed_socket_by_socket[result] = key | ||
| return result | ||
|
|
||
|
|
||
| # global helpers | ||
|
|
||
|
|
||
| _global_connection_manager = {} | ||
| def connection_manager_close_all( | ||
| socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False | ||
| ) -> None: | ||
| """Close all open sockets for pool""" | ||
| if socket_pool: | ||
| socket_pools = [socket_pool] | ||
| else: | ||
| socket_pools = _global_connection_managers.keys() | ||
|
|
||
| for pool in socket_pools: | ||
| connection_manager = _global_connection_managers.get(pool, None) | ||
| if connection_manager is None: | ||
| raise RuntimeError("SocketPool not managed") | ||
|
|
||
| connection_manager._free_sockets(force=True) # pylint: disable=protected-access | ||
|
|
||
| if release_references: | ||
| radio_key = None | ||
| for radio_check, pool_check in _global_socketpools.items(): | ||
| if pool == pool_check: | ||
| radio_key = radio_check | ||
| break | ||
|
|
||
| if radio_key: | ||
| if radio_key in _global_socketpools: | ||
| del _global_socketpools[radio_key] | ||
|
|
||
| if radio_key in _global_ssl_contexts: | ||
| del _global_ssl_contexts[radio_key] | ||
|
|
||
| if pool in _global_connection_managers: | ||
| del _global_connection_managers[pool] | ||
|
|
||
|
|
||
| def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: | ||
| """Get the ConnectionManager singleton for the given pool""" | ||
| if socket_pool not in _global_connection_manager: | ||
| _global_connection_manager[socket_pool] = ConnectionManager(socket_pool) | ||
| return _global_connection_manager[socket_pool] | ||
| if socket_pool not in _global_connection_managers: | ||
| _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) | ||
| return _global_connection_managers[socket_pool] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,14 +24,38 @@ | |
|
|
||
| # get request session | ||
| requests = adafruit_requests.Session(pool, ssl_context) | ||
| connection_manager = adafruit_connection_manager.get_connection_manager(pool) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output of this is so: |
||
| print("-" * 40) | ||
| print("Nothing yet opened") | ||
| print(f"Open Sockets: {connection_manager.managed_socket_count}") | ||
| print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") | ||
|
|
||
| # make request | ||
| print("-" * 40) | ||
| print(f"Fetching from {TEXT_URL}") | ||
| print(f"Fetching from {TEXT_URL} in a context handler") | ||
| with requests.get(TEXT_URL) as response: | ||
| response_text = response.text | ||
| print(f"Text Response {response_text}") | ||
|
|
||
| print("-" * 40) | ||
| print("1 request, opened and freed") | ||
| print(f"Open Sockets: {connection_manager.managed_socket_count}") | ||
| print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") | ||
|
|
||
| print("-" * 40) | ||
| print(f"Fetching from {TEXT_URL} not in a context handler") | ||
| response = requests.get(TEXT_URL) | ||
| response_text = response.text | ||
| response.close() | ||
|
|
||
| print(f"Text Response {response_text}") | ||
| print("-" * 40) | ||
| print("1 request, opened but not freed") | ||
| print(f"Open Sockets: {connection_manager.managed_socket_count}") | ||
| print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") | ||
|
|
||
| print("-" * 40) | ||
| print("Closing everything in the pool") | ||
| adafruit_connection_manager.connection_manager_close_all(pool) | ||
|
|
||
| print("-" * 40) | ||
| print("Everything closed") | ||
| print(f"Open Sockets: {connection_manager.managed_socket_count}") | ||
| print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") | ||
Uh oh!
There was an error while loading. Please reload this page.