Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Fix losting pre/post hook from outermost layer while jit.save #42273

Merged
merged 3 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 83 additions & 11 deletions python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,11 @@ def from_func_and_args(cls, function_spec, args, kwargs, class_instance):

def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
with_hook = self.kwargs.get("with_hook", False)
return hash((id(self.function_spec),
make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self._spec_names_id, self.class_instance))
self._spec_names_id, self.class_instance, with_hook))

def __eq__(self, other):
return (type(self) is type(other)) and hash(self) == hash(other)
Expand Down Expand Up @@ -413,6 +414,8 @@ def get_concrete_program(self, *args, **kwargs):
Traced ConcreteProgram and executable translated Layer.
"""

with_hook = kwargs.get("with_hook", False)
if "with_hook" in kwargs: kwargs.pop("with_hook")
# 1. unify args/kwargs and replace Tensor with InputSpec
if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(args,
Expand All @@ -421,9 +424,13 @@ def get_concrete_program(self, *args, **kwargs):
args, kwargs)

# 2. generate cache key
cache_key = CacheKey(self._function_spec, input_args_with_spec,
input_kwargs_with_spec, self._class_instance,
**self._kwargs)
cache_key = CacheKey(
self._function_spec,
input_args_with_spec,
input_kwargs_with_spec,
self._class_instance,
**self._kwargs,
with_hook=with_hook)

# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
Expand Down Expand Up @@ -480,11 +487,13 @@ def foo(x, y):
"""
return self.concrete_program_specify_input_spec(input_spec=None)

