Skip to content

Commit

Permalink
Add more fine-grained timing to the AutoUnit (pytorch#393)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#393

* Expand `_get_timing_context` to return two context managers:
  - one for timing with the Timer
  - one for calling record_function for PT Profiler
* Replace calls to record_function in AutoUnit with calls to `_get_timing_context` so that we time these functions as well
* Make some small changes to the `event_name`s for better readability

Reviewed By: ananthsub

Differential Revision: D45974462

fbshipit-source-id: 7581e027c37c4b527267cdd9d09824b1a620f9b9
  • Loading branch information
daniellepintz authored and facebook-github-bot committed May 19, 2023
1 parent 5f235a7 commit e05c060
Show file tree
Hide file tree
Showing 12 changed files with 204 additions and 116 deletions.
44 changes: 44 additions & 0 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,50 @@ def test_is_last_batch(self) -> None:
state = init_train_state(dataloader=dataloader, max_epochs=max_epochs)
train(state, my_unit)

def test_auto_unit_timing(self) -> None:
"""
Test auto timing in AutoUnit
"""

input_dim = 2
dataset_len = 10
batch_size = 2
max_steps_per_epoch = 1
max_epochs = 1

dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)

my_module = torch.nn.Linear(2, 2)

state = init_train_state(
dataloader=dataloader,
max_steps_per_epoch=max_steps_per_epoch,
max_epochs=max_epochs,
)
train(state, DummyAutoUnit(module=my_module))
self.assertIsNone(state.timer)

state = init_train_state(
dataloader=dataloader,
max_steps_per_epoch=max_steps_per_epoch,
max_epochs=max_epochs,
auto_timing=True,
)
train(state, DummyAutoUnit(module=my_module))
for k in (
"DummyAutoUnit.on_train_start",
"DummyAutoUnit.on_train_end",
"DummyAutoUnit.compute_loss",
"DummyAutoUnit.next(data_iter)",
"DummyAutoUnit.backward",
):
self.assertTrue(k in state.timer.recorded_durations.keys())

# train_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's train_step
self.assertFalse(
"DummyAutoUnit.train_step" in state.timer.recorded_durations.keys()
)


Batch = Tuple[torch.tensor, torch.tensor]

Expand Down
16 changes: 8 additions & 8 deletions tests/framework/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ def test_evaluate_auto_timing(self) -> None:
auto_timing=True,
)
evaluate(state, DummyEvalUnit(input_dim=input_dim))
for k in [
"eval.DummyEvalUnit.on_eval_start",
"eval.DummyEvalUnit.on_eval_epoch_start",
"eval.data_iter_next",
"eval.DummyEvalUnit.eval_step",
"eval.DummyEvalUnit.on_eval_epoch_end",
"eval.DummyEvalUnit.on_eval_end",
]:
for k in (
"DummyEvalUnit.on_eval_start",
"DummyEvalUnit.on_eval_epoch_start",
"evaluate.next(data_iter)",
"DummyEvalUnit.eval_step",
"DummyEvalUnit.on_eval_epoch_end",
"DummyEvalUnit.on_eval_end",
):
self.assertTrue(k in state.timer.recorded_durations.keys())


Expand Down
8 changes: 4 additions & 4 deletions tests/framework/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ def test_fit_auto_timing(self) -> None:
auto_timing=True,
)
fit(state, DummyFitUnit(input_dim=input_dim))
for k in [
"train.DummyFitUnit.on_train_start",
"train.DummyFitUnit.on_train_end",
]:
for k in (
"DummyFitUnit.on_train_start",
"DummyFitUnit.on_train_end",
):
self.assertTrue(k in state.timer.recorded_durations.keys())
16 changes: 8 additions & 8 deletions tests/framework/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,14 @@ def test_predict_auto_timing(self) -> None:
auto_timing=True,
)
predict(state, DummyPredictUnit(input_dim=input_dim))
for k in [
"predict.DummyPredictUnit.on_predict_start",
"predict.DummyPredictUnit.on_predict_epoch_start",
"predict.data_iter_next",
"predict.DummyPredictUnit.predict_step",
"predict.DummyPredictUnit.on_predict_epoch_end",
"predict.DummyPredictUnit.on_predict_end",
]:
for k in (
"DummyPredictUnit.on_predict_start",
"DummyPredictUnit.on_predict_epoch_start",
"predict.next(data_iter)",
"DummyPredictUnit.predict_step",
"DummyPredictUnit.on_predict_epoch_end",
"DummyPredictUnit.on_predict_end",
):
self.assertTrue(k in state.timer.recorded_durations.keys())


