-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add distributed context in pytorch engine to support torchrun (#2615)
- Loading branch information
Showing
13 changed files
with
119 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import threading | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass | ||
|
||
from torch import distributed as dist | ||
|
||
|
||
@dataclass | ||
class DistContext: | ||
rank: int = 0 | ||
world_size: int = 1 | ||
dist_group: dist.ProcessGroup = None | ||
|
||
|
||
DefaultContext = DistContext() | ||
|
||
|
||
class DistManager: | ||
"""distributed context manager.""" | ||
|
||
def __init__(self): | ||
self.t_local = threading.local() | ||
self.t_local.device_context = DefaultContext | ||
|
||
def current_context(self) -> DistContext: | ||
"""get current context.""" | ||
return getattr(self.t_local, 'device_context', DefaultContext) | ||
|
||
def set_context(self, context: DistContext): | ||
"""set current context.""" | ||
self.t_local.device_context = context | ||
|
||
@contextmanager | ||
def context(self, context: DistContext): | ||
"""context manager.""" | ||
origin_context = self.current_context() | ||
self.set_context(context) | ||
yield self | ||
self.set_context(origin_context) | ||
|
||
|
||
_DIST_MANAGER: DistManager = None | ||
|
||
|
||
def get_dist_manager(): | ||
"""get device manager.""" | ||
global _DIST_MANAGER | ||
if _DIST_MANAGER is None: | ||
_DIST_MANAGER = DistManager() | ||
return _DIST_MANAGER | ||
|
||
|
||
def get_world_rank(): | ||
"""get distributed world size and rank.""" | ||
ctx = get_dist_manager().current_context() | ||
world_size = ctx.world_size | ||
rank = ctx.rank | ||
|
||
return world_size, rank | ||
|
||
|
||
def get_process_group(): | ||
"""get process group.""" | ||
ctx = get_dist_manager().current_context() | ||
return ctx.dist_group |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.