Skip to content

Commit 8e7e460

Browse files
github-actions[bot]minituJaemin Choimichal2409pablo-garay
authored
Add option for mutex timeout in distributed optimizer backward hook (#9087) (#9091)
* Tim: Add option for timeout in distopt callback mutex * Replace parent's _lock * Revert "Replace parent's _lock" This reverts commit 972d1b6. * Raise RuntimeError when timeout * Change RuntimeError to print --------- Signed-off-by: Jaemin Choi <[email protected]> Co-authored-by: Jaemin Choi <[email protected]> Co-authored-by: Jaemin Choi <[email protected]> Co-authored-by: Michal Futrega <[email protected]> Co-authored-by: Pablo Garay <[email protected]>
1 parent 3c29fef commit 8e7e460

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

nemo/core/optim/distributed_adam.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import contextlib
1617
import itertools
1718
from typing import Callable, Dict, Iterable, Optional, Union
1819

@@ -108,6 +109,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam):
108109
but requires larger memory than distributing within all
109110
ranks, especially for pure data parallel models.
110111
(default: False).
112+
lock_timeout (float, optional): timeout for callback mutex in
113+
seconds.
111114
**kwargs: keyword arguments to pass to Apex
112115
DistributedFusedAdam.
113116
@@ -118,6 +121,7 @@ def __init__(
118121
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
119122
disable_distributed_parameters: bool = False,
120123
distribute_within_nodes: bool = False,
124+
lock_timeout: Optional[float] = None,
121125
**kwargs,
122126
):
123127

@@ -152,6 +156,25 @@ def __init__(
152156
# Construct distributed optimizer
153157
super().__init__(param_groups, **kwargs)
154158

159+
# Create mutex with timeout
160+
self._lock_with_timeout = None
161+
if lock_timeout is not None:
162+
163+
@contextlib.contextmanager
164+
def lock_with_timeout():
165+
result = self._lock.acquire(timeout=lock_timeout)
166+
try:
167+
yield result
168+
finally:
169+
if result:
170+
# Acquired lock before timeout
171+
self._lock.release()
172+
else:
173+
# Failed to acquire lock before timeout
174+
print(f'MegatronDistributedFusedAdam: Failed to acquire lock within {lock_timeout} seconds.')
175+
176+
self._lock_with_timeout = lock_with_timeout
177+
155178
def _broadcast_params(self) -> None:
156179
# Assume params have already been synchronized
157180
pass
@@ -166,7 +189,10 @@ def hook(*unused):
166189
'before the forward pass (e.g. by calling data_ptr) '
167190
'or run DistributedFusedAdam with overlap_param_sync=False.'
168191
)
169-
with self._lock:
192+
lock = self._lock
193+
if self._lock_with_timeout is not None:
194+
lock = self._lock_with_timeout()
195+
with lock:
170196
need_to_initialize = 'fragments' not in self.state[param]
171197
if need_to_initialize:
172198
self._init_param_state(param, param_group_id, param_id)

0 commit comments

Comments
 (0)