From 497dbda3860fe9d2a1757a7c16f13a94b23dc117 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Wed, 17 May 2023 19:43:02 -0700 Subject: [PATCH] Only print timer summary on rank 0 Summary: Since this will be used mostly for local debugging let's only print this on rank 0 Differential Revision: D45975140 fbshipit-source-id: 418de1fa5ea73e15d798c1e083d965d98cb54251 --- examples/auto_unit_example.py | 1 + torchtnt/framework/evaluate.py | 3 ++- torchtnt/framework/fit.py | 3 ++- torchtnt/framework/predict.py | 3 ++- torchtnt/framework/train.py | 3 ++- 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/auto_unit_example.py b/examples/auto_unit_example.py index af38825d09..27fbf39ce7 100644 --- a/examples/auto_unit_example.py +++ b/examples/auto_unit_example.py @@ -169,6 +169,7 @@ def main(args: Namespace) -> None: train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_epochs=args.max_epochs, + auto_timing=True, ) fit(state, my_unit) diff --git a/torchtnt/framework/evaluate.py b/torchtnt/framework/evaluate.py index 92b6bc29d6..bb5f49892c 100644 --- a/torchtnt/framework/evaluate.py +++ b/torchtnt/framework/evaluate.py @@ -22,6 +22,7 @@ _step_requires_iterator, log_api_usage, ) +from torchtnt.utils.rank_zero_log import rank_zero_info from torchtnt.utils.timer import get_timer_summary, Timer logger: logging.Logger = logging.getLogger(__name__) @@ -120,7 +121,7 @@ def evaluate( _evaluate_impl(state, eval_unit, callbacks) logger.info("Finished evaluation") if state.timer: - logger.info(get_timer_summary(state.timer)) + rank_zero_info(get_timer_summary(state.timer)) except Exception as e: # TODO: log for diagnostics logger.info(e) diff --git a/torchtnt/framework/fit.py b/torchtnt/framework/fit.py index bc75955aa2..f5d281c905 100644 --- a/torchtnt/framework/fit.py +++ b/torchtnt/framework/fit.py @@ -24,6 +24,7 @@ _run_callback_fn, log_api_usage, ) +from torchtnt.utils.rank_zero_log import rank_zero_info from torchtnt.utils.timer import get_timer_summary, Timer logger: logging.Logger = logging.getLogger(__name__) @@ -149,7 +150,7 @@ def fit( state._entry_point = EntryPoint.FIT _fit_impl(state, unit, callbacks) if state.timer: - logger.info(get_timer_summary(state.timer)) + rank_zero_info(get_timer_summary(state.timer)) except Exception as e: # TODO: log for diagnostics logger.info(e) diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index 3a2c305b2c..8db67d9022 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -22,6 +22,7 @@ _step_requires_iterator, log_api_usage, ) +from torchtnt.utils.rank_zero_log import rank_zero_info from torchtnt.utils.timer import get_timer_summary, Timer logger: logging.Logger = logging.getLogger(__name__) @@ -121,7 +122,7 @@ def predict( _predict_impl(state, predict_unit, callbacks) logger.info("Finished predict") if state.timer: - logger.info(get_timer_summary(state.timer)) + rank_zero_info(get_timer_summary(state.timer)) except Exception as e: # TODO: log for diagnostics logger.info(e) diff --git a/torchtnt/framework/train.py b/torchtnt/framework/train.py index 63df26d05b..1b0c8be043 100644 --- a/torchtnt/framework/train.py +++ b/torchtnt/framework/train.py @@ -25,6 +25,7 @@ _step_requires_iterator, log_api_usage, ) +from torchtnt.utils.rank_zero_log import rank_zero_info from torchtnt.utils.timer import get_timer_summary, Timer logger: logging.Logger = logging.getLogger(__name__) @@ -132,7 +133,7 @@ def train( _train_impl(state, train_unit, callbacks) logger.info("Finished train") if state.timer: - logger.info(get_timer_summary(state.timer)) + rank_zero_info(get_timer_summary(state.timer)) except Exception as e: # TODO: log for diagnostics logger.info(f"Exception during train\n: {e}")