3535
3636
3737if not sys .implementation .name == "circuitpython" :
38- from typing import Optional , Tuple
38+ from typing import List , Optional , Tuple
3939
4040 from circuitpython_typing .socket import (
4141 CircuitPythonSocketType ,
@@ -64,15 +64,14 @@ def connect(self, address: Tuple[str, int]) -> None:
6464 try :
6565 return self ._socket .connect (address , self ._mode )
6666 except RuntimeError as error :
67- raise OSError (errno .ENOMEM ) from error
67+ raise OSError (errno .ENOMEM , str ( error ) ) from error
6868
6969
7070class _FakeSSLContext :
7171 def __init__ (self , iface : InterfaceType ) -> None :
7272 self ._iface = iface
7373
74- # pylint: disable=unused-argument
75- def wrap_socket (
74+ def wrap_socket ( # pylint: disable=unused-argument
7675 self , socket : CircuitPythonSocketType , server_hostname : Optional [str ] = None
7776 ) -> _FakeSSLSocket :
7877 """Return the same socket"""
@@ -99,7 +98,8 @@ def create_fake_ssl_context(
9998 return _FakeSSLContext (iface )
10099
101100
102- _global_socketpool = {}
101+ _global_connection_managers = {}
102+ _global_socketpools = {}
103103_global_ssl_contexts = {}
104104
105105
@@ -113,7 +113,7 @@ def get_radio_socketpool(radio):
113113 * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing)
114114 """
115115 class_name = radio .__class__ .__name__
116- if class_name not in _global_socketpool :
116+ if class_name not in _global_socketpools :
117117 if class_name == "Radio" :
118118 import ssl # pylint: disable=import-outside-toplevel
119119
@@ -151,10 +151,10 @@ def get_radio_socketpool(radio):
151151 else :
152152 raise AttributeError (f"Unsupported radio class: { class_name } " )
153153
154- _global_socketpool [class_name ] = pool
154+ _global_socketpools [class_name ] = pool
155155 _global_ssl_contexts [class_name ] = ssl_context
156156
157- return _global_socketpool [class_name ]
157+ return _global_socketpools [class_name ]
158158
159159
160160def get_radio_ssl_context (radio ):
@@ -183,42 +183,75 @@ def __init__(
183183 ) -> None :
184184 self ._socket_pool = socket_pool
185185 # Hang onto open sockets so that we can reuse them.
186- self ._available_socket = {}
187- self ._open_sockets = {}
188-
189- def _free_sockets (self ) -> None :
190- available_sockets = []
191- for socket , free in self ._available_socket .items ():
192- if free :
193- available_sockets .append (socket )
186+ self ._available_sockets = set ()
187+ self ._key_by_managed_socket = {}
188+ self ._managed_socket_by_key = {}
194189
190+ def _free_sockets (self , force : bool = False ) -> None :
191+ # cloning lists since items are being removed
192+ available_sockets = list (self ._available_sockets )
195193 for socket in available_sockets :
196194 self .close_socket (socket )
195+ if force :
196+ open_sockets = list (self ._managed_socket_by_key .values ())
197+ for socket in open_sockets :
198+ self .close_socket (socket )
197199
198- def _get_key_for_socket (self , socket ):
200+ def _get_connected_socket ( # pylint: disable=too-many-arguments
201+ self ,
202+ addr_info : List [Tuple [int , int , int , str , Tuple [str , int ]]],
203+ host : str ,
204+ port : int ,
205+ timeout : float ,
206+ is_ssl : bool ,
207+ ssl_context : Optional [SSLContextType ] = None ,
208+ ):
199209 try :
200- return next (
201- key for key , value in self ._open_sockets .items () if value == socket
202- )
203- except StopIteration :
204- return None
210+ socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
211+ except (OSError , RuntimeError ) as exc :
212+ return exc
213+
214+ if is_ssl :
215+ socket = ssl_context .wrap_socket (socket , server_hostname = host )
216+ connect_host = host
217+ else :
218+ connect_host = addr_info [- 1 ][0 ]
219+ socket .settimeout (timeout ) # socket read timeout
220+
221+ try :
222+ socket .connect ((connect_host , port ))
223+ except (MemoryError , OSError ) as exc :
224+ socket .close ()
225+ return exc
226+
227+ return socket
228+
229+ @property
230+ def available_socket_count (self ) -> int :
231+ """Get the count of freeable open sockets"""
232+ return len (self ._available_sockets )
233+
234+ @property
235+ def managed_socket_count (self ) -> int :
236+ """Get the count of open sockets"""
237+ return len (self ._managed_socket_by_key )
205238
206239 def close_socket (self , socket : SocketType ) -> None :
207240 """Close a previously opened socket."""
208- if socket not in self ._open_sockets .values ():
241+ if socket not in self ._managed_socket_by_key .values ():
209242 raise RuntimeError ("Socket not managed" )
210- key = self ._get_key_for_socket (socket )
211243 socket .close ()
212- del self ._available_socket [socket ]
213- del self ._open_sockets [key ]
244+ key = self ._key_by_managed_socket .pop (socket )
245+ del self ._managed_socket_by_key [key ]
246+ if socket in self ._available_sockets :
247+ self ._available_sockets .remove (socket )
214248
215249 def free_socket (self , socket : SocketType ) -> None :
216250 """Mark a previously opened socket as available so it can be reused if needed."""
217- if socket not in self ._open_sockets .values ():
251+ if socket not in self ._managed_socket_by_key .values ():
218252 raise RuntimeError ("Socket not managed" )
219- self ._available_socket [ socket ] = True
253+ self ._available_sockets . add ( socket )
220254
221- # pylint: disable=too-many-branches,too-many-locals,too-many-statements
222255 def get_socket (
223256 self ,
224257 host : str ,
@@ -234,10 +267,10 @@ def get_socket(
234267 if session_id :
235268 session_id = str (session_id )
236269 key = (host , port , proto , session_id )
237- if key in self ._open_sockets :
238- socket = self ._open_sockets [key ]
239- if self ._available_socket [ socket ] :
240- self ._available_socket [ socket ] = False
270+ if key in self ._managed_socket_by_key :
271+ socket = self ._managed_socket_by_key [key ]
272+ if socket in self ._available_sockets :
273+ self ._available_sockets . remove ( socket )
241274 return socket
242275
243276 raise RuntimeError (f"Socket already connected to { proto } //{ host } :{ port } " )
@@ -253,64 +286,68 @@ def get_socket(
253286 host , port , 0 , self ._socket_pool .SOCK_STREAM
254287 )[0 ]
255288
256- try_count = 0
257- socket = None
258- last_exc = None
259- while try_count < 2 and socket is None :
260- try_count += 1
261- if try_count > 1 :
262- if any (
263- socket
264- for socket , free in self ._available_socket .items ()
265- if free is True
266- ):
267- self ._free_sockets ()
268- else :
269- break
270-
271- try :
272- socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
273- except OSError as exc :
274- last_exc = exc
275- continue
276- except RuntimeError as exc :
277- last_exc = exc
278- continue
279-
280- if is_ssl :
281- socket = ssl_context .wrap_socket (socket , server_hostname = host )
282- connect_host = host
283- else :
284- connect_host = addr_info [- 1 ][0 ]
285- socket .settimeout (timeout ) # socket read timeout
286-
287- try :
288- socket .connect ((connect_host , port ))
289- except MemoryError as exc :
290- last_exc = exc
291- socket .close ()
292- socket = None
293- except OSError as exc :
294- last_exc = exc
295- socket .close ()
296- socket = None
297-
298- if socket is None :
299- raise RuntimeError (f"Error connecting socket: { last_exc } " ) from last_exc
300-
301- self ._available_socket [socket ] = False
302- self ._open_sockets [key ] = socket
303- return socket
289+ first_exception = None
290+ result = self ._get_connected_socket (
291+ addr_info , host , port , timeout , is_ssl , ssl_context
292+ )
293+ if isinstance (result , Exception ):
294+ # Got an error, if there are any available sockets, free them and try again
295+ if self .available_socket_count :
296+ first_exception = result
297+ self ._free_sockets ()
298+ result = self ._get_connected_socket (
299+ addr_info , host , port , timeout , is_ssl , ssl_context
300+ )
301+ if isinstance (result , Exception ):
302+ last_result = f", first error: { first_exception } " if first_exception else ""
303+ raise RuntimeError (
304+ f"Error connecting socket: { result } { last_result } "
305+ ) from result
306+
307+ self ._key_by_managed_socket [result ] = key
308+ self ._managed_socket_by_key [key ] = result
309+ return result
304310
305311
306312# global helpers
307313
308314
309- _global_connection_manager = {}
315+ def connection_manager_close_all (
316+ socket_pool : Optional [SocketpoolModuleType ] = None , release_references : bool = False
317+ ) -> None :
318+ """Close all open sockets for pool"""
319+ if socket_pool :
320+ socket_pools = [socket_pool ]
321+ else :
322+ socket_pools = _global_connection_managers .keys ()
323+
324+ for pool in socket_pools :
325+ connection_manager = _global_connection_managers .get (pool , None )
326+ if connection_manager is None :
327+ raise RuntimeError ("SocketPool not managed" )
328+
329+ connection_manager ._free_sockets (force = True ) # pylint: disable=protected-access
330+
331+ if release_references :
332+ radio_key = None
333+ for radio_check , pool_check in _global_socketpools .items ():
334+ if pool == pool_check :
335+ radio_key = radio_check
336+ break
337+
338+ if radio_key :
339+ if radio_key in _global_socketpools :
340+ del _global_socketpools [radio_key ]
341+
342+ if radio_key in _global_ssl_contexts :
343+ del _global_ssl_contexts [radio_key ]
344+
345+ if pool in _global_connection_managers :
346+ del _global_connection_managers [pool ]
310347
311348
312349def get_connection_manager (socket_pool : SocketpoolModuleType ) -> ConnectionManager :
313350 """Get the ConnectionManager singleton for the given pool"""
314- if socket_pool not in _global_connection_manager :
315- _global_connection_manager [socket_pool ] = ConnectionManager (socket_pool )
316- return _global_connection_manager [socket_pool ]
351+ if socket_pool not in _global_connection_managers :
352+ _global_connection_managers [socket_pool ] = ConnectionManager (socket_pool )
353+ return _global_connection_managers [socket_pool ]
0 commit comments