From 6de617ae34d10ae31a5701244f378e0e32a76dc2 Mon Sep 17 00:00:00 2001 From: CodemodService Bot Date: Wed, 6 Dec 2023 06:51:59 -0800 Subject: [PATCH] torchtnt (#640) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/640 Differential Revision: D51846380 --- examples/auto_unit_example.py | 6 +++++- examples/torchdata_train_example.py | 6 +++++- examples/torchrec/main.py | 6 +++++- examples/train_unit_example.py | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/auto_unit_example.py b/examples/auto_unit_example.py index c5934c5129..52ad14adf9 100644 --- a/examples/auto_unit_example.py +++ b/examples/auto_unit_example.py @@ -239,7 +239,7 @@ def get_args() -> Namespace: return parser.parse_args() -if __name__ == "__main__": +def invoke_main() -> None: args: Namespace = get_args() lc = pet.LaunchConfig( min_nodes=1, @@ -253,3 +253,7 @@ def get_args() -> Namespace: ) pet.elastic_launch(lc, entrypoint=main)(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/examples/torchdata_train_example.py b/examples/torchdata_train_example.py index b547889134..4dd125760c 100644 --- a/examples/torchdata_train_example.py +++ b/examples/torchdata_train_example.py @@ -202,5 +202,9 @@ def get_args(argv: List[str]) -> Namespace: return parser.parse_args(argv) -if __name__ == "__main__": +def invoke_main() -> None: main(sys.argv[1:]) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/examples/torchrec/main.py b/examples/torchrec/main.py index 4acf7fc701..6cf41dbfad 100644 --- a/examples/torchrec/main.py +++ b/examples/torchrec/main.py @@ -378,7 +378,7 @@ def main(argv: List[str]) -> None: ) -if __name__ == "__main__": +def invoke_main() -> None: lc = launcher.LaunchConfig( min_nodes=MIN_NODES, max_nodes=MAX_NODES, @@ -391,3 +391,7 @@ def main(argv: List[str]) -> None: ) launcher.elastic_launch(config=lc, entrypoint=main)(sys.argv[1:]) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/examples/train_unit_example.py b/examples/train_unit_example.py index aeff258b9f..bfe3f14800 100644 --- a/examples/train_unit_example.py +++ b/examples/train_unit_example.py @@ -173,5 +173,9 @@ def get_args(argv: List[str]) -> Namespace: return parser.parse_args(argv) -if __name__ == "__main__": +def invoke_main() -> None: main(sys.argv[1:]) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover