Skip to content

Commit dd78d22

Browse files
author
Jaemin Choi
committed
Replace parent's _lock
1 parent 81f4902 commit dd78d22

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

nemo/core/optim/distributed_adam.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import collections
1616
import contextlib
1717
import itertools
18+
import threading
1819
from typing import Callable, Dict, Iterable, Optional, Union
1920

2021
import torch
@@ -156,20 +157,20 @@ def __init__(
156157
# Construct distributed optimizer
157158
super().__init__(param_groups, **kwargs)
158159

159-
# Create mutex with timeout
160-
self._lock_with_timeout = None
160+
# Replace lock if timeout is provided
161161
if lock_timeout is not None:
162+
self._lock_with_timeout: threading.Lock = threading.Lock()
162163

163164
@contextlib.contextmanager
164165
def lock_with_timeout():
165-
result = self._lock.acquire(timeout=lock_timeout)
166+
result = self._lock_with_timeout.acquire(timeout=lock_timeout)
166167
try:
167168
yield result
168169
finally:
169170
if result:
170-
self._lock.release()
171+
self._lock_with_timeout.release()
171172

172-
self._lock_with_timeout = lock_with_timeout
173+
self._lock = lock_with_timeout
173174

174175
def _broadcast_params(self) -> None:
175176
# Assume params have already been synchronized
@@ -185,10 +186,7 @@ def hook(*unused):
185186
'before the forward pass (e.g. by calling data_ptr) '
186187
'or run DistributedFusedAdam with overlap_param_sync=False.'
187188
)
188-
lock = self._lock
189-
if self._lock_with_timeout is not None:
190-
lock = self._lock_with_timeout()
191-
with lock:
189+
with self._lock:
192190
need_to_initialize = 'fragments' not in self.state[param]
193191
if need_to_initialize:
194192
self._init_param_state(param, param_group_id, param_id)

0 commit comments

Comments
 (0)