diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 2a67fcc09..a64e7f6d6 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -66,7 +66,9 @@ @dataclass class EmbeddingCollectionContext(Multistreamable): - sharding_contexts: List[InferSequenceShardingContext | SequenceShardingContext] + sharding_contexts: List[ + Union[InferSequenceShardingContext, SequenceShardingContext] + ] def record_stream(self, stream: torch.Stream) -> None: for ctx in self.sharding_contexts: