From 290bf411734a43dcb92ae1078d98ba9bffaefa5d Mon Sep 17 00:00:00 2001 From: NeroLoh <745827440@qq.com> Date: Tue, 2 Jan 2024 20:28:15 +0800 Subject: [PATCH] [xpu]add sine_pos fuse pass and sine_pos xpu kernel (#60025) --- cmake/external/xpu.cmake | 2 +- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/xpu/sine_pos_fuse_pass.cc | 286 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/api/yaml/fused_ops.yaml | 9 + paddle/phi/backends/xpu/xpu2_op_list.cc | 7 +- paddle/phi/infermeta/fusion.cc | 31 ++ paddle/phi/infermeta/fusion.h | 4 + .../kernels/fusion/xpu/sine_pos_xpu_kernel.cc | 55 ++++ .../kernels/legacy/xpu/reduce_max_kernel.cc | 8 +- paddle/phi/kernels/xpu/activation_kernel.cc | 13 +- paddle/phi/kernels/xpu/reduce_max_kernel.cc | 9 +- test/ir/inference/test_xpu_sine_pos_pass.py | 132 ++++++++ 13 files changed, 548 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc create mode 100644 paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc create mode 100644 test/ir/inference/test_xpu_sine_pos_pass.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c0aea59730832..2b5c94872a36c 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -26,7 +26,7 @@ set(XPU_XBLAS_LIB_NAME "libxpu_blas.so") set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so") if(NOT DEFINED XPU_BASE_DATE) - set(XPU_BASE_DATE "20231203") + set(XPU_BASE_DATE "20231218") endif() if(NOT DEFINED XPU_XHPC_BASE_DATE) set(XPU_XHPC_BASE_DATE "20231229") diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 3c7560b69e332..35f5ba1522368 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -318,6 +318,7 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc new file mode 100644 index 0000000000000..6c398b775abf5 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/sine_pos_fuse_pass.cc @@ -0,0 +1,286 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +/* +fuse block in vis model to sine_pos_xpu op +------------------------------------------------------ +sub block: + x y + \ / + \ / + \ / + mul + / \ + / \ + / \ + slice slice + | | + | | + sin cos + \ / + \ / + \ / + stack + | + | + flatten + | + out +------------------------------------------------------ +After the pass is applied: + x y + \ / + \ / + \ / + sine_pos_xpu + | + | + out +*/ + +struct SinePosXPUPattern : public PatternBase { + SinePosXPUPattern(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(ew_mul); + PATTERN_DECL_NODE(slice1); + PATTERN_DECL_NODE(slice2); + PATTERN_DECL_NODE(sin); + PATTERN_DECL_NODE(cos); + PATTERN_DECL_NODE(stack); + PATTERN_DECL_NODE(flatten); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(y); + PATTERN_DECL_NODE(ew_mul_out); + PATTERN_DECL_NODE(slice1_out); + PATTERN_DECL_NODE(slice2_out); + PATTERN_DECL_NODE(sin_out); + PATTERN_DECL_NODE(cos_out); + PATTERN_DECL_NODE(stack_out); + PATTERN_DECL_NODE(flatten_out); +}; + +SinePosXPUPattern::SinePosXPUPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto x = pattern->NewNode(x_repr()) + ->assert_is_op_input("elementwise_mul", "X") + ->assert_more([&](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 3 && x_shape.back() == 1; + }); + auto y = pattern->NewNode(y_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_more([&](Node* node) { + auto x_shape = node->Var()->GetShape(); + size_t x_rank = x_shape.size(); + return x_rank == 1 && x_shape[0] % 2 == 0; + }); + auto* ew_mul = pattern->NewNode(ew_mul_repr()) + ->assert_is_op("elementwise_mul") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("axis") == -1; + }); + auto* ew_mul_out = pattern->NewNode(ew_mul_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_is_op_input("strided_slice", "Input"); + ew_mul->LinksFrom({x, y}).LinksTo({ew_mul_out}); + auto* slice1 = + pattern->NewNode(slice1_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{2} && + op_desc->GetAttrIfExists>("starts") == + std::vector{0} && + op_desc->GetAttrIfExists>("strides") == + std::vector{2}; + }); + auto* slice1_out = pattern->NewNode(slice1_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_input("sin", "X"); + slice1->LinksFrom({ew_mul_out}).LinksTo({slice1_out}); + auto* sin = pattern->NewNode(sin_repr())->assert_is_op("sin"); + auto* sin_out = pattern->NewNode(sin_out_repr()) + ->assert_is_op_output("sin", "Out") + ->assert_is_op_nth_input("stack", "X", 0); + sin->LinksFrom({slice1_out}).LinksTo({sin_out}); + auto* slice2 = + pattern->NewNode(slice2_repr()) + ->assert_is_op("strided_slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{2} && + op_desc->GetAttrIfExists>("starts") == + std::vector{1} && + op_desc->GetAttrIfExists>("strides") == + std::vector{2}; + }); + auto* slice2_out = pattern->NewNode(slice2_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_input("cos", "X"); + slice2->LinksFrom({ew_mul_out}).LinksTo({slice2_out}); + auto* cos = pattern->NewNode(cos_repr())->assert_is_op("cos"); + auto* cos_out = pattern->NewNode(cos_out_repr()) + ->assert_is_op_output("cos", "Out") + ->assert_is_op_nth_input("stack", "X", 1); + cos->LinksFrom({slice2_out}).LinksTo({cos_out}); + auto* stack = pattern->NewNode(stack_repr()) + ->assert_is_op("stack") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("axis") == 3; + }); + auto* stack_out = pattern->NewNode(stack_out_repr()) + ->assert_is_op_output("stack", "Y") + ->assert_is_op_input("flatten_contiguous_range", "X"); + stack->LinksFrom({sin_out, cos_out}).LinksTo({stack_out}); + + auto* flatten = + pattern->NewNode(flatten_repr()) + ->assert_is_op("flatten_contiguous_range") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("start_axis") == 2 && + op_desc->GetAttrIfExists("stop_axis") == 3; + }); + auto* flatten_out = + pattern->NewNode(flatten_out_repr()) + ->assert_is_op_output("flatten_contiguous_range", "Out") + ->AsOutput(); + flatten->LinksFrom({stack_out}).LinksTo({flatten_out}); +} + +} // namespace patterns + +class SinePosFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"sine_pos_fuse_pass"}; +}; + +void SinePosFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::SinePosXPUPattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle SinePosFusePass fuse"; + /* declare operator node's name */ + // declare operator node's name + GET_IR_NODE(ew_mul); + GET_IR_NODE(slice1); + GET_IR_NODE(slice2); + GET_IR_NODE(sin); + GET_IR_NODE(cos); + GET_IR_NODE(stack); + GET_IR_NODE(flatten); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(y); + GET_IR_NODE(ew_mul_out); + GET_IR_NODE(slice1_out); + GET_IR_NODE(slice2_out); + GET_IR_NODE(sin_out); + GET_IR_NODE(cos_out); + GET_IR_NODE(stack_out); + GET_IR_NODE(flatten_out); + auto* block = flatten->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // Generate sine_pos_xpu fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("sine_pos_xpu"); + // set attrs for fused op + fused_op_desc.SetInput("x", {x->Name()}); + fused_op_desc.SetInput("y", {y->Name()}); + + fused_op_desc.SetOutput("out", {flatten_out->Name()}); + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(x, fused_op); + IR_NODE_LINK_TO(y, fused_op); + IR_NODE_LINK_TO(fused_op, flatten_out); + // delete useless node + std::unordered_set delete_nodes = {ew_mul, + ew_mul_out, + slice1, + slice1_out, + slice2, + slice2_out, + sin, + sin_out, + cos, + cos_out, + stack, + stack_out, + flatten}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(sine_pos_fuse_pass, paddle::framework::ir::SinePosFusePass); + +REGISTER_PASS_CAPABILITY(sine_pos_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "sin_pos_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 726e833fd515a..0a0e6b591ef89 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -577,6 +577,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "yolo_box_xpu_fuse_pass", "fast_where_xpu_fuse_pass", "elementwise_mul_add_fuse_pass", + "sine_pos_fuse_pass", // "auto_mixed_precision_pass", "cast_mixed_precision_op_fuse_pass", "xpu_quantize_op_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index a31dee6a4c27d..f1d253945139e 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -431,6 +431,15 @@ func : self_dp_attention data_type : x +- op : sine_pos_xpu + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : SinePosXPUInferMeta + kernel : + func : sine_pos_xpu + data_type : x + - op : skip_layernorm args : (Tensor x, Tensor y, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis) output : Tensor(out) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 31d16aaf5c0a3..1d388b2a47d5a 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -680,7 +680,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"pool3d", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"pow", XPUKernelSet({phi::DataType::FLOAT32})}, + {"pow", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})}, {"prior_box", XPUKernelSet({phi::DataType::FLOAT32})}, @@ -707,7 +707,8 @@ XPUOpMap& get_kl2_ops() { {"reduce_max", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32, - phi::DataType::INT64})}, + phi::DataType::INT64, + phi::DataType::FLOAT16})}, {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32, @@ -1171,6 +1172,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"sine_pos_xpu", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, }; return s_xpu2_kernels; diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index f38ffe0f1fc9d..41329efaa86d5 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -3687,4 +3687,35 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, qkv_max->set_dtype(out_dtype); qkv_max->set_layout(q.layout()); } +void SinePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + PADDLE_ENFORCE_EQ( + x_dims_size, + 3, + phi::errors::InvalidArgument( + "x_dims_size should be 3, but received x_dims_size is %d", + x_dims_size)); + PADDLE_ENFORCE_EQ(x_dims[x_dims_size - 1], + 1, + phi::errors::InvalidArgument( + "x last dim size should be 1, but received is %d", + x_dims[x_dims_size - 1])); + auto y_dims = y.dims(); + auto y_dims_size = y_dims.size(); + PADDLE_ENFORCE_EQ( + y_dims_size, + 1, + phi::errors::InvalidArgument( + "x_dims_size should be 3, but received x_dims_size is %d", + y_dims_size)); + + phi::DDim out_dim = phi::make_ddim({x_dims[0], x_dims[1], y_dims[0]}); + + out->set_dims(out_dim); + out->set_dtype(x.dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index ade4e38d457a6..e294e67aa1c95 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -834,4 +834,8 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, DataType out_dtype, MetaTensor* qkv, MetaTensor* qkv_max); +void SinePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc new file mode 100644 index 0000000000000..0936f7be2f0ab --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/sine_pos_xpu_kernel.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void SinePosXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + auto* x_data = reinterpret_cast(x.data()); + auto* y_data = reinterpret_cast(y.data()); + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + // fix precision of fp16 model + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + std::vector x_shape = phi::vectorize(x.dims()); + std::vector y_shape = phi::vectorize(y.dims()); + // yolo_box_coord only support fp32&&fp16 precision + int r = xpu::sine_pos_fusion( + /* baidu::xpu::api::Context* ctx */ ctx.x_context(), + /* const T* x */ x_data, + /* const T* y */ y_data, + /* T* out */ out_data, + /* int64_t batch */ x_shape[0], + /* int64_t n */ x_shape[1], + /* int64_t dim */ y_shape[0]); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sine_pos_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(sine_pos_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::SinePosXPUKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc index cb9ff8f6bdb80..9e21dfd6ba30e 100644 --- a/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc @@ -49,4 +49,10 @@ void MaxRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(max_raw, XPU, ALL_LAYOUT, phi::MaxRawKernel, float, int) {} +PD_REGISTER_KERNEL(max_raw, + XPU, + ALL_LAYOUT, + phi::MaxRawKernel, + float, + int, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index 449be30474193..4f82566ca45f1 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -195,15 +195,16 @@ void PowKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& factor, DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); - float pow_factor = factor.to(); - const T* x_data = x.data(); - T* y_data = out->data(); + T pow_factor = factor.to(); + const XPUType* x_data = reinterpret_cast(x.data()); + XPUType* y_data = reinterpret_cast(out->data()); auto xpu_context = dev_ctx.x_context(); // allocate temp memory for factor on xpu xpu::ctx_guard RAII_GUARD(xpu_context); - T* factor_data = RAII_GUARD.alloc_l3_or_gm(1); + XPUType* factor_data = RAII_GUARD.alloc_l3_or_gm(1); PADDLE_ENFORCE_NOT_NULL( factor_data, errors::External("XPU alloc_l3_or_gm returns nullptr")); memory_utils::Copy(dev_ctx.GetPlace(), @@ -653,6 +654,9 @@ PD_REGISTER_KERNEL(cos, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL( + pow, XPU, ALL_LAYOUT, phi::PowKernel, float, phi::dtype::float16) {} + #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} @@ -660,7 +664,6 @@ PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) diff --git a/paddle/phi/kernels/xpu/reduce_max_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_kernel.cc index 8842f86b0c9fb..72ce736ddcad2 100644 --- a/paddle/phi/kernels/xpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_kernel.cc @@ -57,4 +57,11 @@ void MaxKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(max, XPU, ALL_LAYOUT, phi::MaxKernel, float, int, int64_t) {} +PD_REGISTER_KERNEL(max, + XPU, + ALL_LAYOUT, + phi::MaxKernel, + float, + int, + int64_t, + phi::dtype::float16) {} diff --git a/test/ir/inference/test_xpu_sine_pos_pass.py b/test/ir/inference/test_xpu_sine_pos_pass.py new file mode 100644 index 0000000000000..8d8abbfdfb184 --- /dev/null +++ b/test/ir/inference/test_xpu_sine_pos_pass.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestSinePosXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["sine_pos_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=10), min_size=3, max_size=3 + ) + ) + x_shape[1] = draw(st.integers(min_value=100, max_value=512)) + x_shape[2] = draw(st.integers(min_value=1, max_value=1)) + y_shape = draw( + st.lists( + st.integers(min_value=128, max_value=128), + min_size=1, + max_size=1, + ) + ) + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + mul_op = OpConfig( + "elementwise_mul", + inputs={"X": ["x"], "Y": ["y"]}, + outputs={"Out": ["mul_out"]}, + axis=-1, + ) + slice1_op = OpConfig( + "strided_slice", + inputs={"Input": ["mul_out"]}, + outputs={"Out": ["slice1_out"]}, + axes=[2], + starts=[0], + strides=[2], + ends=[128], + infer_flags=[1], + ) + sin_op = OpConfig( + "sin", + inputs={"X": ["slice1_out"]}, + outputs={"Out": ["sin_out"]}, + ) + slice2_op = OpConfig( + "strided_slice", + inputs={"Input": ["mul_out"]}, + outputs={"Out": ["slice2_out"]}, + axes=[2], + starts=[1], + strides=[2], + ends=[128], + infer_flags=[1], + ) + cos_op = OpConfig( + "cos", + inputs={"X": ["slice2_out"]}, + outputs={"Out": ["cos_out"]}, + ) + stack_op = OpConfig( + "stack", + inputs={"X": ["sin_out", "cos_out"]}, + outputs={"Y": ["stack_out"]}, + axis=3, + ) + flatten_op = OpConfig( + "flatten_contiguous_range", + inputs={"X": ["stack_out"]}, + outputs={"Out": ["flatten_out"]}, + start_axis=2, + stop_axis=3, + ) + + ops = [ + mul_op, + slice1_op, + slice2_op, + sin_op, + cos_op, + stack_op, + flatten_op, + ] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "x": TensorConfig(data_gen=partial(generate_data, x_shape)), + "y": TensorConfig(data_gen=partial(generate_data, y_shape)), + }, + weights={}, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["sine_pos_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main()