Skip to content

Commit d5d72cd

Browse files
minituJaemin Choi
authored andcommitted
Add option for mutex timeout in distributed optimizer backward hook (#9087)
* Tim: Add option for timeout in distopt callback mutex Signed-off-by: Jaemin Choi <[email protected]> * Replace parent's _lock Signed-off-by: Jaemin Choi <[email protected]> * Revert "Replace parent's _lock" This reverts commit 972d1b6. Signed-off-by: Jaemin Choi <[email protected]> * Raise RuntimeError when timeout Signed-off-by: Jaemin Choi <[email protected]> * Change RuntimeError to print Signed-off-by: Jaemin Choi <[email protected]> --------- Signed-off-by: Jaemin Choi <[email protected]> Co-authored-by: Jaemin Choi <[email protected]>
1 parent a8e0ca1 commit d5d72cd

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

nemo/core/optim/distributed_adam.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import contextlib
1617
import itertools
1718
from typing import Callable, Dict, Iterable, Optional, Union
1819

@@ -55,6 +56,8 @@ class MegatronDistributedFusedAdam(DistributedFusedAdam):
5556
but requires larger memory than distributing within all
5657
ranks, especially for pure data parallel models.
5758
(default: False).
59+
lock_timeout (float, optional): timeout for callback mutex in
60+
seconds.
5861
**kwargs: keyword arguments to pass to Apex
5962
DistributedFusedAdam.
6063
@@ -65,6 +68,7 @@ def __init__(
6568
params: Union[Iterable[torch.nn.Parameter], Iterable[dict]],
6669
disable_distributed_parameters: bool = False,
6770
distribute_within_nodes: bool = False,
71+
lock_timeout: Optional[float] = None,
6872
**kwargs,
6973
):
7074

@@ -114,6 +118,25 @@ def __init__(
114118
# Construct distributed optimizer
115119
super().__init__(param_groups, **kwargs)
116120

121+
# Create mutex with timeout
122+
self._lock_with_timeout = None
123+
if lock_timeout is not None:
124+
125+
@contextlib.contextmanager
126+
def lock_with_timeout():
127+
result = self._lock.acquire(timeout=lock_timeout)
128+
try:
129+
yield result
130+
finally:
131+
if result:
132+
# Acquired lock before timeout
133+
self._lock.release()
134+
else:
135+
# Failed to acquire lock before timeout
136+
print(f'MegatronDistributedFusedAdam: Failed to acquire lock within {lock_timeout} seconds.')
137+
138+
self._lock_with_timeout = lock_with_timeout
139+
117140
def _broadcast_params(self) -> None:
118141
# Assume params have already been synchronized
119142
pass
@@ -128,7 +151,10 @@ def hook(*unused):
128151
'before the forward pass (e.g. by calling data_ptr) '
129152
'or run DistributedFusedAdam with overlap_param_sync=False.'
130153
)
131-
with self._lock:
154+
lock = self._lock
155+
if self._lock_with_timeout is not None:
156+
lock = self._lock_with_timeout()
157+
with lock:
132158
need_to_initialize = 'fragments' not in self.state[param]
133159
if need_to_initialize:
134160
self._init_param_state(param, param_group_id, param_id)

0 commit comments

Comments
 (0)