From 4677f6e62fc82892bc57092a61a87c41c8942321 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Tue, 21 Dec 2021 05:23:14 +0000 Subject: [PATCH] Move clear and run_pre_save_hooks as internal methonds only. --- python/paddle/fluid/dygraph/jit.py | 51 +++---------------- .../unittests/test_jit_pre_save_hooks.py | 7 ++- python/paddle/jit/__init__.py | 4 +- 3 files changed, 10 insertions(+), 52 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 09bc1833c060c..0bbc86c165be5 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -44,8 +44,7 @@ __all__ = [ 'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level', - 'set_verbosity', 'save', 'load', 'not_to_static', 'register_save_pre_hook', - 'clear_save_pre_hooks' + 'set_verbosity', 'save', 'load', 'not_to_static', 'register_save_pre_hook' ] @@ -567,10 +566,10 @@ def register_save_pre_hook(hook): IMAGE_SIZE = 256 CLASS_NUM = 10 - class LinearNet(nn.Layer): + class LinearNet(paddle.nn.Layer): def __init__(self): super(LinearNet, self).__init__() - self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM) def forward(self, x): return self._linear(x) @@ -599,45 +598,7 @@ def save_pre_hook(layer, input_spec, configs): return HookRemoveHelper(hook) -def clear_save_pre_hooks(): - """ - Clear all save pre-hooks for `paddle.jit.save`. - - Args: - None - - Returns: - None - - Examples: - .. code-block:: python - - import numpy as np - import paddle - - IMAGE_SIZE = 256 - CLASS_NUM = 10 - - class LinearNet(nn.Layer): - def __init__(self): - super(LinearNet, self).__init__() - self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) - - def forward(self, x): - return self._linear(x) - - saving_count = 0 - def save_pre_hook(layer, input_spec, configs): - global saving_count - saving_count += 1 - - remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook) - - layer = LinearNet() - paddle.jit.clear_save_pre_hooks() - paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])]) - # saving_count == 0 - """ +def _clear_save_pre_hooks(): global _save_pre_hooks_lock global _save_pre_hooks _save_pre_hooks_lock.acquire() @@ -654,7 +615,7 @@ def _remove_save_pre_hook(hook): _save_pre_hooks_lock.release() -def run_save_pre_hooks(func): +def _run_save_pre_hooks(func): def wrapper(layer, path, input_spec=None, **configs): global _save_pre_hooks for hook in _save_pre_hooks: @@ -664,7 +625,7 @@ def wrapper(layer, path, input_spec=None, **configs): return wrapper -@run_save_pre_hooks +@_run_save_pre_hooks @switch_to_static_graph def save(layer, path, input_spec=None, **configs): """ diff --git a/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py b/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py index 071b3c8927e9a..a7a75fcc37a1b 100644 --- a/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py +++ b/python/paddle/fluid/tests/unittests/test_jit_pre_save_hooks.py @@ -19,8 +19,7 @@ import paddle from paddle.jit import register_save_pre_hook -from paddle.jit import clear_save_pre_hooks -from paddle.fluid.dygraph.jit import run_save_pre_hooks +from paddle.fluid.dygraph.jit import _run_save_pre_hooks, _clear_save_pre_hooks _counter = 0 @@ -46,13 +45,13 @@ def fake_func(*args, **kwgs): self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) remove_handler = register_save_pre_hook(fake_func) - clear_save_pre_hooks() + _clear_save_pre_hooks() self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0) global _counter _counter = 0 remove_handler = register_save_pre_hook(fake_func) - func_with_hook = run_save_pre_hooks(fake_func) + func_with_hook = _run_save_pre_hooks(fake_func) func_with_hook(None, None) self.assertEqual(_counter, 2) diff --git a/python/paddle/jit/__init__.py b/python/paddle/jit/__init__.py index ab2b5b5a5feca..2b4d1a1dac2d9 100644 --- a/python/paddle/jit/__init__.py +++ b/python/paddle/jit/__init__.py @@ -23,7 +23,6 @@ from ..fluid.dygraph.jit import declarative as to_static # noqa: F401 from ..fluid.dygraph.jit import not_to_static # noqa: F401 from ..fluid.dygraph.jit import register_save_pre_hook # noqa: F401 -from ..fluid.dygraph.jit import clear_save_pre_hooks # noqa: F401 from ..fluid.dygraph import ProgramTranslator # noqa: F401 from ..fluid.dygraph.io import TranslatedLayer # noqa: F401 @@ -39,6 +38,5 @@ 'set_code_level', 'set_verbosity', 'not_to_static', - 'register_save_pre_hook', - 'clear_save_pre_hooks' + 'register_save_pre_hook' ]