diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 048f3365e1753..66bbdc7fc3750 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -480,7 +480,7 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: ) self.setup_precision_plugin(plugin) - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 01c23504b7773..6fd02142bf410 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, TypeVar, Union import torch from torch.nn import Module @@ -30,6 +30,8 @@ if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer +TBroadcast = TypeVar("T") + class TrainingTypePlugin(Plugin, ABC): """A Plugin to change the behaviour of the training, validation and test-loop.""" @@ -88,7 +90,7 @@ def barrier(self, name: Optional[str] = None) -> None: """Forces all possibly joined processes to wait for each other""" @abstractmethod - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: """Broadcasts an object to all processes""" @abstractmethod