|
5 | 5 | import os |
6 | 6 | import time |
7 | 7 | import uuid |
| 8 | +from multiprocessing import Manager |
| 9 | +from multiprocessing.managers import SyncManager |
8 | 10 | from typing import Any, Dict, Optional |
9 | 11 |
|
10 | 12 | from .rp_logger import RunPodLogger |
@@ -61,82 +63,149 @@ def __str__(self) -> str: |
61 | 63 | # ---------------------------------------------------------------------------- # |
62 | 64 | # Tracker # |
63 | 65 | # ---------------------------------------------------------------------------- # |
64 | | -class JobsProgress(set): |
65 | | - """Track the state of current jobs in progress.""" |
66 | | - |
67 | | - _instance = None |
| 66 | +class JobsProgress: |
| 67 | + """Track the state of current jobs in progress using shared memory.""" |
| 68 | + |
| 69 | + _instance: Optional['JobsProgress'] = None |
| 70 | + _manager: SyncManager |
| 71 | + _shared_data: Any |
| 72 | + _lock: Any |
68 | 73 |
|
69 | 74 | def __new__(cls): |
70 | | - if JobsProgress._instance is None: |
71 | | - JobsProgress._instance = set.__new__(cls) |
72 | | - return JobsProgress._instance |
| 75 | + if cls._instance is None: |
| 76 | + instance = object.__new__(cls) |
| 77 | + # Initialize instance variables |
| 78 | + instance._manager = Manager() |
| 79 | + instance._shared_data = instance._manager.dict() |
| 80 | + instance._shared_data['jobs'] = instance._manager.list() |
| 81 | + instance._lock = instance._manager.Lock() |
| 82 | + cls._instance = instance |
| 83 | + return cls._instance |
| 84 | + |
| 85 | + def __init__(self): |
| 86 | + # Everything is already initialized in __new__ |
| 87 | + pass |
73 | 88 |
|
74 | 89 | def __repr__(self) -> str: |
75 | 90 | return f"<{self.__class__.__name__}>: {self.get_job_list()}" |
76 | 91 |
|
77 | 92 | def clear(self) -> None: |
78 | | - return super().clear() |
| 93 | + with self._lock: |
| 94 | + self._shared_data['jobs'][:] = [] |
79 | 95 |
|
80 | 96 | def add(self, element: Any): |
81 | 97 | """ |
82 | 98 | Adds a Job object to the set. |
| 99 | + """ |
| 100 | + if isinstance(element, str): |
| 101 | + job_dict = {'id': element} |
| 102 | + elif isinstance(element, dict): |
| 103 | + job_dict = element |
| 104 | + elif hasattr(element, 'id'): |
| 105 | + job_dict = {'id': element.id} |
| 106 | + else: |
| 107 | + raise TypeError("Only Job objects can be added to JobsProgress.") |
83 | 108 |
|
84 | | - If the added element is a string, then `Job(id=element)` is added |
| 109 | + with self._lock: |
| 110 | + # Check if job already exists |
| 111 | + job_list = self._shared_data['jobs'] |
| 112 | + for existing_job in job_list: |
| 113 | + if existing_job['id'] == job_dict['id']: |
| 114 | + return # Job already exists |
| 115 | + |
| 116 | + # Add new job |
| 117 | + job_list.append(job_dict) |
| 118 | + log.debug(f"JobsProgress | Added job: {job_dict['id']}") |
| 119 | + |
| 120 | + def get(self, element: Any) -> Optional[Job]: |
| 121 | + """ |
| 122 | + Retrieves a Job object from the set. |
85 | 123 | |
86 | | - If the added element is a dict, that `Job(**element)` is added |
| 124 | + If the element is a string, searches for Job with that id. |
87 | 125 | """ |
88 | 126 | if isinstance(element, str): |
89 | | - element = Job(id=element) |
90 | | - |
91 | | - if isinstance(element, dict): |
92 | | - element = Job(**element) |
93 | | - |
94 | | - if not isinstance(element, Job): |
95 | | - raise TypeError("Only Job objects can be added to JobsProgress.") |
| 127 | + search_id = element |
| 128 | + elif isinstance(element, Job): |
| 129 | + search_id = element.id |
| 130 | + else: |
| 131 | + raise TypeError("Only Job objects can be retrieved from JobsProgress.") |
96 | 132 |
|
97 | | - return super().add(element) |
| 133 | + with self._lock: |
| 134 | + for job_dict in self._shared_data['jobs']: |
| 135 | + if job_dict['id'] == search_id: |
| 136 | + log.debug(f"JobsProgress | Retrieved job: {job_dict['id']}") |
| 137 | + return Job(**job_dict) |
| 138 | + |
| 139 | + return None |
98 | 140 |
|
99 | 141 | def remove(self, element: Any): |
100 | 142 | """ |
101 | 143 | Removes a Job object from the set. |
102 | | -
|
103 | | - If the element is a string, then `Job(id=element)` is removed |
104 | | - |
105 | | - If the element is a dict, then `Job(**element)` is removed |
106 | 144 | """ |
107 | 145 | if isinstance(element, str): |
108 | | - element = Job(id=element) |
109 | | - |
110 | | - if isinstance(element, dict): |
111 | | - element = Job(**element) |
112 | | - |
113 | | - if not isinstance(element, Job): |
| 146 | + job_id = element |
| 147 | + elif isinstance(element, dict): |
| 148 | + job_id = element.get('id') |
| 149 | + elif hasattr(element, 'id'): |
| 150 | + job_id = element.id |
| 151 | + else: |
114 | 152 | raise TypeError("Only Job objects can be removed from JobsProgress.") |
115 | 153 |
|
116 | | - return super().discard(element) |
117 | | - |
118 | | - def get(self, element: Any) -> Job: |
119 | | - if isinstance(element, str): |
120 | | - element = Job(id=element) |
121 | | - |
122 | | - if not isinstance(element, Job): |
123 | | - raise TypeError("Only Job objects can be retrieved from JobsProgress.") |
124 | | - |
125 | | - for job in self: |
126 | | - if job == element: |
127 | | - return job |
| 154 | + with self._lock: |
| 155 | + job_list = self._shared_data['jobs'] |
| 156 | + # Find and remove the job |
| 157 | + for i, job_dict in enumerate(job_list): |
| 158 | + if job_dict['id'] == job_id: |
| 159 | + del job_list[i] |
| 160 | + log.debug(f"JobsProgress | Removed job: {job_dict['id']}") |
| 161 | + break |
128 | 162 |
|
129 | | - def get_job_list(self) -> str: |
| 163 | + def get_job_list(self) -> Optional[str]: |
130 | 164 | """ |
131 | 165 | Returns the list of job IDs as comma-separated string. |
132 | 166 | """ |
133 | | - if not len(self): |
| 167 | + with self._lock: |
| 168 | + job_list = list(self._shared_data['jobs']) |
| 169 | + |
| 170 | + if not job_list: |
134 | 171 | return None |
135 | 172 |
|
136 | | - return ",".join(str(job) for job in self) |
| 173 | + log.debug(f"JobsProgress | Jobs in progress: {job_list}") |
| 174 | + return ",".join(str(job_dict['id']) for job_dict in job_list) |
137 | 175 |
|
138 | 176 | def get_job_count(self) -> int: |
139 | 177 | """ |
140 | 178 | Returns the number of jobs. |
141 | 179 | """ |
142 | | - return len(self) |
| 180 | + with self._lock: |
| 181 | + return len(self._shared_data['jobs']) |
| 182 | + |
| 183 | + def __iter__(self): |
| 184 | + """Make the class iterable - returns Job objects""" |
| 185 | + with self._lock: |
| 186 | + # Create a snapshot of jobs to avoid holding lock during iteration |
| 187 | + job_dicts = list(self._shared_data['jobs']) |
| 188 | + |
| 189 | + # Return an iterator of Job objects |
| 190 | + return iter(Job(**job_dict) for job_dict in job_dicts) |
| 191 | + |
| 192 | + def __len__(self): |
| 193 | + """Support len() operation""" |
| 194 | + return self.get_job_count() |
| 195 | + |
| 196 | + def __contains__(self, element: Any) -> bool: |
| 197 | + """Support 'in' operator""" |
| 198 | + if isinstance(element, str): |
| 199 | + search_id = element |
| 200 | + elif isinstance(element, Job): |
| 201 | + search_id = element.id |
| 202 | + elif isinstance(element, dict): |
| 203 | + search_id = element.get('id') |
| 204 | + else: |
| 205 | + return False |
| 206 | + |
| 207 | + with self._lock: |
| 208 | + for job_dict in self._shared_data['jobs']: |
| 209 | + if job_dict['id'] == search_id: |
| 210 | + return True |
| 211 | + return False |
0 commit comments