Skip to content

Commit efcd761

Browse files
tchatoncarmoccaBorda
committed
[bug] Fix Pytorch profiler with emit_nvtx (#6260)
* resolve bug * update changelog * Update tests/trainer/test_trainer.py * Update pytorch_lightning/profiler/profilers.py Co-authored-by: Jirka Borovec <[email protected]> * resolve comments * resolve flake8 Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 2c99d70 commit efcd761

File tree

5 files changed

+337
-6
lines changed

5 files changed

+337
-6
lines changed

Diff for: CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525

2626

2727

28+
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
29+
30+
31+
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
32+
33+
2834
## [1.2.2] - 2021-03-02
2935

3036
### Added

Diff for: pytorch_lightning/profiler/profilers.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Profiler to check if there are any bottlenecks in your code."""
15-
1615
import cProfile
1716
import inspect
1817
import io

Diff for: pytorch_lightning/profiler/pytorch.py

+303
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Profiler to check if there are any bottlenecks in your code."""
15+
16+
import inspect
17+
import logging
18+
import os
19+
from typing import List, Optional
20+
21+
import torch
22+
23+
from pytorch_lightning.profiler.profilers import BaseProfiler
24+
from pytorch_lightning.utilities import rank_zero_only
25+
from pytorch_lightning.utilities.cloud_io import get_filesystem
26+
from pytorch_lightning.utilities.distributed import rank_zero_warn
27+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
28+
29+
log = logging.getLogger(__name__)
30+
31+
32+
class PyTorchProfiler(BaseProfiler):
33+
34+
PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step")
35+
AVAILABLE_SORT_KEYS = (
36+
"cpu_time",
37+
"cuda_time",
38+
"cpu_time_total",
39+
"cuda_time_total",
40+
"cpu_memory_usage",
41+
"cuda_memory_usage",
42+
"self_cpu_memory_usage",
43+
"self_cuda_memory_usage",
44+
"count",
45+
)
46+
47+
def __init__(
48+
self,
49+
output_filename: Optional[str] = None,
50+
enabled: bool = True,
51+
use_cuda: bool = False,
52+
record_shapes: bool = False,
53+
profile_memory: bool = False,
54+
group_by_input_shapes: bool = False,
55+
with_stack: bool = False,
56+
use_kineto: bool = False,
57+
use_cpu: bool = True,
58+
emit_nvtx: bool = False,
59+
export_to_chrome: bool = False,
60+
path_to_export_trace: str = None,
61+
row_limit: int = 20,
62+
sort_by_key: Optional[str] = None,
63+
profiled_functions: Optional[List] = None,
64+
local_rank: Optional[int] = None,
65+
):
66+
"""
67+
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
68+
different operators inside your model - both on the CPU and GPU
69+
70+
Args:
71+
72+
output_filename: optionally save profile results to file instead of printing
73+
to std out when training is finished. When using ``ddp``,
74+
each rank will stream the profiled operation to their own file
75+
with the extension ``_{rank}.txt``
76+
77+
enabled: Setting this to False makes this context manager a no-op.
78+
79+
use_cuda: Enables timing of CUDA events as well using the cudaEvent API.
80+
Adds approximately 4us of overhead to each tensor operation.
81+
82+
record_shapes: If shapes recording is set, information about input dimensions will be collected.
83+
84+
profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0)
85+
86+
group_by_input_shapes: Include operator input shapes and group calls by shape.
87+
88+
with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0)
89+
90+
use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0)
91+
92+
use_cpu: use_kineto=True and can be used to lower the overhead
93+
for GPU-only profiling (Introduced in PyTorch 1.8.0)
94+
95+
emit_nvtx: Context manager that makes every autograd operation emit an NVTX range
96+
Run::
97+
98+
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
99+
100+
To visualize, you can either use::
101+
102+
nvvp trace_name.prof
103+
torch.autograd.profiler.load_nvprof(path)
104+
105+
export_to_chrome: Wether to export the sequence of profiled operators for Chrome.
106+
It will generate a ``.json`` file which can be read by Chrome.
107+
108+
path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``.
109+
By default, it will be save where the file being is being run.
110+
111+
row_limit: Limit the number of rows in a table, `0` is a special value that
112+
removes the limit completely.
113+
114+
sort_by_key: Keys to sort out profiled table
115+
116+
profiled_functions: list of profiled functions which will create a context manager on.
117+
Any other will be pass through.
118+
119+
local_rank: When running in distributed setting, local_rank is used for each process
120+
to write to their own file if `output_fname` is provided.
121+
122+
Raises:
123+
MisconfigurationException:
124+
If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or
125+
if log file is not a ``.txt`` file.
126+
ValueError:
127+
If you attempt to stop recording an action which was never started.
128+
"""
129+
130+
self.profiled_actions = {}
131+
self.enabled = enabled
132+
self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS
133+
self.use_cuda = use_cuda
134+
self.record_shapes = record_shapes
135+
self.profile_memory = profile_memory
136+
self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total")
137+
self.with_stack = with_stack
138+
self.group_by_input_shapes = group_by_input_shapes and record_shapes
139+
self.use_kineto = use_kineto
140+
self.use_cpu = use_cpu
141+
self.row_limit = row_limit
142+
self.emit_nvtx = emit_nvtx
143+
self.export_to_chrome = export_to_chrome
144+
self.path_to_export_trace = path_to_export_trace
145+
146+
if export_to_chrome and path_to_export_trace is None:
147+
rank_zero_warn(
148+
"The exported trace would be save locally as `path_to_export_trace` is empty."
149+
" Note: Each functions will generate its own traced file."
150+
)
151+
152+
if self.sort_by_key not in self.AVAILABLE_SORT_KEYS:
153+
raise MisconfigurationException(
154+
f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
155+
)
156+
157+
self.profiled_actions = {}
158+
self.context_names = {}
159+
self.running_stack = []
160+
self.profiler = None
161+
162+
self.output_fname = output_filename
163+
self.output_file = None
164+
if local_rank is not None:
165+
self.on_train_start(local_rank=local_rank)
166+
self.on_train_start = super().on_train_start
167+
168+
def on_train_start(self, local_rank: Optional[str] = None):
169+
self.local_rank = local_rank
170+
171+
# when logging to `log.info`, only perform profiling on rank 0
172+
if local_rank != 0 and self.output_fname is None:
173+
self.wrap_functions_into_rank_zero_only()
174+
175+
if self.output_fname:
176+
if local_rank is not None:
177+
if '.txt' not in self.output_fname:
178+
raise MisconfigurationException("Log file should be .txt file.")
179+
180+
self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt")
181+
182+
fs = get_filesystem(self.output_fname)
183+
self.output_file = fs.open(self.output_fname, "w")
184+
185+
streaming_out = [self.output_file.write] if self.output_file else [log.info]
186+
super().__init__(output_streams=streaming_out)
187+
188+
def wrap_functions_into_rank_zero_only(self):
189+
self.start = rank_zero_only(self.start)
190+
self.stop = rank_zero_only(self.stop)
191+
self.summary = rank_zero_only(self.summary)
192+
self.describe = rank_zero_only(self.describe)
193+
194+
def start(self, action_name: str) -> None:
195+
if action_name not in self.profiled_functions:
196+
return
197+
198+
if len(self.running_stack) > 0:
199+
self._stop(self.running_stack[-1])
200+
self.running_stack.append(action_name)
201+
202+
self.context_names[action_name] = "/".join(self.running_stack)
203+
204+
self._start(action_name)
205+
206+
def _start(self, action_name: str) -> None:
207+
if self.emit_nvtx:
208+
self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True)
209+
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
210+
else:
211+
self._create_profiler(action_name, torch.autograd.profiler.profile)
212+
213+
def _create_profiler(self, action_name, profiler, enter=True):
214+
init_args = inspect.signature(profiler.__init__).parameters
215+
profiler_args = {k: v for k, v in vars(self).items() if k in init_args}
216+
pr = profiler(**profiler_args)
217+
if enter:
218+
out_pr = pr.__enter__()
219+
if out_pr is not None:
220+
pr = out_pr
221+
self.profiler = pr
222+
return self.profiler
223+
224+
def _stop(self, action_name: str) -> None:
225+
if self.profiler is None:
226+
return
227+
228+
self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None)
229+
230+
if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx):
231+
# when running ``emit_nvtx``, PyTorch requires 2 context manager.
232+
# The parent_profiler is being closed too.
233+
self._parent_profiler.__exit__(None, None, None)
234+
return
235+
236+
function_events = self.profiler.function_events
237+
self.profiler = None
238+
for name in self.running_stack:
239+
if name not in self.profiled_actions:
240+
self.profiled_actions[name] = function_events
241+
else:
242+
self.profiled_actions[name] += function_events
243+
244+
def stop(self, action_name: str) -> None:
245+
if action_name not in self.profiled_functions:
246+
return
247+
248+
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
249+
raise ValueError( # pragma: no-cover
250+
f"Attempting to stop recording an action ({action_name}) which was never started."
251+
)
252+
self._stop(action_name)
253+
self.running_stack.pop()
254+
# restore running profiler
255+
if len(self.running_stack) > 0:
256+
self._start(self.running_stack[-1])
257+
258+
def summary(self) -> str:
259+
recorded_stats = {}
260+
output_string = ''
261+
local_rank = '0' if self.local_rank is None else self.local_rank
262+
263+
if not self.enabled:
264+
return output_string
265+
266+
for action_name, function_events in self.profiled_actions.items():
267+
268+
# next line is a workaround for a pytorch issue (fixed on master, still present
269+
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
270+
# parent event for detach`
271+
function_events.populate_cpu_children = lambda: None
272+
273+
if self.export_to_chrome:
274+
filename = f"{action_name}_{local_rank}_trace.json"
275+
path_to_trace = filename if self.path_to_export_trace is None \
276+
else os.path.join(self.path_to_export_trace, filename)
277+
function_events.export_chrome_trace(path_to_trace)
278+
279+
if self.emit_nvtx:
280+
return output_string
281+
282+
else:
283+
data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes)
284+
table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit)
285+
recorded_stats[action_name] = table
286+
287+
# log to standard out
288+
output_string = f"{os.linesep}Profiler Report{os.linesep}"
289+
for action, stats in recorded_stats.items():
290+
output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}")
291+
292+
return output_string
293+
294+
def describe(self):
295+
"""Logs a profile report after the conclusion of the training run."""
296+
super().describe()
297+
if self.output_file:
298+
self.output_file.flush()
299+
300+
def __del__(self):
301+
"""Close profiler's stream."""
302+
if self.output_file:
303+
self.output_file.close()

