Skip to content

Commit

Permalink
add fused_elemwise_add_activation
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Oct 30, 2023
1 parent 6a25e9d commit 4be6756
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 9 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
'c_allgather',
'seed',
"fused_attention",
'fused_elemwise_add_activation',
]


Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,13 @@
infer_meta:
func: FusedAttentionInferMeta
optional: cache_kv, ln_scale, ln_bias, qkv_bias, src_mask, out_linear_bias, ln_scale_2, ln_bias_2

- op: fused_elemwise_add_activation
args: (Tensor x, Tensor y, str[] functor_list, int axis=-1, bool save_intermediate_out=false)
output: Tensor(out), Tensor(intermediate_out)
kernel:
func: fused_elemwise_add_activation
infer_meta:
func : FusedElemwiseAddActivationInferMeta
backward: fused_elemwise_add_activation_grad
intermediate: intermediate_out
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,13 @@
kernel:
func: set_value_grad
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]

- backward_op: fused_elemwise_add_activation_grad
forward: fused_elemwise_add_activation(Tensor x, Tensor y, str[] functor_list, int axis=-1, bool save_intermediate_out=false) -> Tensor(out), Tensor(intermediate_out)
args: (Tensor x, Tensor y, Tensor out, Tensor intermediate_out, Tensor out_grad, str[] functor_list, int axis=-1, bool save_intermediate_out=false)
output: Tensor(x_grad), Tensor(y_grad)
infer_meta:
func: FusedElemwiseAddActivationGradInferMeta
kernel:
func: fused_elemwise_add_activation_grad
optional : x, intermediate_out
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ const std::unordered_set<std::string> LegacyOpList = {
"pd_op.c_allreduce_max_",
"pd_op.c_allgather",
"pd_op.seed",
"pd_op.share_data"};
"pd_op.share_data",
"pd_op.fused_elemwise_add_activation",
"pd_op.fused_elemwise_add_activation_grad"};

enum class AttrType {
UNDEFINED = 0,
Expand Down
19 changes: 16 additions & 3 deletions paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ void BuildRuntimeContext(
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
pir::Value ptr = op->operand_source(index);

if (ptr == nullptr) {
continue;
}
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;

Expand Down Expand Up @@ -695,7 +697,9 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
pir::Value ptr = op->operand_source(index);

if (ptr == nullptr) {
continue;
}
auto in_var_name = name_map.at(ptr);

auto legacy_attr_name = op_normalizer.GetLegacyArgName(fluid_op_name, name);
Expand Down Expand Up @@ -758,6 +762,13 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
attribute.dyn_cast<pir::DoubleAttribute>().data()); // NOLINT
}
attr_map[name] = vec_double;
} else if (array_list[0].isa<pir::StrAttribute>()) {
std::vector<std::string> vec_string;
for (auto attribute : array_list) {
vec_string.push_back(
attribute.dyn_cast<pir::StrAttribute>().AsString()); // NOLINT
}
attr_map[name] = vec_string;
} else {
std::stringstream ss;
val.Print(ss);
Expand All @@ -779,7 +790,9 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
for (size_t i = 0; i < output_name_list.size(); ++i) {
auto name = output_name_list[i];
pir::Value ptr = op->result(i);

if (ptr == nullptr) {
continue;
}
auto out_var_name = name_map.at(ptr);

auto type = ptr.type();
Expand Down
19 changes: 19 additions & 0 deletions paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,25 @@ void BuildPhiContext(
}
}
ctx->EmplaceBackAttr(vec_res);

} else if (attr_type_name == "pir::ArrayAttribute<pir::StrAttribute>") {
auto array_list = attr_map[t].dyn_cast<pir::ArrayAttribute>().AsVector();

std::vector<std::string> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
array_list[0].isa<pir::StrAttribute>(),
true,
phi::errors::PreconditionNotMet(
"Element in array list MUST be pir::StrAttribute "));

for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(
array_list[i].dyn_cast<pir::StrAttribute>().AsString());
}
}
ctx->EmplaceBackAttr(vec_res);

} else if (attr_type_name == "paddle::dialect::PlaceAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::PlaceAttribute>().data());
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ bool NeedFallBackFromGPUDNN2GPU(pir::Operation* op,
return true;
}
}

return false;
}

