From 5b3fb54cf36c74e2d847a313645ba7496e9a534c Mon Sep 17 00:00:00 2001
From: Gal Rotem <galrotem@fb.com>
Date: Fri, 1 Sep 2023 15:58:38 -0700
Subject: [PATCH] pyre - eliminate errors from distributed
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Summary:
Get rid of pyre error annotations in distributed util

Internal


# Context
[Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.]


# This diff
[List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.]


# What’s next
[If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.]

Differential Revision: D48872019
---
 torchtnt/utils/distributed.py | 16 ++++++++--------
 torchtnt/utils/timer.py       |  5 ++++-
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py
index 85d05e4113..c243e654d8 100644
--- a/torchtnt/utils/distributed.py
+++ b/torchtnt/utils/distributed.py
@@ -18,6 +18,8 @@
 from torch.distributed.elastic.utils.distributed import get_free_port
 from typing_extensions import Literal
 
+T = TypeVar("T")
+
 
 class PGWrapper:
     """
@@ -54,14 +56,14 @@ def barrier(self) -> None:
         else:
             dist.barrier(group=self.pg)
 
-    # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
-    def broadcast_object_list(self, obj_list: List[Any], src: int = 0) -> None:
+    def broadcast_object_list(
+        self, obj_list: Union[List[T], List[None]], src: int = 0
+    ) -> None:
         if self.pg is None:
             return
         dist.broadcast_object_list(obj_list, src=src, group=self.pg)
 
-    # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
-    def all_gather_object(self, obj_list: List[Any], obj: Any) -> None:
+    def all_gather_object(self, obj_list: List[Union[None, T]], obj: T) -> None:
         if self.pg is None:
             obj_list[0] = obj
             return
@@ -69,10 +71,8 @@ def all_gather_object(self, obj_list: List[Any], obj: Any) -> None:
 
     def scatter_object_list(
         self,
-        # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
-        output_list: List[Any],
-        # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
-        input_list: Optional[List[Any]],
+        output_list: List[Optional[T]],
+        input_list: Optional[Union[List[T], List[None]]],
         src: int = 0,
     ) -> None:
         rank = self.get_rank()
diff --git a/torchtnt/utils/timer.py b/torchtnt/utils/timer.py
index 7668b96429..77035b904c 100644
--- a/torchtnt/utils/timer.py
+++ b/torchtnt/utils/timer.py
@@ -22,6 +22,7 @@
     Sequence,
     Tuple,
     TypeVar,
+    Union,
 )
 
 import numpy as np
@@ -364,7 +365,9 @@ def _sync_durations(
 
     pg_wrapper = PGWrapper(pg)
     world_size = pg_wrapper.get_world_size()
-    outputs = [None] * world_size
+    # the below is a workaround for pyre since it doesn't infer List[None] as List[Optional[T]]
+    none_list: List[Optional[Dict[str, List[float]]]] = [None]
+    outputs = none_list * world_size
     pg_wrapper.all_gather_object(outputs, recorded_durations)
     ret = defaultdict(list)
     for output in outputs: