Skip to content

Commit

Permalink
pyre - eliminate errors from distributed
Browse files Browse the repository at this point in the history
Summary:
Get rid of pyre error annotations in distributed util

Internal


# Context
[Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.]


# This diff
[List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.]


# What’s next
[If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.]

Differential Revision: D48872019
  • Loading branch information
galrotem authored and facebook-github-bot committed Sep 1, 2023
1 parent 1f703d5 commit 5b3fb54
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
16 changes: 8 additions & 8 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from torch.distributed.elastic.utils.distributed import get_free_port
from typing_extensions import Literal

T = TypeVar("T")


class PGWrapper:
"""
Expand Down Expand Up @@ -54,25 +56,23 @@ def barrier(self) -> None:
else:
dist.barrier(group=self.pg)

# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def broadcast_object_list(self, obj_list: List[Any], src: int = 0) -> None:
def broadcast_object_list(
self, obj_list: Union[List[T], List[None]], src: int = 0
) -> None:
if self.pg is None:
return
dist.broadcast_object_list(obj_list, src=src, group=self.pg)

# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
def all_gather_object(self, obj_list: List[Any], obj: Any) -> None:
def all_gather_object(self, obj_list: List[Union[None, T]], obj: T) -> None:
if self.pg is None:
obj_list[0] = obj
return
dist.all_gather_object(obj_list, obj, group=self.pg)

def scatter_object_list(
self,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
output_list: List[Any],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
input_list: Optional[List[Any]],
output_list: List[Optional[T]],
input_list: Optional[Union[List[T], List[None]]],
src: int = 0,
) -> None:
rank = self.get_rank()
Expand Down
5 changes: 4 additions & 1 deletion torchtnt/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Sequence,
Tuple,
TypeVar,
Union,
)

import numpy as np
Expand Down Expand Up @@ -364,7 +365,9 @@ def _sync_durations(

pg_wrapper = PGWrapper(pg)
world_size = pg_wrapper.get_world_size()
outputs = [None] * world_size
# the below is a workaround for pyre since it doesn't infer List[None] as List[Optional[T]]
none_list: List[Optional[Dict[str, List[float]]]] = [None]
outputs = none_list * world_size
pg_wrapper.all_gather_object(outputs, recorded_durations)
ret = defaultdict(list)
for output in outputs:
Expand Down

0 comments on commit 5b3fb54

Please sign in to comment.