Expand Down Expand Up @@ -1240,6 +1239,7 @@ pir::Operation* BuildPhiKernelOp(

pir::OpInfo legacy_kernel_op_info =
ctx->GetRegisteredOpInfo(paddle::dialect::LegacyKernelOp::name());

pir::Operation* op;
if (dialect::IsLegacyOp(op_item->name())) {
op = pir::Operation::Create(
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3158,6 +3158,18 @@
outputs:
{out: Out}

- op: fused_elemwise_add_activation
backward: fused_elemwise_add_activation_grad
inputs :
{x: X, y: Y}
outputs :
{out : Out, intermediate_out : IntermediateOut}
attrs :
functor_list: functor_list
extra :
attrs : [int axis=-1, bool save_intermediate_out=false]
outputs : [intermediate_out]

- op: get_tensor_from_selected_rows
inputs :
x : X
Expand Down
111 changes: 110 additions & 1 deletion paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unordered_set>

#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"

namespace phi {
Expand Down Expand Up @@ -232,6 +233,114 @@ void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
}
}

static bool IsUnaryCompound(const std::vector<std::string>& functor_list) {
PADDLE_ENFORCE_EQ(
functor_list.size(),
2,
phi::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(),
2));
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul",
"elementwise_add_grad",
"elementwise_mul_grad"};
return binary_fun.count(functor_list[1]) != 0;
}
static bool InputXCanBeAbsent(const std::vector<std::string>& functor_list) {
PADDLE_ENFORCE_EQ(
functor_list.size(),
2,
phi::errors::InvalidArgument(
"Invalid functor list size %d, which should be equal to %d.",
functor_list.size(),
2));
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
binary_fun.count(functor_list[1]) != 0;
}

void FusedElemwiseAddActivationGradInferMeta(
const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out,
const MetaTensor& intermediate_out,
const MetaTensor& out_grad,
const std::vector<std::string>& functor_list,
int axis,
bool save_intermediate_out,
MetaTensor* x_grad,
MetaTensor* y_grad) {
PADDLE_ENFORCE_NOT_NULL(
out_grad,
phi::errors::InvalidArgument("Input(Out@Grad) should not be null."));

if (save_intermediate_out) {
PADDLE_ENFORCE_NOT_NULL(intermediate_out,
phi::errors::InvalidArgument(
"Input(IntermediateOut) should not be null."));
} else {
if (!InputXCanBeAbsent(functor_list)) {
PADDLE_ENFORCE_NOT_NULL(
x, phi::errors::InvalidArgument("Input(X) should not be null."));
}
}

if (x_grad) {
if (x) {
x_grad->set_dims(x.dims());
x_grad->share_lod(x);
} else {
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE_EQ(
InputXCanBeAbsent(functor_list),
true,
phi::errors::InvalidArgument(
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent."));

// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, otherwise, we could not infer the shape of dx.

x_grad->set_dims(out_grad.dims());
x_grad->share_lod(out_grad);
}
}

if (y_grad) {
PADDLE_ENFORCE_NOT_NULL(
y, phi::errors::InvalidArgument("Input(Y) should not be null."));
y_grad->set_dims(y.dims());
y_grad->share_lod(y);
}

// if (intermediate_out_grad) {
// // For Unary(Binary(X, Y)), IntermediateOut should not be empty.
// if (IsUnaryCompound(functor_list)) {
// intermediate_out_grad->set_dims(out_grad.dims());
// intermediate_out_grad->share_lod(out_grad);
// } else {
// intermediate_out_grad->set_dims(y.dims());
// intermediate_out_grad->share_lod(y);
// }
// }
bool elemntwise_add_grad_detected = false;
for (auto names : functor_list) {
if (names == "elementwise_add_grad") {
elemntwise_add_grad_detected = true;
break;
}
}
PADDLE_ENFORCE_EQ(
elemntwise_add_grad_detected,
true,
phi::errors::InvalidArgument(
"When the FusedElemwiseAddActivationOpGrad Is used in fused pass, "
"the elementwise_add_grad Op must be"
"detected and used, Please check the fuse pass pattern"));
}

void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax,
const MetaTensor& loss_grad,
Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,17 @@ void FusedRopeGradInferMeta(const MetaTensor& sin,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);

void FusedElemwiseAddActivationGradInferMeta(
const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& out,
const MetaTensor& intermediate_out,
const MetaTensor& out_grad,
const std::vector<std::string>& functor_list,
int axis,
bool save_intermediate_out,
MetaTensor* x_grad,
MetaTensor* y_grad);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
Expand Down
Loading

0 comments on commit 4be6756

Please sign in to comment.