def concrete_program_specify_input_spec(self, input_spec=None):
def concrete_program_specify_input_spec(self,
input_spec=None,
with_hook=False):
"""
Returns recent ConcreteProgram instance of decorated function while
specifying input_spec. If the self._function_spec already has
input_spce, it will check the compatibility of input input_spec and
input_spec, it will check the compatibility of input input_spec and
the self._function_spec.input_spec. If input input_spec=None, then
this method uses self._function_spec.input_spec

Expand Down Expand Up @@ -516,12 +525,18 @@ def concrete_program_specify_input_spec(self, input_spec=None):
has_input_spec = (desired_input_spec is not None)
if has_input_spec:
concrete_program, _ = self.get_concrete_program(
*desired_input_spec)
*desired_input_spec, with_hook=with_hook)
return concrete_program
else:
raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
format(self._function_spec))
elif with_hook:
cache_key = self._program_cache._recent_cache_key
cache_key.kwargs["with_hook"] = True
concrete_program, _ = self._program_cache[cache_key]
return concrete_program

# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging_utils.warn(
Expand Down Expand Up @@ -588,6 +603,54 @@ def _verify_init_in_dynamic_mode(class_instance):
class_instance))


class HookHelper(object):
"""
Only For converting pre/post hooks operation in outermost layer while jit.save.
Because hooks in sublayer have been processed automatically.
"""

def __init__(self, func, class_instance, with_hook=False):
self.func = func
self.class_instance = class_instance
self.with_hook = with_hook
self.need_apply_hook = with_hook and isinstance(
self.class_instance,
layers.Layer) and getattr(func, "__name__") == "forward"

def apply_pre_hooks(self, inputs):
"""
Apply _forward_pre_hooks from outermost layer
"""
if not self.need_apply_hook: return inputs

inputs = inputs[1:]
for forward_pre_hook in self.class_instance._forward_pre_hooks.values():
hook_result = forward_pre_hook(self.class_instance, inputs)
if hook_result is not None:
if not isinstance(hook_result, tuple):
hook_result = (hook_result, )
inputs = hook_result

return [self.class_instance] + list(inputs)

def apply_post_hooks(self, inputs, outputs):
"""
Apply _forward_post_hooks from outermost layer
"""
if not self.need_apply_hook: return outputs

inputs = inputs[1:]
for forward_post_hook in self.class_instance._forward_post_hooks.values(
):
hook_result = forward_post_hook(self.class_instance, inputs,
outputs)
if hook_result is not None:
outputs = hook_result

inputs.insert(0, self.class_instance)
return outputs


class ConcreteProgram(object):

__slots__ = [
Expand Down Expand Up @@ -629,6 +692,9 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
# Transforms dygraph function into static function and caches it.
dygraph_function = func_spec.dygraph_function
static_func = convert_to_static(dygraph_function)
# apply pre\post hook for outermost layer
hook_helper = HookHelper(dygraph_function, class_instance,
kwargs.get("with_hook", False))

main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program
Expand All @@ -642,12 +708,13 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True):
# 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program)
static_inputs = func_spec.to_static_inputs_with_spec(
input_spec, main_program)
_kwargs = func_spec.to_static_inputs_with_spec(
input_kwargs_spec, main_program)
if class_instance:
inputs = tuple([class_instance] + list(inputs))
static_inputs = tuple([class_instance] + list(
static_inputs))

# 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = _extract_indeed_params_buffers(
Expand All @@ -658,10 +725,13 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
class_instance, False)), param_guard(
get_buffers(class_instance, False)):
try:
# only for jit.save, do nothing while train and eval process
inputs = hook_helper.apply_pre_hooks(static_inputs)
if _kwargs:
outputs = static_func(*inputs, **_kwargs)
else:
outputs = static_func(*inputs)
outputs = hook_helper.apply_post_hooks(inputs, outputs)
except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
error.attach_error_data(e)
Expand All @@ -679,7 +749,7 @@ def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
main_program = update_op_callstack_with_origin_info(main_program)

return ConcreteProgram(
inputs=inputs,
inputs=static_inputs,
outputs=outputs,
parameters=all_parameters_and_buffers,
function=dygraph_function,
Expand Down Expand Up @@ -709,6 +779,7 @@ def __init__(self):
self._caches = collections.OrderedDict()
# trace mostly recent used program
self._recent_key = None
self._recent_cache_key = None

def _build_once(self, cache_key):
concrete_program = ConcreteProgram.from_func_spec(
Expand All @@ -724,6 +795,7 @@ def __getitem__(self, item):
raise ValueError('type(item) should be CacheKey, but received %s' %
type_name(item))
item_id = hash(item)
self._recent_cache_key = item
self._recent_key = item_id
if item_id not in self._caches:
self._caches[item_id] = self._build_once(item)
Expand Down
25 changes: 20 additions & 5 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(self):

# If True, It will save inference program only, and do not save params of Program
self._program_only = False
self.with_hook = False

@property
def output_spec(self):
Expand Down Expand Up @@ -370,7 +371,7 @@ def keep_name_table(self, value):


def _parse_save_configs(configs):
supported_configs = ['output_spec']
supported_configs = ['output_spec', "with_hook"]

# input check
for key in configs:
Expand All @@ -382,6 +383,7 @@ def _parse_save_configs(configs):
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False)

return inner_config

Expand Down Expand Up @@ -454,11 +456,15 @@ def _get_input_var_names(inputs, input_spec):
return result_list


def _get_output_vars(outputs, output_spec):
def _get_output_vars(outputs, output_spec, with_hook=False):
name_no_exists_error = "The tensor `%s` does not exists. " \
"Please make sure the name of example Tensor " \
"in configs.output_spec is the output tensor of " \
"Layer.forward method."
if output_spec and with_hook:
raise RuntimeError(
"Currently not support specify output_spec while founding pre/post hooks in your outermost layer."
)
result_list = []
output_vars_dict = OrderedDict()
for var in flatten(outputs):
Expand Down Expand Up @@ -830,10 +836,16 @@ def fun(inputs):

# parse configs
configs = _parse_save_configs(configs)
# whether outermost layer has pre/post hook, if does, we need also save
# these operators in program.
with_hook = configs.with_hook

scope = core.Scope()
extra_var_info = dict()
if isinstance(layer, Layer):
functions = dir(inner_layer)
if inner_layer._forward_pre_hooks or inner_layer._forward_post_hooks:
with_hook = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面可以在讨论下with_hook参数的是否必要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

else:
# layer is function
functions = [layer, ]
Expand All @@ -842,7 +854,7 @@ def fun(inputs):
static_func = getattr(inner_layer, attr_func, None)
if isinstance(static_func, StaticFunction):
concrete_program = static_func.concrete_program_specify_input_spec(
inner_input_spec)
inner_input_spec, with_hook=with_hook)
elif 'forward' == attr_func:
# transform in jit.save, if input_spec is incomplete, declarative will throw error
# inner_input_spec is list[InputSpec], it should be packed with same structure
Expand All @@ -852,7 +864,8 @@ def fun(inputs):
inner_input_spec)
static_forward = declarative(
inner_layer.forward, input_spec=inner_input_spec)
concrete_program = static_forward.concrete_program
concrete_program = static_forward.concrete_program_specify_input_spec(
with_hook=with_hook)
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
Expand Down Expand Up @@ -943,8 +956,10 @@ def fun(inputs):
# the rule is like [ Get input variables name ]. For output var,
# we only support VarBase spec, and actually, we only need the
# var name of output, and we don't recommended to use output_spec
# print(concrete_program.main_program)
# print(concrete_program.outputs, configs.output_spec)
output_vars = _get_output_vars(concrete_program.outputs,
configs.output_spec)
configs.output_spec, with_hook)

# 5. save inference model
from paddle.fluid.io import save_inference_model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle

import numpy as np


def forward_post_hook1(layer, input, output):
return output + output


def forward_pre_hook1(layer, input):
input_return = (input[0] * 2, )
return input_return


class SimpleNet(paddle.nn.Layer):
def __init__(self, ):
super(SimpleNet, self).__init__()
self.fc1 = paddle.nn.Linear(10, 10)
# sublayer1 register post hook
self.fc1.register_forward_post_hook(forward_post_hook1)

self.fc2 = paddle.nn.Linear(10, 10)
# sublayer2 register pre hook
self.fc2.register_forward_pre_hook(forward_pre_hook1)

# register pre/post hook
self.register_forward_pre_hook(forward_pre_hook1)
self.register_forward_post_hook(forward_post_hook1)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
out = paddle.mean(x)

return out


class TestNestLayerHook(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 10])
self.path = "./net_hook"

def train_net(self, to_static=False):
paddle.seed(2022)
net = SimpleNet()
if to_static:
net = paddle.jit.to_static(net)
out = net(self.x)

if to_static:
paddle.jit.save(net, self.path)

return out.numpy()[0]

def load_train(self):
net = paddle.jit.load(self.path)
out = net(self.x)
return out.numpy()[0]

def test_hook(self):
dy_out = self.train_net(to_static=False)
st_out = self.train_net(to_static=True)
load_out = self.load_train()
print(st_out, dy_out, load_out)
self.assertTrue(
np.allclose(st_out, dy_out),
msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out))
self.assertTrue(
np.allclose(st_out, load_out),
msg='load_out is {}\nstatic_res is {}'.format(load_out, st_out))


if __name__ == "__main__":
unittest.main()