Diff for: tests/special_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_
3535
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
3636
python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler
3737
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model
38+
nvprof --profile-from-start off -o trace_name.prof -- python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_nested_emit_nvtx

Diff for: tests/trainer/test_trainer.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,14 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch
220220
@pytest.mark.parametrize(
221221
["accumulate_grad_batches", "limit_train_batches"],
222222
[
223-
({1: 2, 3: 4}, 1.0),
224-
({1: 2, 3: 4}, 0.5), # not to be divisible by accumulate_grad_batches on purpose
223+
({
224+
1: 2,
225+
3: 4
226+
}, 1.0),
227+
({
228+
1: 2,
229+
3: 4
230+
}, 0.5), # not to be divisible by accumulate_grad_batches on purpose
225231
(3, 1.0),
226232
(3, 0.8), # not to be divisible by accumulate_grad_batches on purpose
227233
(4, 1.0),
@@ -239,9 +245,7 @@ def on_batch_start(self, *_):
239245
def on_batch_end(self, outputs, batch, batch_idx, *_):
240246
self.on_train_batch_start_end_dict = self.state_dict()
241247
for key in self.on_train_batch_start_end_dict.keys():
242-
equal = torch.equal(
243-
self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key]
244-
)
248+
equal = torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
245249
if (batch_idx + 1) == self.trainer.num_training_batches:
246250
assert equal
247251
else:
@@ -1587,6 +1591,22 @@ def test_pytorch_profiler_nested(tmpdir):
15871591
assert pa[n] == expected_[n]
15881592

15891593

1594+
@RunIf(min_gpus=1, special=True)
1595+
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
1596+
"""
1597+
This test check emit_nvtx is correctly supported
1598+
"""
1599+
profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True)
1600+
1601+
model = BoringModel()
1602+
trainer = Trainer(
1603+
fast_dev_run=True,
1604+
profiler=profiler,
1605+
gpus=1,
1606+
)
1607+
trainer.fit(model)
1608+
1609+
15901610
@pytest.mark.parametrize(
15911611
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
15921612
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],
@@ -1738,6 +1758,7 @@ def test_train_loop_system(tmpdir):
17381758
)
17391759

17401760
class TestOptimizer(SGD):
1761+
17411762
def step(self, *args, **kwargs):
17421763
called_methods.append("step")
17431764
return super().step(*args, **kwargs)
@@ -1747,6 +1768,7 @@ def zero_grad(self, *args, **kwargs):
17471768
return super().zero_grad(*args, **kwargs)
17481769

17491770
class TestModel(BoringModel):
1771+
17501772
def configure_optimizers(self):
17511773
return TestOptimizer(self.parameters(), lr=0.1)
17521774

0 commit comments

Comments
 (0)