Skip to content

Commit

Permalink
[Prim][PIR]Fix addn prim (#57675)
Browse files Browse the repository at this point in the history
* support prim backward

* fix jit support pir prim

* move new_ir to pir

* add test case

* fix prim pir addn
  • Loading branch information
cyber-pioneer authored Sep 25, 2023
1 parent 1fcefb1 commit 6a25e9d
Show file tree
Hide file tree
Showing 16 changed files with 191 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ for (auto arg: stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
{% if 'composite' in api and api.name in vjp_comp_white_list %}
if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled()) {
std::string op_name = "{{api.name}}";
auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(op_name);
if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) {
{% filter indent(2, True) %}{{body_prim(api)}}{% endfilter %}
} else {
{% filter indent(2, True) %}{{body_unprim(api)}}{% endfilter %}
Expand Down
33 changes: 30 additions & 3 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def _prepare_python_api_arguments(op):
op (Operator): The target operator.
"""
op_inputs = [x.source() for x in op.operands()]
# The inputs of PIR op builtin.combine will be restored as list of tensor.
if op.name() in ["builtin.combine"]:
return (op_inputs,)

op_attrs_dict = op.attrs()
op_attrs_name = op.get_attr_names()
op_attrs = [op_attrs_dict[x] for x in op_attrs_name]
Expand Down Expand Up @@ -198,15 +202,28 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):

if isinstance(block, Block):
ops_list = block.ops
for op in ops_list:
temp_op = None
temp_inputs = None
for idx, op in enumerate(ops_list):
op_name = op.name()
decom_rule = register.get_decomp_rule(op_name)
lower = decom_rule and op_filter(op)

if op.name() == "builtin.combine":
temp_op = op
temp_inputs = _prepare_python_api_arguments(op)

if lower:
core.prim_config["composite_ops_record"].add(op_name)
input_args = _prepare_python_api_arguments(op)
pir.set_insertion_point(op)
if (
temp_op is not None
and ops_list[idx - 1].name() == "builtin.combine"
):
input_args = temp_inputs
pir.set_insertion_point(temp_op)
else:
input_args = _prepare_python_api_arguments(op)
pir.set_insertion_point(op)
orig_outs = op.results()
new_outs = _build_tensor_tuple(decom_rule(*input_args))

Expand All @@ -217,6 +234,16 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):

op.replace_all_uses_with(new_outs)
block.remove_op(op)

if temp_op is not None:
remove_op = True
for item in temp_op.results():
if item.has_one_use():
remove_op = False
break
if remove_op:
block.remove_op(temp_op)
temp_op = None
return

elif isinstance(block, typing.Sequence):
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,11 @@ def layernorm_composite(x, scale, bias, epsilon, begin_norm_axis):
if is_amp:
out = cast(out, dtype)
return out, mean_, variance


@register_decomp('pd_op.add_n')
def sum_composite(x):
ans = x[0]
for xi in x[1:]:
ans = xi + ans
return ans
8 changes: 1 addition & 7 deletions python/paddle/jit/dy2static/newir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,14 +958,8 @@ def create_out(var_id):
else:
tensor_type = paddle.dtype(8) # SELECT ROW TENSOR

# TODO(xiongkun): more elegent way to do it.

ir_dtype_2_tensor_dtype = {
10: paddle.dtype(5),
}

out = core.eager.Tensor(
ir_dtype_2_tensor_dtype[int(var.dtype)],
framework.paddle_type_to_proto_type[var.dtype],
var.shape,
"",
tensor_type,
Expand Down
26 changes: 15 additions & 11 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,26 +1495,30 @@ def before_append_backward(self, forward_program, src_vars):
dst_vars = decomposition.decompose(
forward_program, src_vars, blacklist=self.custom_vjps
)
return forward_program, dst_vars
return forward_program, dst_vars
return forward_program, src_vars

def after_append_backward(self, whole_program, src_vars, forward_end_idx):
with backend_guard(self.backend):
backward_length = (
len(whole_program.global_block().ops) - forward_end_idx
)
if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0:
# only process backward part of block
dst_vars = decomposition.decompose(whole_program, src_vars)
new_start_index = (
len(whole_program.global_block().ops) - backward_length
)
return whole_program, new_start_index, dst_vars
backward_length = (
len(whole_program.global_block().ops) - forward_end_idx
)
dst_vars = decomposition.decompose(
whole_program, src_vars, whitelist=self.custom_vjps
)
new_start_index = (
len(whole_program.global_block().ops) - backward_length
)
return whole_program, new_start_index, dst_vars
return whole_program, forward_end_idx, src_vars

def after_infer(self, infer_program, src_vars):
with backend_guard(self.backend):
if core._is_fwd_prim_enabled():
dst_vars = decomposition.decompose(infer_program, src_vars)
return infer_program, dst_vars
return infer_program, dst_vars
return infer_program, src_vars


class ProgramCache:
Expand Down
3 changes: 2 additions & 1 deletion test/legacy_test/test_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_check_output(self):
check_prim=True,
check_cinn=True,
check_new_ir=True,
check_prim_pir=True,
)

def test_check_grad(self):
Expand All @@ -70,8 +71,8 @@ def test_check_grad(self):
'Out',
check_prim=True,
check_cinn=True,
check_prim_pir=True,
check_new_ir=True,
check_prim_pir=True,
)


Expand Down
2 changes: 1 addition & 1 deletion test/prim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ add_subdirectory(prim)
add_subdirectory(model)
add_subdirectory(composite_ops)
add_subdirectory(process)
add_subdirectory(new_ir_prim)
add_subdirectory(pir_prim)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(TEST_PRIM_PURE_NEW_IR_CASES test_prim_program test_prim_simpnet
test_prim_custom_vjp test_prim_jit)
set(TEST_PRIM_PURE_NEW_IR_CASES
test_prim_program test_prim_simpnet test_prim_custom_vjp test_prim_jit
test_pir_prim_flags)

foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
Expand Down
File renamed without changes.
File renamed without changes.
128 changes: 128 additions & 0 deletions test/prim/pir_prim/test_pir_prim_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
import paddle.nn.functional as F
from paddle.base import core
from paddle.decomposition import decompose


class TestPrimBlacklistFlags(unittest.TestCase):
def not_in_blacklist(self):
inputs = np.random.random([2, 3, 4]).astype("float32")
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
y = F.gelu(x)

fwd_ops = [op.name() for op in main_program.global_block().ops]
# Ensure that tanh in original block
self.assertTrue('pd_op.gelu' in fwd_ops)

[y] = decompose(main_program, [y])

fwd_ops_new = [op.name() for op in main_program.global_block().ops]
# Ensure that tanh is splitted into small ops
self.assertTrue('pd_op.gelu' not in fwd_ops_new)

exe = paddle.static.Executor()
exe.run(startup_program)
_ = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
paddle.disable_static()
core._set_prim_forward_enabled(False)

def in_blacklist(self):
inputs = np.random.random([2, 3, 4]).astype("float32")
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
y = F.gelu(x)

fwd_ops = [op.name() for op in main_program.global_block().ops]
# Ensure that tanh in original block
self.assertTrue('pd_op.gelu' in fwd_ops)

_ = decompose(main_program, [y])

fwd_ops_new = [op.name() for op in main_program.global_block().ops]
# Ensure that tanh is splitted into small ops
self.assertTrue('pd_op.gelu' in fwd_ops_new)

exe = paddle.static.Executor()
exe.run(startup_program)
_ = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
paddle.disable_static()
core._set_prim_forward_enabled(False)

def test_prim_forward_blacklist(self):
self.not_in_blacklist()
core._set_prim_forward_blacklist("pd_op.gelu")
self.in_blacklist()


class PrimeNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

def forward(self, x):
x1 = paddle.tanh(x)
x2 = paddle.exp(x)
x3 = x1 + x2
res = paddle.nn.functional.gelu(x3)
return res


class TestPrimBackwardBlacklistFlags(unittest.TestCase):
def train(self):
x = paddle.randn([2, 4])
x.stop_gradient = False
net = PrimeNet()
net = paddle.jit.to_static(net)
out = net(x)
loss = paddle.mean(out)
loss.backward()
self.check_prim(net)

def check_prim(self, net):
block = net.forward.program_cache.last()[-1][
-1
].train_program.global_block()
ops = [op.name() for op in block.ops]
self.assertTrue('pd_op.tanh_grad' in ops)
self.assertTrue('pd_op.exp_grad' in ops)
self.assertTrue('pd_op.gelu_grad' not in ops)

def test_prim_backward_blacklist(self):
core._set_prim_all_enabled(True)
core._set_prim_backward_blacklist("tanh_grad", "exp_grad")
self.train()
core._set_prim_all_enabled(False)


if __name__ == '__main__':
unittest.main()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 comments on commit 6a25e9d

Please sign in to comment.