13
13
# limitations under the License.
14
14
15
15
import collections
16
+ import contextlib
16
17
import itertools
17
18
from typing import Callable , Dict , Iterable , Optional , Union
18
19
@@ -55,6 +56,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam):
55
56
but requires larger memory than distributing within all
56
57
ranks, especially for pure data parallel models.
57
58
(default: False).
59
+ lock_timeout (float, optional): timeout for callback mutex in
60
+ seconds.
58
61
**kwargs: keyword arguments to pass to Apex
59
62
DistributedFusedAdam.
60
63
@@ -65,6 +68,7 @@ def __init__(
65
68
params : Union [Iterable [torch .nn .Parameter ], Iterable [dict ]],
66
69
disable_distributed_parameters : bool = False ,
67
70
distribute_within_nodes : bool = False ,
71
+ lock_timeout : Optional [float ] = None ,
68
72
** kwargs ,
69
73
):
70
74
@@ -114,6 +118,25 @@ def __init__(
114
118
# Construct distributed optimizer
115
119
super ().__init__ (param_groups , ** kwargs )
116
120
121
+ # Create mutex with timeout
122
+ self ._lock_with_timeout = None
123
+ if lock_timeout is not None :
124
+
125
+ @contextlib .contextmanager
126
+ def lock_with_timeout ():
127
+ result = self ._lock .acquire (timeout = lock_timeout )
128
+ try :
129
+ yield result
130
+ finally :
131
+ if result :
132
+ # Acquired lock before timeout
133
+ self ._lock .release ()
134
+ else :
135
+ # Failed to acquire lock before timeout
136
+ print (f'MegatronDistributedFusedAdam: Failed to acquire lock within { lock_timeout } seconds.' )
137
+
138
+ self ._lock_with_timeout = lock_with_timeout
139
+
117
140
def _broadcast_params (self ) -> None :
118
141
# Assume params have already been synchronized
119
142
pass
@@ -128,7 +151,10 @@ def hook(*unused):
128
151
'before the forward pass (e.g. by calling data_ptr) '
129
152
'or run DistributedFusedAdam with overlap_param_sync=False.'
130
153
)
131
- with self ._lock :
154
+ lock = self ._lock
155
+ if self ._lock_with_timeout is not None :
156
+ lock = self ._lock_with_timeout ()
157
+ with lock :
132
158
need_to_initialize = 'fragments' not in self .state [param ]
133
159
if need_to_initialize :
134
160
self ._init_param_state (param , param_group_id , param_id )
0 commit comments