Skip to content

Commit

Permalink
add xpu_support op function
Browse files Browse the repository at this point in the history
*test=kunlun
  • Loading branch information
QingshuChen committed Dec 5, 2022
1 parent 2af8219 commit 1f33737
Show file tree
Hide file tree
Showing 12 changed files with 746 additions and 469 deletions.
45 changes: 26 additions & 19 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1319,9 +1319,10 @@ bool OperatorWithKernel::SupportXPU() const {
op_kernels.end(),
[this](OpKernelMap::const_reference kern_pair) {
return platform::is_xpu_place(kern_pair.first.place_) &&
paddle::platform::is_xpu_support_op(type_,
kern_pair.first) &&
!paddle::platform::is_in_xpu_black_list(type_);
paddle::platform::is_xpu_support_op(
type_,
framework::TransToPhiDataType(
kern_pair.first.data_type_));
});
}
}
Expand Down Expand Up @@ -1409,16 +1410,17 @@ bool OperatorWithKernel::SupportsKernelType(
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (paddle::platform::is_xpu_place(kernel_type.place_)) {
return kernel_iter != kernels.end() &&
paddle::platform::is_xpu_support_op(type_, kernel_type) &&
!paddle::platform::is_in_xpu_black_list(type_);
paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type.data_type_));
}
#endif

#ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(kernel_type.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, kernel_type);
paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type.data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
Expand All @@ -1428,8 +1430,8 @@ bool OperatorWithKernel::SupportsKernelType(
return kernels.find(tmp_kernel_type) != kernels.end();
}
return kernel_iter != kernels.end() &&
paddle::platform::is_xpu_support_op(type_, kernel_type) &&
!paddle::platform::is_in_xpu_black_list(type_);
paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type.data_type_));
}
#endif

Expand Down Expand Up @@ -1591,7 +1593,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
if (use_xpu_kp_kernel_rt) {
Expand Down Expand Up @@ -1668,7 +1671,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
if (use_xpu_kp_kernel_rt) {
Expand Down Expand Up @@ -1709,14 +1713,15 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
#if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
!paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) ||
paddle::platform::is_in_xpu_black_list(type_);
!paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
#endif
#ifdef PADDLE_WITH_XPU_KP
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_xpu_place(kernel_type_->place_) &&
paddle::platform::is_in_xpu_kpwhite_list(type_);
Expand Down Expand Up @@ -2051,8 +2056,9 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() ||
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(type_))) {
!paddle::platform::is_xpu_support_op(
type_,
framework::TransToPhiDataType(expected_kernel_key.data_type_)))) {
VLOG(3) << "fluid missing XPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
Expand All @@ -2065,7 +2071,9 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key);
paddle::platform::is_xpu_support_op(
type_,
framework::TransToPhiDataType(expected_kernel_key.data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_);
if (use_xpu_kp_kernel_rt) {
Expand Down Expand Up @@ -2093,9 +2101,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
<< ", using_kernel_key:" << expected_kernel_key;
}
}
bool is_xpu_unsupport =
(!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(type_));
bool is_xpu_unsupport = (!paddle::platform::is_xpu_support_op(
type_, framework::TransToPhiDataType(expected_kernel_key.data_type_)));
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "fluid missing XPU kernel: " << type_
Expand Down
16 changes: 10 additions & 6 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ PreparedOp PrepareImpl(
#if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type());
!paddle::platform::is_xpu_support_op(
op.Type(),
framework::TransToPhiDataType(expected_kernel_key.data_type_));
#endif

#ifdef PADDLE_WITH_MLU
Expand Down Expand Up @@ -292,8 +292,10 @@ PreparedOp PrepareImpl(
#ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(
op.Type(), expected_kernel_key);
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_support_op(
op.Type(),
framework::TransToPhiDataType(expected_kernel_key.data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
if (use_xpu_kp_kernel_rt) {
Expand Down Expand Up @@ -368,7 +370,9 @@ PreparedOp PrepareImpl(
bool use_xpu_kp_kernel_rt =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
paddle::platform::is_xpu_support_op(
op.Type(),
framework::TransToPhiDataType(expected_kernel_key.data_type_));
bool use_xpu_kp_kernel_debug =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
Expand Down
Loading

0 comments on commit 1f33737

Please sign in to comment.