1
+ import os
1
2
import asyncio
2
3
from exo .networking .discovery import Discovery
3
- from typing import Dict , List , Callable
4
+ from typing import Dict , List , Callable , Optional
4
5
5
6
from exo .topology .device_capabilities import DeviceCapabilities
6
7
from exo .networking .manual .network_topology_config import NetworkTopology , PeerConfig
@@ -15,28 +16,24 @@ def __init__(
15
16
node_id : str ,
16
17
create_peer_handle : Callable [[str , str , DeviceCapabilities ], PeerHandle ],
17
18
):
18
- self .topology = NetworkTopology .from_path (network_config_path )
19
19
self .network_config_path = network_config_path
20
20
self .node_id = node_id
21
21
self .create_peer_handle = create_peer_handle
22
22
23
- if node_id not in self .topology .peers :
24
- raise ValueError (
25
- f"Node ID { node_id } not found in network config file { network_config_path } . Please run with `node_id` set to one of the keys in the config file: { [k for k , _ in self .topology .peers ]} "
26
- )
27
-
28
23
self .listen_task = None
29
-
24
+ self . cleanup_task = None
30
25
self .known_peers : Dict [str , PeerHandle ] = {}
31
- self .peers_in_network : Dict [str , PeerConfig ] = self .topology .peers
32
- self .peers_in_network .pop (node_id )
26
+
27
+ self ._cached_peers : Dict [str , PeerConfig ] = {}
28
+ self ._last_modified_time : Optional [float ] = None
33
29
34
30
async def start (self ) -> None :
35
31
self .listen_task = asyncio .create_task (self .task_find_peers_from_config ())
32
+ self .cleanup_task = asyncio .create_task (self .task_clean_up_peers_from_config ())
36
33
37
34
async def stop (self ) -> None :
38
- if self .listen_task :
39
- self .listen_task .cancel ()
35
+ if self .listen_task : self . listen_task . cancel ()
36
+ if self . cleanup_task : self .cleanup_task .cancel ()
40
37
41
38
async def discover_peers (self , wait_for_peers : int = 0 ) -> List [PeerHandle ]:
42
39
if wait_for_peers > 0 :
@@ -49,7 +46,7 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
49
46
async def task_find_peers_from_config (self ):
50
47
if DEBUG_DISCOVERY >= 2 : print ("Starting task to find peers from config..." )
51
48
while True :
52
- for peer_id , peer_config in self .peers_in_network .items ():
49
+ for peer_id , peer_config in self ._get_peers () .items ():
53
50
try :
54
51
if DEBUG_DISCOVERY >= 2 : print (f"Checking peer { peer_id = } at { peer_config .address } :{ peer_config .port } " )
55
52
peer = self .known_peers .get (peer_id )
@@ -72,3 +69,44 @@ async def task_find_peers_from_config(self):
72
69
73
70
if DEBUG_DISCOVERY >= 2 : print (f"Current known peers: { [peer .id () for peer in self .known_peers .values ()]} " )
74
71
72
+ async def task_clean_up_peers_from_config (self ):
73
+ if DEBUG_DISCOVERY >= 2 : print ("Starting task to clean up peers from config..." )
74
+ while True :
75
+ peers_from_config = self ._get_peers ()
76
+ if peers_from_config :
77
+ peers_to_remove = [peer for peer in self .known_peers .keys () if peer not in peers_from_config ]
78
+
79
+ for peer in peers_to_remove :
80
+ if DEBUG_DISCOVERY >= 2 : print (f"{ peer } is no longer found in the config but is currently a known peer. Removing from known peers..." )
81
+ try : del self .known_peers [peer ]
82
+ except KeyError : pass
83
+
84
+ await asyncio .sleep (1.0 )
85
+
86
+ def _get_peers (self ):
87
+ try :
88
+ current_mtime = os .path .getmtime (self .network_config_path )
89
+
90
+ if self ._cached_peers is not None and self ._last_modified_time is not None and current_mtime <= self ._last_modified_time :
91
+ return self ._cached_peers
92
+
93
+ topology = NetworkTopology .from_path (self .network_config_path )
94
+
95
+ if self .node_id not in topology .peers :
96
+ raise ValueError (
97
+ f"Node ID { self .node_id } not found in network config file "
98
+ f"{ self .network_config_path } . Please run with `node_id` set to "
99
+ f"one of the keys in the config file: { [k for k , _ in topology .peers ]} "
100
+ )
101
+
102
+ peers_in_network : Dict [str , PeerConfig ] = topology .peers
103
+ peers_in_network .pop (self .node_id )
104
+
105
+ self ._cached_peers = peers_in_network
106
+ self ._last_modified_time = current_mtime
107
+
108
+ return peers_in_network
109
+
110
+ except Exception as e :
111
+ if DEBUG_DISCOVERY >= 2 : print (f"Error when loading network config file from { self .network_config_path } . Please update the config file in order to successfully discover peers. Exception: { e } " )
112
+ return self ._cached_peers
0 commit comments