diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index 85d05e4113..c243e654d8 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -18,6 +18,8 @@ from torch.distributed.elastic.utils.distributed import get_free_port from typing_extensions import Literal +T = TypeVar("T") + class PGWrapper: """ @@ -54,14 +56,14 @@ def barrier(self) -> None: else: dist.barrier(group=self.pg) - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - def broadcast_object_list(self, obj_list: List[Any], src: int = 0) -> None: + def broadcast_object_list( + self, obj_list: Union[List[T], List[None]], src: int = 0 + ) -> None: if self.pg is None: return dist.broadcast_object_list(obj_list, src=src, group=self.pg) - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - def all_gather_object(self, obj_list: List[Any], obj: Any) -> None: + def all_gather_object(self, obj_list: List[Union[None, T]], obj: T) -> None: if self.pg is None: obj_list[0] = obj return @@ -69,10 +71,8 @@ def all_gather_object(self, obj_list: List[Any], obj: Any) -> None: def scatter_object_list( self, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - output_list: List[Any], - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - input_list: Optional[List[Any]], + output_list: List[Optional[T]], + input_list: Optional[Union[List[T], List[None]]], src: int = 0, ) -> None: rank = self.get_rank() diff --git a/torchtnt/utils/timer.py b/torchtnt/utils/timer.py index 7668b96429..77035b904c 100644 --- a/torchtnt/utils/timer.py +++ b/torchtnt/utils/timer.py @@ -22,6 +22,7 @@ Sequence, Tuple, TypeVar, + Union, ) import numpy as np @@ -364,7 +365,9 @@ def _sync_durations( pg_wrapper = PGWrapper(pg) world_size = pg_wrapper.get_world_size() - outputs = [None] * world_size + # the below is a workaround for pyre since it doesn't infer List[None] as List[Optional[T]] + none_list: List[Optional[Dict[str, List[float]]]] = [None] + outputs = none_list * world_size pg_wrapper.all_gather_object(outputs, recorded_durations) ret = defaultdict(list) for output in outputs: