15
15
import collections
16
16
import contextlib
17
17
import itertools
18
+ import threading
18
19
from typing import Callable , Dict , Iterable , Optional , Union
19
20
20
21
import torch
@@ -156,20 +157,20 @@ def __init__(
156
157
# Construct distributed optimizer
157
158
super ().__init__ (param_groups , ** kwargs )
158
159
159
- # Create mutex with timeout
160
- self ._lock_with_timeout = None
160
+ # Replace lock if timeout is provided
161
161
if lock_timeout is not None :
162
+ self ._lock_with_timeout : threading .Lock = threading .Lock ()
162
163
163
164
@contextlib .contextmanager
164
165
def lock_with_timeout ():
165
- result = self ._lock .acquire (timeout = lock_timeout )
166
+ result = self ._lock_with_timeout .acquire (timeout = lock_timeout )
166
167
try :
167
168
yield result
168
169
finally :
169
170
if result :
170
- self ._lock .release ()
171
+ self ._lock_with_timeout .release ()
171
172
172
- self ._lock_with_timeout = lock_with_timeout
173
+ self ._lock = lock_with_timeout
173
174
174
175
def _broadcast_params (self ) -> None :
175
176
# Assume params have already been synchronized
@@ -185,10 +186,7 @@ def hook(*unused):
185
186
'before the forward pass (e.g. by calling data_ptr) '
186
187
'or run DistributedFusedAdam with overlap_param_sync=False.'
187
188
)
188
- lock = self ._lock
189
- if self ._lock_with_timeout is not None :
190
- lock = self ._lock_with_timeout ()
191
- with lock :
189
+ with self ._lock :
192
190
need_to_initialize = 'fragments' not in self .state [param ]
193
191
if need_to_initialize :
194
192
self ._init_param_state (param , param_group_id , param_id )
0 commit comments