Skip to content

Commit

Permalink
[cherry-pick]Reduce performance influence by record event in python (#…
Browse files Browse the repository at this point in the history
…42142)

* fix kenrel name apperance (#42071)

* Reduce performance influence by record event in python (#42040)

* optimize performance

* fix

* improve coverage

* fix

* fix
  • Loading branch information
rainyfly authored Apr 24, 2022
1 parent b543998 commit 338fcc1
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 23 deletions.
25 changes: 15 additions & 10 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import paddle
import paddle.profiler as profiler
from paddle.profiler.utils import in_profiler_mode
from .. import core, layers
from ..framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph
from ..multiprocess_utils import _set_SIGCHLD_handler, MP_STATUS_CHECK_INTERVAL, CleanupFuncRegistrar
Expand Down Expand Up @@ -252,10 +253,11 @@ def _thread_loop(self, legacy_expected_place):
self._exit_thread_expectedly()

def __next__(self):
trace_event = profiler.RecordEvent(
name="_DataLoaderIterSingleProcess",
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
if in_profiler_mode():
trace_event = profiler.RecordEvent(
name="_DataLoaderIterSingleProcess",
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
try:
benchmark().check_if_need_record(self)
benchmark().before_reader()
Expand Down Expand Up @@ -294,7 +296,8 @@ def __next__(self):
self._try_shutdown_all()
six.reraise(*sys.exc_info())
finally:
trace_event.end()
if in_profiler_mode():
trace_event.end()

def _shutdown_thread(self):
if self._thread:
Expand Down Expand Up @@ -708,10 +711,11 @@ def _shutdown_on_exit(self):
self._try_shutdown_all(1)

def __next__(self):
trace_event = profiler.RecordEvent(
name="_DataLoaderIterMultiProcess",
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
if in_profiler_mode():
trace_event = profiler.RecordEvent(
name="_DataLoaderIterMultiProcess",
event_type=profiler.TracerEventType.Dataloader)
trace_event.begin()
try:
benchmark().check_if_need_record(self)
benchmark().before_reader()
Expand Down Expand Up @@ -765,7 +769,8 @@ def __next__(self):
self._try_shutdown_all()
six.reraise(*sys.exc_info())
finally:
trace_event.end()
if in_profiler_mode():
trace_event.end()

# python2 compatibility
def next(self):
Expand Down
10 changes: 7 additions & 3 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import paddle
import paddle.profiler as profiler
from paddle.profiler.utils import in_profiler_mode

from . import parallel_helper
from .. import unique_name
Expand Down Expand Up @@ -906,8 +907,11 @@ def _dygraph_call_func(self, *inputs, **kwargs):

self._built = True

with profiler.RecordEvent(self.full_name(),
profiler.TracerEventType.Forward):
if in_profiler_mode():
with profiler.RecordEvent(self.full_name(),
profiler.TracerEventType.Forward):
outputs = self.forward(*inputs, **kwargs)
else:
outputs = self.forward(*inputs, **kwargs)

for forward_post_hook in self._forward_post_hooks.values():
Expand All @@ -919,7 +923,7 @@ def _dygraph_call_func(self, *inputs, **kwargs):

def __call__(self, *inputs, **kwargs):
if (not in_declarative_mode()) and (not self._forward_pre_hooks) \
and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode():
and (not self._forward_post_hooks) and (not self._built) and in_dygraph_mode() and (not in_profiler_mode()):
self._build_once(*inputs, **kwargs)
return self.forward(*inputs, **kwargs)
else:
Expand Down
11 changes: 7 additions & 4 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from paddle.fluid.data_feeder import convert_dtype, _PADDLE_DTYPE_2_NUMPY_DTYPE
import paddle.utils.deprecated as deprecated
import paddle.profiler as profiler
from paddle.profiler.utils import in_profiler_mode
from paddle import _C_ops

_grad_scalar = None
Expand Down Expand Up @@ -247,9 +248,10 @@ def backward(self, grad_tensor=None, retain_graph=False):
"""
if framework._non_static_mode():
record_event = profiler.RecordEvent(
"Gradient Backward", profiler.TracerEventType.Backward)
record_event.begin()
if in_profiler_mode():
record_event = profiler.RecordEvent(
"Gradient Backward", profiler.TracerEventType.Backward)
record_event.begin()
if grad_tensor is not None:
if framework._in_eager_mode_:
assert isinstance(
Expand Down Expand Up @@ -288,7 +290,8 @@ def backward(self, grad_tensor=None, retain_graph=False):
core.dygraph_run_backward([self], [grad_tensor],
retain_graph,
framework._dygraph_tracer())
record_event.end()
if in_profiler_mode():
record_event.end()
else:
raise ValueError(
"Variable.backward() is only available in DyGraph mode")
Expand Down
36 changes: 36 additions & 0 deletions python/paddle/fluid/tests/unittests/test_newprofiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,42 @@ def my_sheduler1(num_step):
prof.export(path='./test_profiler_pb.pb', format='pb')
prof.summary()
result = profiler.utils.load_profiler_result('./test_profiler_pb.pb')
prof = None
dataset = RandomDataset(10 * 4)
simple_net = SimpleNet()
opt = paddle.optimizer.SGD(learning_rate=1e-3,
parameters=simple_net.parameters())
loader = DataLoader(
dataset, batch_size=4, shuffle=True, drop_last=True, num_workers=2)
prof = profiler.Profiler(on_trace_ready=lambda prof: None)
prof.start()
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = F.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
opt.minimize(avg_loss)
simple_net.clear_gradients()
prof.step()
prof.stop()
prof.summary()
prof = None
dataset = RandomDataset(10 * 4)
simple_net = SimpleNet()
loader = DataLoader(dataset, batch_size=4, shuffle=True, drop_last=True)
opt = paddle.optimizer.Adam(
learning_rate=1e-3, parameters=simple_net.parameters())
prof = profiler.Profiler(on_trace_ready=lambda prof: None)
prof.start()
for i, (image, label) in enumerate(loader()):
out = simple_net(image)
loss = F.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
opt.step()
simple_net.clear_gradients()
prof.step()
prof.stop()


class TestNvprof(unittest.TestCase):
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/profiler/profiler_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import collections
from enum import Enum
import re

from paddle.fluid.core import TracerEventType

Expand Down Expand Up @@ -1317,10 +1318,11 @@ def format_ratio(ratio, indent=0):
append(header_sep)
append(row_format.format(*headers))
append(header_sep)
kernel_name_pattern = re.compile('(.+?)(<.*>)(\(.*\))')
for row_values in all_row_values:
indx = row_values[0].find('(')
if indx != -1:
name = row_values[0][:indx]
match = kernel_name_pattern.match(row_values[0])
if match:
name = match.group(1) + match.group(2)
else:
name = row_values[0]
if len(name) > name_column_width:
Expand Down
18 changes: 15 additions & 3 deletions python/paddle/profiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.fluid.core import (_RecordEvent, TracerEventType)

_is_profiler_used = False
_has_optimizer_wrapped = False

_AllowedEventTypeList = [
TracerEventType.Dataloader, TracerEventType.ProfileStep,
Expand Down Expand Up @@ -154,20 +155,31 @@ def load_profiler_result(filename: str):
return core.load_profiler_result(filename)


def in_profiler_mode():
return _is_profiler_used == True


def wrap_optimizers():
def optimizer_warpper(func):
@functools.wraps(func)
def warpper(*args, **kwargs):
with RecordEvent(
'Optimization Step',
event_type=TracerEventType.Optimization):
if in_profiler_mode():
with RecordEvent(
'Optimization Step',
event_type=TracerEventType.Optimization):
return func(*args, **kwargs)
else:
return func(*args, **kwargs)

return warpper

global _has_optimizer_wrapped
if _has_optimizer_wrapped == True:
return
import paddle.optimizer as optimizer
for classname in optimizer.__all__:
if classname != 'Optimizer':
classobject = getattr(optimizer, classname)
if getattr(classobject, 'step', None) != None:
classobject.step = optimizer_warpper(classobject.step)
_has_optimizer_wrapped = True

0 comments on commit 338fcc1

Please sign in to comment.