13
13
# limitations under the License.
14
14
15
15
import collections
16
+ import contextlib
16
17
import itertools
17
18
from typing import Callable , Dict , Iterable , Optional , Union
18
19
@@ -108,6 +109,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam):
108
109
but requires larger memory than distributing within all
109
110
ranks, especially for pure data parallel models.
110
111
(default: False).
112
+ lock_timeout (float, optional): timeout for callback mutex in
113
+ seconds.
111
114
**kwargs: keyword arguments to pass to Apex
112
115
DistributedFusedAdam.
113
116
@@ -118,6 +121,7 @@ def __init__(
118
121
params : Union [Iterable [torch .nn .Parameter ], Iterable [dict ]],
119
122
disable_distributed_parameters : bool = False ,
120
123
distribute_within_nodes : bool = False ,
124
+ lock_timeout : Optional [float ] = None ,
121
125
** kwargs ,
122
126
):
123
127
@@ -152,6 +156,25 @@ def __init__(
152
156
# Construct distributed optimizer
153
157
super ().__init__ (param_groups , ** kwargs )
154
158
159
+ # Create mutex with timeout
160
+ self ._lock_with_timeout = None
161
+ if lock_timeout is not None :
162
+
163
+ @contextlib .contextmanager
164
+ def lock_with_timeout ():
165
+ result = self ._lock .acquire (timeout = lock_timeout )
166
+ try :
167
+ yield result
168
+ finally :
169
+ if result :
170
+ # Acquired lock before timeout
171
+ self ._lock .release ()
172
+ else :
173
+ # Failed to acquire lock before timeout
174
+ print (f'MegatronDistributedFusedAdam: Failed to acquire lock within { lock_timeout } seconds.' )
175
+
176
+ self ._lock_with_timeout = lock_with_timeout
177
+
155
178
def _broadcast_params (self ) -> None :
156
179
# Assume params have already been synchronized
157
180
pass
@@ -166,7 +189,10 @@ def hook(*unused):
166
189
'before the forward pass (e.g. by calling data_ptr) '
167
190
'or run DistributedFusedAdam with overlap_param_sync=False.'
168
191
)
169
- with self ._lock :
192
+ lock = self ._lock
193
+ if self ._lock_with_timeout is not None :
194
+ lock = self ._lock_with_timeout ()
195
+ with lock :
170
196
need_to_initialize = 'fragments' not in self .state [param ]
171
197
if need_to_initialize :
172
198
self ._init_param_state (param , param_group_id , param_id )
0 commit comments