Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jit pre save hook #38186

Merged
merged 11 commits into from
Jan 11, 2022
Merged

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Dec 16, 2021

PR types

New features

PR changes

APIs

Describe

Pre-saving Hooks in paddle.jit.save

  • Added pre-saving hooks mechanism to allow users or developers hand functions to preprocess nn.Layer before jit.save.
  • All submitted functions are threading-safe and unique-hook guaranteed.
  • Added three new APIs
  1. register_save_pre_hook:

Register a function to _pre_save_hooks
Args:

  • func (py-function): This function would be hanged to hooks and executed before paddle.jit.save

Return:

  • HookRemoveHandler: A handler to remove the registered hook.
  1. clear_save_pre_hooks

Clear all hooks.

  • The functions be hanged should following the interface

def func(layer, input_spec, configs)

  • layer (nn.Layer): The layer passed to jit.save.
  • input_spec (list of InputSpec): The input_spec passed to jit.save.
  • configs (dict): The kwargs passed to jit.save, which would be collected into configs.

Usage Example:

  _counter = 0
  def fake_func(layer, input_spec, configs):
      global _counter
      _counter += 1

  remove_handler = register_save_pre_hook(fake_func)
  # len(paddle.fluid.dygraph.jit._save_pre_hooks) == 1
  
  remove_handler = register_save_pre_hook(fake_func)
  # len(paddle.fluid.dygraph.jit._save_pre_hooks) == 1
  # Avoid redundancy hanging
  
  paddle.jit.save(my_layer)
  # _counter == 1
  
  paddle.jit.save(my_layer)
  # _counter == 2

  remove_handler.remove()
  # len(paddle.fluid.dygraph.jit._save_pre_hooks) == 0

  paddle.jit.save(my_layer)
  # _counter == 2

  remove_handler = register_save_pre_hook(fake_func)
  clear_save_pre_hooks()
  # len(paddle.fluid.dygraph.jit._save_pre_hooks) == 0

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

remove_save_pre_hook(self._hook)


def register_save_pre_hook(hook):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should append api doc like other api

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, added api doc to public functions.

return HookRemoveHelper(hook)


def remove_save_pre_hook(hook):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if remove_save_pre_hook is not used as public api, we recommoned nameing it with prefix _ , _remove_save_pre_hook

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, renamed remove_save_pre_hook to _remove_save_pre_hook

_save_pre_hooks_lock.release()


def clear_save_pre_hooks():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clear_save_pre_hooks used in what cases? Is HookRemoveHelper.remove enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently not, but this is a convenient API to let user clear all hooks in one call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the specific usage scenario is not clear, we recommended to use it as an internal method first, and then upgrade it to a public API if necessary in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, currently rename clear_pre_save_hooks with prefix _ and make it internal use only.

_save_pre_hooks_lock.release()


def run_save_pre_hooks(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as _remove_save_pre_hook

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@paddle-bot-old
Copy link

Sorry to inform you that 4677f6e's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@YuanRisheng YuanRisheng merged commit e91f7c0 into PaddlePaddle:develop Jan 11, 2022
@mingxu1067 mingxu1067 deleted the jit_pre_save_hook branch January 12, 2022 01:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants