From 5b3fb54cf36c74e2d847a313645ba7496e9a534c Mon Sep 17 00:00:00 2001 From: Gal Rotem <galrotem@fb.com> Date: Fri, 1 Sep 2023 15:58:38 -0700 Subject: [PATCH] pyre - eliminate errors from distributed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Get rid of pyre error annotations in distributed util Internal # Context [Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.] # This diff [List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.] # What’s next [If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.] Differential Revision: D48872019 --- torchtnt/utils/distributed.py | 16 ++++++++-------- torchtnt/utils/timer.py | 5 ++++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index 85d05e4113..c243e654d8 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -18,6 +18,8 @@ from torch.distributed.elastic.utils.distributed import get_free_port from typing_extensions import Literal +T = TypeVar("T") + class PGWrapper: """ @@ -54,14 +56,14 @@ def barrier(self) -> None: else: dist.barrier(group=self.pg) - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - def broadcast_object_list(self, obj_list: List[Any], src: int = 0) -> None: + def broadcast_object_list( + self, obj_list: Union[List[T], List[None]], src: int = 0 + ) -> None: if self.pg is None: return dist.broadcast_object_list(obj_list, src=src, group=self.pg) - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - def all_gather_object(self, obj_list: List[Any], obj: Any) -> None: + def all_gather_object(self, obj_list: List[Union[None, T]], obj: T) -> None: if self.pg is None: obj_list[0] = obj return @@ -69,10 +71,8 @@ def all_gather_object(self, obj_list: List[Any], obj: Any) -> None: def scatter_object_list( self, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - output_list: List[Any], - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - input_list: Optional[List[Any]], + output_list: List[Optional[T]], + input_list: Optional[Union[List[T], List[None]]], src: int = 0, ) -> None: rank = self.get_rank() diff --git a/torchtnt/utils/timer.py b/torchtnt/utils/timer.py index 7668b96429..77035b904c 100644 --- a/torchtnt/utils/timer.py +++ b/torchtnt/utils/timer.py @@ -22,6 +22,7 @@ Sequence, Tuple, TypeVar, + Union, ) import numpy as np @@ -364,7 +365,9 @@ def _sync_durations( pg_wrapper = PGWrapper(pg) world_size = pg_wrapper.get_world_size() - outputs = [None] * world_size + # the below is a workaround for pyre since it doesn't infer List[None] as List[Optional[T]] + none_list: List[Optional[Dict[str, List[float]]]] = [None] + outputs = none_list * world_size pg_wrapper.all_gather_object(outputs, recorded_durations) ret = defaultdict(list) for output in outputs: