Skip to content

Commit

Permalink
Move clear and run_pre_save_hooks as internal methonds only.
Browse files Browse the repository at this point in the history
  • Loading branch information
mingxu1067 committed Dec 21, 2021
1 parent 9b9b1bb commit 4677f6e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 52 deletions.
51 changes: 6 additions & 45 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@

__all__ = [
'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level',
'set_verbosity', 'save', 'load', 'not_to_static', 'register_save_pre_hook',
'clear_save_pre_hooks'
'set_verbosity', 'save', 'load', 'not_to_static', 'register_save_pre_hook'
]


Expand Down Expand Up @@ -567,10 +566,10 @@ def register_save_pre_hook(hook):
IMAGE_SIZE = 256
CLASS_NUM = 10
class LinearNet(nn.Layer):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
self._linear = paddle.nn.Linear(IMAGE_SIZE, CLASS_NUM)
def forward(self, x):
return self._linear(x)
Expand Down Expand Up @@ -599,45 +598,7 @@ def save_pre_hook(layer, input_spec, configs):
return HookRemoveHelper(hook)


def clear_save_pre_hooks():
"""
Clear all save pre-hooks for `paddle.jit.save`.
Args:
None
Returns:
None
Examples:
.. code-block:: python
import numpy as np
import paddle
IMAGE_SIZE = 256
CLASS_NUM = 10
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
def forward(self, x):
return self._linear(x)
saving_count = 0
def save_pre_hook(layer, input_spec, configs):
global saving_count
saving_count += 1
remove_handler = paddle.jit.register_save_pre_hook(save_pre_hook)
layer = LinearNet()
paddle.jit.clear_save_pre_hooks()
paddle.jit.save(layer, "/tmp", [paddle.static.InputSpec(shape=[-1, IMAGE_SIZE])])
# saving_count == 0
"""
def _clear_save_pre_hooks():
global _save_pre_hooks_lock
global _save_pre_hooks
_save_pre_hooks_lock.acquire()
Expand All @@ -654,7 +615,7 @@ def _remove_save_pre_hook(hook):
_save_pre_hooks_lock.release()


def run_save_pre_hooks(func):
def _run_save_pre_hooks(func):
def wrapper(layer, path, input_spec=None, **configs):
global _save_pre_hooks
for hook in _save_pre_hooks:
Expand All @@ -664,7 +625,7 @@ def wrapper(layer, path, input_spec=None, **configs):
return wrapper


@run_save_pre_hooks
@_run_save_pre_hooks
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import paddle
from paddle.jit import register_save_pre_hook
from paddle.jit import clear_save_pre_hooks
from paddle.fluid.dygraph.jit import run_save_pre_hooks
from paddle.fluid.dygraph.jit import _run_save_pre_hooks, _clear_save_pre_hooks

_counter = 0

Expand All @@ -46,13 +45,13 @@ def fake_func(*args, **kwgs):
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0)

remove_handler = register_save_pre_hook(fake_func)
clear_save_pre_hooks()
_clear_save_pre_hooks()
self.assertEqual(len(paddle.fluid.dygraph.jit._save_pre_hooks), 0)

global _counter
_counter = 0
remove_handler = register_save_pre_hook(fake_func)
func_with_hook = run_save_pre_hooks(fake_func)
func_with_hook = _run_save_pre_hooks(fake_func)
func_with_hook(None, None)
self.assertEqual(_counter, 2)

Expand Down
4 changes: 1 addition & 3 deletions python/paddle/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ..fluid.dygraph.jit import declarative as to_static # noqa: F401
from ..fluid.dygraph.jit import not_to_static # noqa: F401
from ..fluid.dygraph.jit import register_save_pre_hook # noqa: F401
from ..fluid.dygraph.jit import clear_save_pre_hooks # noqa: F401
from ..fluid.dygraph import ProgramTranslator # noqa: F401
from ..fluid.dygraph.io import TranslatedLayer # noqa: F401

Expand All @@ -39,6 +38,5 @@
'set_code_level',
'set_verbosity',
'not_to_static',
'register_save_pre_hook',
'clear_save_pre_hooks'
'register_save_pre_hook'
]

0 comments on commit 4677f6e

Please sign in to comment.