Expand Down
9 changes: 5 additions & 4 deletions tests/framework/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,10 +286,11 @@ def test_train_auto_timing(self) -> None:
auto_timing=True,
)
train(state, DummyTrainUnit(input_dim=input_dim))
for k in [
"train.DummyTrainUnit.on_train_start",
"train.DummyTrainUnit.on_train_end",
]:
for k in (
"DummyTrainUnit.on_train_start",
"DummyTrainUnit.on_train_end",
"train.next(data_iter)",
):
self.assertTrue(k in state.timer.recorded_durations.keys())


Expand Down
45 changes: 24 additions & 21 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import unittest
from typing import Any, Iterator, Tuple, Union
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import torch

Expand All @@ -19,7 +19,6 @@
if is_torch_version_geq_2_0():
from torch.distributed._composable import fully_shard

import contextlib
import time

from torch.distributed import launcher
Expand Down Expand Up @@ -179,55 +178,47 @@ def test_run_callback_fn_hooks(self) -> None:
ValueError("test"),
)
self.assertEqual(callback.dummy_data, "on_exception")
self.assertTrue(
"callback.DummyCallback.on_exception" in timer.recorded_durations.keys()
)
self.assertTrue("DummyCallback.on_exception" in timer.recorded_durations.keys())

_run_callback_fn([callback], "on_train_start", dummy_train_state, train_unit)
self.assertEqual(callback.dummy_data, "on_train_start")
self.assertTrue(
"callback.DummyCallback.on_train_start" in timer.recorded_durations.keys()
"DummyCallback.on_train_start" in timer.recorded_durations.keys()
)

_run_callback_fn(
[callback], "on_train_epoch_start", dummy_train_state, train_unit
)
self.assertEqual(callback.dummy_data, "on_train_epoch_start")
self.assertTrue(
"callback.DummyCallback.on_train_epoch_start"
in timer.recorded_durations.keys()
"DummyCallback.on_train_epoch_start" in timer.recorded_durations.keys()
)

_run_callback_fn(
[callback], "on_train_step_start", dummy_train_state, train_unit
)
self.assertEqual(callback.dummy_data, "on_train_step_start")
self.assertTrue(
"callback.DummyCallback.on_train_step_start"
in timer.recorded_durations.keys()
"DummyCallback.on_train_step_start" in timer.recorded_durations.keys()
)

_run_callback_fn([callback], "on_train_step_end", dummy_train_state, train_unit)
self.assertEqual(callback.dummy_data, "on_train_step_end")
self.assertTrue(
"callback.DummyCallback.on_train_step_end"
in timer.recorded_durations.keys()
"DummyCallback.on_train_step_end" in timer.recorded_durations.keys()
)

_run_callback_fn(
[callback], "on_train_epoch_end", dummy_train_state, train_unit
)
self.assertEqual(callback.dummy_data, "on_train_epoch_end")
self.assertTrue(
"callback.DummyCallback.on_train_epoch_end"
in timer.recorded_durations.keys()
"DummyCallback.on_train_epoch_end" in timer.recorded_durations.keys()
)

_run_callback_fn([callback], "on_train_end", dummy_train_state, train_unit)
self.assertEqual(callback.dummy_data, "on_train_end")
self.assertTrue(
"callback.DummyCallback.on_train_end" in timer.recorded_durations.keys()
)
self.assertTrue("DummyCallback.on_train_end" in timer.recorded_durations.keys())

def test_run_callback_fn_exception(self) -> None:
"""
Expand Down Expand Up @@ -339,18 +330,30 @@ def test_get_current_progress(self) -> None:
progress.num_steps_completed, train_state.progress.num_steps_completed
)

def test_get_timing_context(self) -> None:
@patch("torchtnt.framework.utils.record_function")
def test_get_timing_context(self, mock_record_function) -> None:
state = MagicMock()
state.timer = None

ctx = _get_timing_context(state, "a")
self.assertEqual(type(ctx), contextlib.nullcontext)
with ctx:
time.sleep(1)
mock_record_function.assert_called_with("a")

state.timer = Timer()
ctx = _get_timing_context(state, "a")
ctx = _get_timing_context(state, "b")
with ctx:
time.sleep(1)
self.assertTrue("b" in state.timer.recorded_durations.keys())
mock_record_function.assert_called_with("b")

state.timer = Timer()
ctx = _get_timing_context(state, "c", skip_timer=True)
with ctx:
time.sleep(1)
self.assertTrue("a" in state.timer.recorded_durations.keys())
# "c" should not be in the recorded_durations because we set skip_timer to True
self.assertFalse("c" in state.timer.recorded_durations.keys())
mock_record_function.assert_called_with("c")

def test_find_optimizers_for_module(self) -> None:
module1 = torch.nn.Linear(10, 10)
Expand Down
Loading

0 comments on commit e05c060

Please sign in to comment.