diff --git a/cmake/neuware.cmake b/cmake/neuware.cmake index 811c8d664a097..a371a0032d991 100644 --- a/cmake/neuware.cmake +++ b/cmake/neuware.cmake @@ -17,13 +17,16 @@ INCLUDE_DIRECTORIES(${NEUWARE_INCLUDE_DIR}) set(CNNL_LIB ${NEUWARE_LIB_DIR}/libcnnl.so) set(CNRT_LIB ${NEUWARE_LIB_DIR}/libcnrt.so) set(CNDRV_LIB ${NEUWARE_LIB_DIR}/libcndrv.so) +set(CNPAPI_LIB ${NEUWARE_LIB_DIR}/libcnpapi.so) generate_dummy_static_lib(LIB_NAME "neuware_lib" GENERATOR "neuware.cmake") +set(NEUWARE_LIB_DEPS ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB} ${CNPAPI_LIB}) + if(WITH_CNCL) MESSAGE(STATUS "Compile with CNCL!") ADD_DEFINITIONS(-DPADDLE_WITH_CNCL) set(CNCL_LIB ${NEUWARE_LIB_DIR}/libcncl.so) - TARGET_LINK_LIBRARIES(neuware_lib ${CNCL_LIB} ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) -else() - TARGET_LINK_LIBRARIES(neuware_lib ${CNNL_LIB} ${CNRT_LIB} ${CNDRV_LIB}) + list(APPEND NEUWARE_LIB_DEPS ${CNCL_LIB}) endif() + +TARGET_LINK_LIBRARIES(neuware_lib ${NEUWARE_LIB_DEPS}) diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index 589d09bf81c1d..1a4f283f511da 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -34,14 +34,6 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place, return; } - // NOTE(hqp): Special case for CPU->MLU, avoid stream sync. - if (platform::is_cpu_place(in.place()) && platform::is_mlu_place(dst_place)) { - paddle::framework::TensorCopy( - in, dst_place, *platform::DeviceContextPool::Instance().Get(dst_place), - out); - return; - } - // NOTE(yy): TransDataDevice should wait for computation of input. if (!platform::is_cuda_pinned_place(in.place())) { platform::DeviceContextPool::Instance().Get(in.place())->Wait(); diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index 43d662830c0c8..f66b75fd1f319 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -15,12 +15,8 @@ limitations under the Licnse. */ #include #include -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/platform/device/mlu/device_context.h" -#include "paddle/phi/core/ddim.h" namespace paddle { namespace operators { @@ -38,20 +34,39 @@ class ActivationMLUKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); MLUCnnlActivationDesc act_desc(act_mode, alpha); - MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(input->dtype())); - MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(output->dtype())); - - MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), - reinterpret_cast(input->data()), - output_desc.get(), - reinterpret_cast(output->data())); + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc output_desc(*output); + + MLUCnnl::Active(ctx, act_desc.get(), input_desc.get(), GetBasePtr(input), + output_desc.get(), GetBasePtr(output)); } }; +// For gelu, leaky_relu template -class ActivationGradMLUKernel : public framework::OpKernel { +class ActivationGradMLUKernelV1 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnlActivationDesc act_desc(act_mode, alpha); + MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, + dout_desc.get(), GetBasePtr(dout), x_desc.get(), + GetBasePtr(x), dx_desc.get(), GetBasePtr(dx)); + } +}; + +// For tanh, sigmoid +template +class ActivationGradMLUKernelV2 : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* out = ctx.Input("Out"); @@ -61,18 +76,35 @@ class ActivationGradMLUKernel : public framework::OpKernel { dx->mutable_data(ctx.GetPlace()); - MLUCnnlTensorDesc dout_desc(*dout, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(dout->dtype())); - MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(out->dtype())); - MLUCnnlTensorDesc dx_desc(*dx, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(dx->dtype())); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); MLUCnnlActivationDesc act_desc(act_mode, alpha); - MLUCnnl::ActiveGrad( - ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, - dout_desc.get(), reinterpret_cast(dout->data()), - out_desc.get(), reinterpret_cast(out->data()), - dx_desc.get(), reinterpret_cast(dx->data())); + MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, out_desc.get(), + GetBasePtr(out), dout_desc.get(), GetBasePtr(dout), + nullptr, nullptr, dx_desc.get(), GetBasePtr(dx)); + } +}; + +// For relu, relu6 +template +class ActivationGradMLUKernelV3 : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 1.0f; + + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnlActivationDesc act_desc(act_mode, alpha); + MLUCnnl::ActiveGrad(ctx, act_desc.get(), nullptr, nullptr, nullptr, nullptr, + dout_desc.get(), GetBasePtr(dout), out_desc.get(), + GetBasePtr(out), dx_desc.get(), GetBasePtr(dx)); } }; @@ -81,10 +113,60 @@ class ActivationGradMLUKernel : public framework::OpKernel { namespace ops = paddle::operators; +// relu REGISTER_OP_MLU_KERNEL( relu, ops::ActivationMLUKernel, ops::ActivationMLUKernel); REGISTER_OP_MLU_KERNEL( - relu_grad, ops::ActivationGradMLUKernel, - ops::ActivationGradMLUKernel); + relu_grad, ops::ActivationGradMLUKernelV3, + ops::ActivationGradMLUKernelV3); + +// relu6 +REGISTER_OP_MLU_KERNEL( + relu6, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + relu6_grad, ops::ActivationGradMLUKernelV3, + ops::ActivationGradMLUKernelV3); + +// sigmoid +REGISTER_OP_MLU_KERNEL(sigmoid, + ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + sigmoid_grad, + ops::ActivationGradMLUKernelV2, + ops::ActivationGradMLUKernelV2); + +// tanh +REGISTER_OP_MLU_KERNEL( + tanh, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + tanh_grad, ops::ActivationGradMLUKernelV2, + ops::ActivationGradMLUKernelV2); + +// gelu +REGISTER_OP_MLU_KERNEL( + gelu, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + gelu_grad, ops::ActivationGradMLUKernelV1, + ops::ActivationGradMLUKernelV1); + +// leaky_relu +REGISTER_OP_MLU_KERNEL( + leaky_relu, ops::ActivationMLUKernel, + ops::ActivationMLUKernel); +REGISTER_OP_MLU_KERNEL( + leaky_relu_grad, + ops::ActivationGradMLUKernelV1, + ops::ActivationGradMLUKernelV1); diff --git a/paddle/fluid/operators/fill_constant_op_mlu.cc b/paddle/fluid/operators/fill_constant_op_mlu.cc index 10e7c72d158e6..f7463c5dd8821 100644 --- a/paddle/fluid/operators/fill_constant_op_mlu.cc +++ b/paddle/fluid/operators/fill_constant_op_mlu.cc @@ -51,6 +51,8 @@ class FillConstantMLUKernel : public framework::OpKernel { } } } + const T *value_data = &value; + cnnlPointerMode_t pointer_mode = CNNL_POINTER_MODE_HOST; if (ctx.HasInput("ValueTensor")) { auto *value_tensor = ctx.Input("ValueTensor"); PADDLE_ENFORCE_EQ( @@ -59,22 +61,18 @@ class FillConstantMLUKernel : public framework::OpKernel { "When use Tensor as value to set Tensor value in fill_cosntant, " "value input(ValueTensor) size must be 1, but get %d", value_tensor->numel())); - const T *tensor_data = value_tensor->data(); - framework::Tensor mlu_tensor; + value_data = value_tensor->data(); auto tmp_place = value_tensor->place(); if (platform::is_mlu_place(tmp_place)) { - framework::TensorCopySync(*value_tensor, platform::CPUPlace(), - &mlu_tensor); - tensor_data = mlu_tensor.data(); + pointer_mode = CNNL_POINTER_MODE_DEVICE; } - value = tensor_data[0]; } auto shape = GetShape(ctx); out_var->mutable_data(shape, ctx.GetPlace()); - MLUCnnlTensorDesc output_desc(*out_var, CNNL_LAYOUT_ARRAY, - ToCnnlDataType(out_var->dtype())); - MLUCnnl::Fill(ctx, value, output_desc.get(), GetBasePtr(out_var)); + MLUCnnlTensorDesc output_desc(*out_var); + MLUCnnl::Fill(ctx, pointer_mode, value_data, output_desc.get(), + GetBasePtr(out_var)); } }; } // namespace operators diff --git a/paddle/fluid/operators/gather_op_mlu.cc b/paddle/fluid/operators/gather_op_mlu.cc new file mode 100644 index 0000000000000..220d045952643 --- /dev/null +++ b/paddle/fluid/operators/gather_op_mlu.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class GatherOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto axis = ctx.Attr("axis"); + + auto *out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(*x); + MLUCnnlTensorDesc index_desc(*index); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(), + GetBasePtr(x), index_desc.get(), GetBasePtr(index), + out_desc.get(), GetBasePtr(out)); + } +}; + +template +class GatherGradOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *index = ctx.Input("Index"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc dx_desc(*dx); + auto value = static_cast(0); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(), + GetBasePtr(dx)); + + MLUCnnlTensorDesc index_desc(*index); + MLUCnnlTensorDesc dout_desc(*dout); + const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; + MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(), + GetBasePtr(dout), index_desc.get(), + GetBasePtr(index), mode); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(gather, ops::GatherOpMLUKernel, + ops::GatherOpMLUKernel, + ops::GatherOpMLUKernel); + +REGISTER_OP_MLU_KERNEL(gather_grad, ops::GatherGradOpMLUKernel, + ops::GatherGradOpMLUKernel, + ops::GatherGradOpMLUKernel); diff --git a/paddle/fluid/operators/mean_op_mlu.cc b/paddle/fluid/operators/mean_op_mlu.cc index 1fed01194c1a6..1456e749b1343 100644 --- a/paddle/fluid/operators/mean_op_mlu.cc +++ b/paddle/fluid/operators/mean_op_mlu.cc @@ -95,7 +95,8 @@ class MeanMLUGradKernel : public framework::OpKernel { MLUCnnlTensorDesc mean_var_desc(mean_var, CNNL_LAYOUT_ARRAY, ToCnnlDataType(mean_var.dtype())); auto value = static_cast(1.0 / static_cast(input_grad->numel())); - MLUCnnl::Fill(context, value, mean_var_desc.get(), GetBasePtr(&mean_var)); + MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, mean_var_desc.get(), + GetBasePtr(&mean_var)); // means mul output_grad MLUCnnlTensorDesc in_desc(*output_grad, CNNL_LAYOUT_ARRAY, diff --git a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc index 1ce02ff4525c9..26c31d82e36eb 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc @@ -136,15 +136,17 @@ class AccuracyMLUKernel : public framework::OpKernel { // [total] total->mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc total_desc(*total); - MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples, total_desc.get(), + GetBasePtr(total)); // use `total` of type `float32` for calculating accuracy Tensor total_fp32(framework::TransToPhiDataType(VT::FP32)); total_fp32.Resize(total->dims()); total_fp32.mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc total_fp32_desc(total_fp32); - MLUCnnl::Fill(ctx, static_cast(num_samples), total_fp32_desc.get(), - GetBasePtr(&total_fp32)); + float num_samples_fp32 = static_cast(num_samples); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &num_samples_fp32, + total_fp32_desc.get(), GetBasePtr(&total_fp32)); // [accuracy] accuracy->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 1fdaa153e3c27..b09f7e33fd12f 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -208,8 +208,20 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() { MLUCnnlActivationDesc::MLUCnnlActivationDesc( const cnnlActivationMode_t act_mode, const float ceof) { PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor( - active_desc_, act_mode, CNNL_NOT_PROPAGATE_NAN, ceof)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4( + active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, ceof, 1.0f /*sliced_dim*/, + 1.67326319217681884765625 /*selu_alpha*/, + 1.05070102214813232421875 /*selu_lambda*/)); +} + +MLUCnnlActivationDesc::MLUCnnlActivationDesc( + const cnnlActivationMode_t act_mode, const float ceof, + const float sliced_dim, const float selu_alpha, const float selu_lambda) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4( + active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, ceof, sliced_dim, selu_alpha, selu_lambda)); } const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const { @@ -541,12 +553,15 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { output_desc, output)); } -/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx, float value, +/* static */ void MLUCnnl::Fill(const ExecutionContext& ctx, + const cnnlPointerMode_t pointer_mode, + const void* value_ptr, const cnnlTensorDescriptor_t output_desc, void* output) { cnnlHandle_t handle = GetHandleFromCTX(ctx); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlFill(handle, value, output_desc, output)); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlFill_v3(handle, pointer_mode, value_ptr, output_desc, output)); } /* static */ void MLUCnnl::QuantifyOffline( @@ -919,9 +934,8 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() { beta_ptr = static_cast(&beta_int); } - PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize_v2( - handle, op_tensor_desc, alpha1_ptr, a_desc, a, alpha2_ptr, b_desc, b, - beta_ptr, output_desc, output, &workspace_size)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize( + handle, a_desc, b_desc, output_desc, &workspace_size)); auto& dev_ctx = GetDevCtxFromCTX(ctx); Tensor workspace = ctx.AllocateTmpTensor( diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index b55b10686e92e..64a99b2a6d273 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -218,6 +218,9 @@ class MLUCnnlActivationDesc { MLUCnnlActivationDesc(const MLUCnnlActivationDesc& desc) = delete; MLUCnnlActivationDesc& operator=(const MLUCnnlActivationDesc& desc) = delete; MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof); + MLUCnnlActivationDesc(const cnnlActivationMode_t act_mode, const float ceof, + const float sliced_dim, const float selu_alpha, + const float selu_lambda); const cnnlActivationDescriptor_t get() const; ~MLUCnnlActivationDesc(); @@ -418,7 +421,8 @@ class MLUCnnl { const cnnlTensorDescriptor_t in1_desc, const void* in1, const cnnlTensorDescriptor_t output_desc, void* output); - static void Fill(const ExecutionContext& ctx, float value, + static void Fill(const ExecutionContext& ctx, + const cnnlPointerMode_t pointer_mode, const void* value_ptr, const cnnlTensorDescriptor_t output_desc, void* output); static void LRN(const ExecutionContext& ctx, const int local_size, diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc index e5399ee36ba7f..54ead6d3df7f0 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_mlu.cc @@ -69,7 +69,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { "the same Tensors.")); } - auto mu = ctx.Attr("mu"); + auto mu = static_cast(ctx.Attr("mu")); auto lrs = ctx.MultiInput("LearningRate"); if (lrs.size() != 1) { PADDLE_ENFORCE_EQ( @@ -114,14 +114,15 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { Tensor mu_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc mu_tensor_desc(mu_tensor); - MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(), + GetBasePtr(&mu_tensor)); for (size_t idx = 0; idx < n; ++idx) { - RegularizationType regularization_flag = + phi::RegularizationType regularization_flag = regularization_methods.size() > 0 && regularization_methods[idx] == "l2_decay" - ? RegularizationType::kL2DECAY - : RegularizationType::kNONE; + ? phi::RegularizationType::kL2DECAY + : phi::RegularizationType::kNONE; T regularization_coeff = static_cast(0.0); if (regularization_coeffs.size() != 0) { regularization_coeff = static_cast(regularization_coeffs[idx]); @@ -134,7 +135,7 @@ class MLUMergedMomentumOpKernel : public framework::OpKernel { auto grad = grads[idx]; Tensor regularized_grad; MLUCnnlTensorDesc param_desc(*param_out); - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { regularized_grad = ctx.AllocateTmpTensor( param_out->dims(), dev_ctx); MLUCnnlOpTensorDesc op_tensor_desc( diff --git a/paddle/fluid/operators/optimizers/momentum_op_mlu.cc b/paddle/fluid/operators/optimizers/momentum_op_mlu.cc index 91e8aa643b981..b8fa81b2e7123 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_mlu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_mlu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { @@ -27,10 +28,10 @@ class MLUMomentumOpKernel : public framework::OpKernel { std::string regularization_method = ctx.Attr("regularization_method"); auto regularization_coeff = ctx.Attr("regularization_coeff"); - RegularizationType regularization_flag{ - RegularizationType::kNONE}; // disable regularization + phi::RegularizationType regularization_flag{ + phi::RegularizationType::kNONE}; // disable regularization if (regularization_method == "l2_decay") { - regularization_flag = RegularizationType::kL2DECAY; + regularization_flag = phi::RegularizationType::kL2DECAY; } T mu = static_cast(ctx.Attr("mu")); @@ -52,11 +53,12 @@ class MLUMomentumOpKernel : public framework::OpKernel { Tensor mu_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc mu_tensor_desc(mu_tensor); - MLUCnnl::Fill(ctx, mu, mu_tensor_desc.get(), GetBasePtr(&mu_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &mu, mu_tensor_desc.get(), + GetBasePtr(&mu_tensor)); Tensor regularized_grad; MLUCnnlTensorDesc param_desc(*param); - if (regularization_flag == RegularizationType::kL2DECAY) { + if (regularization_flag == phi::RegularizationType::kL2DECAY) { regularized_grad = ctx.AllocateTmpTensor(param->dims(), dev_ctx); MLUCnnlOpTensorDesc op_tensor_desc( diff --git a/paddle/fluid/operators/pool_op_mlu.cc b/paddle/fluid/operators/pool_op_mlu.cc index fa88d128a9a1d..c1bcf82c33256 100644 --- a/paddle/fluid/operators/pool_op_mlu.cc +++ b/paddle/fluid/operators/pool_op_mlu.cc @@ -116,11 +116,16 @@ class MLUPoolOpKernel : public framework::OpKernel { framework::Tensor extra_device_tensor = ctx.AllocateTmpTensor( {static_cast(extra_input_size)}, dev_ctx); - // TODO(fwg): use Async copy, and add a callback to stream that free - // host - // memory. - framework::TensorCopySync(extra_host_tensor, ctx.GetPlace(), - &extra_device_tensor); + framework::TensorCopy(extra_host_tensor, ctx.GetPlace(), + &extra_device_tensor); + // Increase extra_host_tensor holder_ reference count until copy + // complete. + auto increase_ref_count = [extra_host_tensor]() { + VLOG(4) << "Finished copying extra_host_tensor[" + << GetBasePtr(&extra_host_tensor) + << "] in mlu pooling kernel."; + }; + dev_ctx.AddStreamCallback(increase_ref_count); MLUCnnl::PoolingForward( ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/, in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/, diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc index 45f4e43378f44..89e578dbdb6b7 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op_mlu.cc @@ -103,8 +103,8 @@ class ReduceMeanGradMLUKernel : public framework::OpKernel { ToCnnlDataType(input_grad->dtype())); auto value = static_cast(1.0 / static_cast(reduce_numel)); - MLUCnnl::Fill(context, value, input_grad_desc.get(), - GetBasePtr(input_grad)); + MLUCnnl::Fill(context, CNNL_POINTER_MODE_HOST, &value, + input_grad_desc.get(), GetBasePtr(input_grad)); MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN); diff --git a/paddle/fluid/operators/scale_op_mlu.cc b/paddle/fluid/operators/scale_op_mlu.cc index 5237e70e319ad..f9e313e64b1e1 100644 --- a/paddle/fluid/operators/scale_op_mlu.cc +++ b/paddle/fluid/operators/scale_op_mlu.cc @@ -27,7 +27,7 @@ class ScaleMLUKernel : public framework::OpKernel { auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var); // cnnl require input, scale, bias with same type. And all in device side. - auto& scale = ctx.Attr("scale"); + auto scale = static_cast(ctx.Attr("scale")); framework::Tensor scale_tensor; if (ctx.HasInput("ScaleTensor")) { framework::Tensor float_scale_tensor = @@ -49,14 +49,16 @@ class ScaleMLUKernel : public framework::OpKernel { } else { scale_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc scale_desc(scale_tensor); - MLUCnnl::Fill(ctx, scale, scale_desc.get(), GetBasePtr(&scale_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &scale, scale_desc.get(), + GetBasePtr(&scale_tensor)); } - auto& bias = ctx.Attr("bias"); + auto bias = static_cast(ctx.Attr("bias")); framework::Tensor bias_tensor = ctx.AllocateTmpTensor({1}, dev_ctx); MLUCnnlTensorDesc bias_desc(bias_tensor); - MLUCnnl::Fill(ctx, bias, bias_desc.get(), GetBasePtr(&bias_tensor)); + MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &bias, bias_desc.get(), + GetBasePtr(&bias_tensor)); auto* out_var = ctx.OutputVar("Out"); if (in_var->IsType() && in_var != out_var) { diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h index fcf06cb4f1c40..12c206ef2c445 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.h +++ b/paddle/fluid/platform/device/mlu/mlu_info.h @@ -16,7 +16,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_MLU #include +#include #include +#include #include #ifdef PADDLE_WITH_CNCL #include @@ -33,7 +35,7 @@ using cnclStatus = cnclResult_t; #endif using mluStream = cnrtQueue_t; using mluCnnlHandle = cnnlHandle_t; -using mluEventHandle = CNnotifier; +using mluEventHandle = cnrtNotifier_t; using mluDeviceHandle = CNdev; namespace platform { diff --git a/paddle/fluid/platform/device/mlu/mlu_stream.h b/paddle/fluid/platform/device/mlu/mlu_stream.h index 3f4b27e370f2e..b20949f3bfe85 100644 --- a/paddle/fluid/platform/device/mlu/mlu_stream.h +++ b/paddle/fluid/platform/device/mlu/mlu_stream.h @@ -40,7 +40,6 @@ class MLUStream final { template void AddCallback(Callback&& callback) const { - // TODO(mlu): mlu not support AddCallback callback_manager_->AddCallback(callback); } diff --git a/paddle/fluid/platform/profiler/CMakeLists.txt b/paddle/fluid/platform/profiler/CMakeLists.txt index c903a52530ccb..084bc44dbc78b 100755 --- a/paddle/fluid/platform/profiler/CMakeLists.txt +++ b/paddle/fluid/platform/profiler/CMakeLists.txt @@ -1,12 +1,13 @@ cc_library(host_tracer SRCS host_tracer.cc DEPS enforce) cc_library(cuda_tracer SRCS cuda_tracer.cc cupti_data_process.cc DEPS workqueue_utils enforce glog) +add_subdirectory(mlu) cc_library(event_node SRCS event_node.cc DEPS enforce) cc_library(profiler_utils SRCS utils.cc DEPS enforce glog) add_subdirectory(dump) cc_library(profiler_logger SRCS chrometracing_logger.cc dump/serialization_logger.cc dump/deserialization_reader.cc DEPS nodetreeproto event_node profiler_utils) cc_library(event_bind SRCS event_python.cc DEPS profiler_logger) cc_library(cpu_utilization SRCS cpu_utilization.cc DEPS cpu_info os_info enforce glog) -cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer profiler_utils cpu_utilization event_bind) +cc_library(new_profiler SRCS profiler.cc DEPS host_tracer cuda_tracer profiler_utils cpu_utilization event_bind mlu_tracer) cc_test(test_event_node SRCS test_event_node.cc DEPS event_node profiler_logger) cc_test(test_extra_info SRCS test_extra_info.cc DEPS profiler_utils) cc_test(test_serialization_logger SRCS dump/test_serialization_logger.cc DEPS event_bind) diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.cc b/paddle/fluid/platform/profiler/chrometracing_logger.cc index d7879e7be517e..4ee95a530fb43 100644 --- a/paddle/fluid/platform/profiler/chrometracing_logger.cc +++ b/paddle/fluid/platform/profiler/chrometracing_logger.cc @@ -38,10 +38,12 @@ static std::string DefaultFileName() { } const char* ChromeTracingLogger::categary_name_[] = { - "Operator", "Dataloader", "ProfileStep", "CudaRuntime", - "Kernel", "Memcpy", "Memset", "UserDefined", - "OperatorInner", "Forward", "Backward", "Optimization", - "Communication", "PythonOp", "PythonUserDefined"}; + "Operator", "Dataloader", "ProfileStep", + "CudaRuntime", "Kernel", "Memcpy", + "Memset", "UserDefined", "OperatorInner", + "Forward", "Backward", "Optimization", + "Communication", "PythonOp", "PythonUserDefined", + "MluRuntime"}; void ChromeTracingLogger::OpenFile() { output_file_stream_.open(filename_, @@ -598,6 +600,12 @@ void ChromeTracingLogger::RefineDisplayName( (*it).second * 2, (*it).first, (*it).second, (*it).second * 2 + 1); } +#ifdef PADDLE_WITH_MLU + static std::string device_type("MLU"); +#else + static std::string device_type("GPU"); +#endif + for (auto it = deviceid_streamid_set_.begin(); it != deviceid_streamid_set_.end(); ++it) { output_file_stream_ << string_format( @@ -607,7 +615,7 @@ void ChromeTracingLogger::RefineDisplayName( "name": "process_name", "pid": %lld, "tid": %lld, "ph": "M", "args": { - "name": "Deivce %lld (GPU)" + "name": "Deivce %lld (%s)" } }, { @@ -632,9 +640,9 @@ void ChromeTracingLogger::RefineDisplayName( } }, )JSON"), - (*it).first, (*it).second, (*it).first, (*it).first, (*it).second, - (*it).second, (*it).first, (*it).second, (*it).first + 0x10000000, - (*it).first, (*it).second, (*it).second); + (*it).first, (*it).second, (*it).first, device_type.c_str(), + (*it).first, (*it).second, (*it).second, (*it).first, (*it).second, + (*it).first + 0x10000000, (*it).first, (*it).second, (*it).second); } } diff --git a/paddle/fluid/platform/profiler/mlu/CMakeLists.txt b/paddle/fluid/platform/profiler/mlu/CMakeLists.txt new file mode 100644 index 0000000000000..01b3757ea6912 --- /dev/null +++ b/paddle/fluid/platform/profiler/mlu/CMakeLists.txt @@ -0,0 +1,5 @@ +if(WITH_MLU) + set(MLU_INFO mlu_info) +endif() + +cc_library(mlu_tracer SRCS mlu_tracer.cc cnpapi_data_process.cc DEPS workqueue_utils enforce glog ${MLU_INFO}) diff --git a/paddle/fluid/platform/profiler/mlu/cnpapi_data_process.cc b/paddle/fluid/platform/profiler/mlu/cnpapi_data_process.cc new file mode 100644 index 0000000000000..36abf77279d06 --- /dev/null +++ b/paddle/fluid/platform/profiler/mlu/cnpapi_data_process.cc @@ -0,0 +1,264 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h" +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/os_info.h" + +#ifdef PADDLE_WITH_MLU +namespace paddle { +namespace platform { + +namespace { + +inline uint64_t GetTimeGap() { + static uint64_t time_gap = []() -> uint64_t { + uint64_t cpu_time = PosixInNsec(); + uint64_t mlu_time = cnpapiGetTimestamp(); + return (cpu_time - mlu_time); + }(); + return time_gap; +} + +void AddKernelRecord(const cnpapiActivityKernel* kernel, uint64_t start_ns, + TraceEventCollector* collector) { + static uint64_t time_gap = GetTimeGap(); + if (kernel->start + time_gap < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = demangle(kernel->name); + event.type = TracerEventType::Kernel; + event.start_ns = kernel->start + time_gap; + event.end_ns = kernel->end + time_gap; + event.device_id = kernel->device_id; + event.context_id = kernel->context_id; + event.stream_id = kernel->queue_id; + event.correlation_id = kernel->correlation_id; + event.kernel_info.block_x = kernel->dimx; + event.kernel_info.block_y = kernel->dimy; + event.kernel_info.block_z = kernel->dimz; + event.kernel_info.grid_x = kernel->kernel_type; + event.kernel_info.grid_y = 0; + event.kernel_info.grid_z = 0; + event.kernel_info.queued = kernel->queued; + event.kernel_info.submitted = kernel->submitted; + event.kernel_info.completed = kernel->received; + collector->AddDeviceEvent(std::move(event)); +} + +const char* MemcpyKind(cnpapiActivityMemcpyType kind) { + switch (kind) { + case CNPAPI_ACTIVITY_MEMCPY_TYPE_HTOD: + return "MEMCPY_HtoD"; + case CNPAPI_ACTIVITY_MEMCPY_TYPE_DTOH: + return "MEMCPY_DtoH"; + case CNPAPI_ACTIVITY_MEMCPY_TYPE_DTOD: + return "MEMCPY_DtoD"; + case CNPAPI_ACTIVITY_MEMCPY_TYPE_HTOH: + return "MEMCPY_HtoH"; + case CNPAPI_ACTIVITY_MEMCPY_TYPE_PTOP: + return "MEMCPY_PtoP"; + default: + break; + } + return "MEMCPY"; +} + +void AddMemcpyRecord(const cnpapiActivityMemcpy* memcpy, uint64_t start_ns, + TraceEventCollector* collector) { + static uint64_t time_gap = GetTimeGap(); + if (memcpy->start + time_gap < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = MemcpyKind(memcpy->copy_type); + event.type = TracerEventType::Memcpy; + event.start_ns = memcpy->start + time_gap; + event.end_ns = memcpy->end + time_gap; + event.device_id = memcpy->device_id; + event.context_id = memcpy->context_id; + event.stream_id = memcpy->queue_id; + event.correlation_id = memcpy->correlation_id; + event.memcpy_info.num_bytes = memcpy->bytes; + snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s", + MemcpyKind(memcpy->copy_type)); + collector->AddDeviceEvent(std::move(event)); +} + +void AddMemcpy2Record(const cnpapiActivityMemcpyPtoP* memcpy2, + uint64_t start_ns, TraceEventCollector* collector) { + static uint64_t time_gap = GetTimeGap(); + if (memcpy2->start + time_gap < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = MemcpyKind(memcpy2->copy_type); + event.type = TracerEventType::Memcpy; + event.start_ns = memcpy2->start + time_gap; + event.end_ns = memcpy2->end + time_gap; + event.device_id = memcpy2->device_id; + event.context_id = memcpy2->context_id; + event.stream_id = memcpy2->queue_id; + event.correlation_id = memcpy2->correlation_id; + event.memcpy_info.num_bytes = memcpy2->bytes; + snprintf(event.memcpy_info.copy_kind, kMemKindMaxLen, "%s", + MemcpyKind(memcpy2->copy_type)); + collector->AddDeviceEvent(std::move(event)); +} + +void AddMemsetRecord(const cnpapiActivityMemset* memset, uint64_t start_ns, + TraceEventCollector* collector) { + static uint64_t time_gap = GetTimeGap(); + if (memset->start + time_gap < start_ns) { + return; + } + DeviceTraceEvent event; + event.name = "MEMSET"; + event.type = TracerEventType::Memset; + event.start_ns = memset->start + time_gap; + event.end_ns = memset->end + time_gap; + event.device_id = memset->device_id; + event.context_id = memset->context_id; + event.stream_id = memset->queue_id; + event.correlation_id = memset->correlation_id; + event.memset_info.num_bytes = memset->bytes; + event.memset_info.value = memset->value; + collector->AddDeviceEvent(std::move(event)); +} + +class CnpapiRuntimeCbidStr { + public: + static const CnpapiRuntimeCbidStr& GetInstance() { + static CnpapiRuntimeCbidStr inst; + return inst; + } + + std::string RuntimeKind(cnpapi_CallbackId cbid) const { + auto iter = cbid_str_.find(cbid); + if (iter == cbid_str_.end()) { + return "MLU Runtime API " + std::to_string(cbid); + } + return iter->second; + } + + private: + CnpapiRuntimeCbidStr(); + + std::unordered_map cbid_str_; +}; + +CnpapiRuntimeCbidStr::CnpapiRuntimeCbidStr() { +#define REGISTER_RUNTIME_CBID_STR(cbid) \ + cbid_str_[CNPAPI_CNDRV_TRACE_CBID_##cbid] = #cbid + + REGISTER_RUNTIME_CBID_STR(cnMalloc); + REGISTER_RUNTIME_CBID_STR(cnMallocHost); + REGISTER_RUNTIME_CBID_STR(cnFree); + REGISTER_RUNTIME_CBID_STR(cnFreeHost); + REGISTER_RUNTIME_CBID_STR(cnMemcpy); + REGISTER_RUNTIME_CBID_STR(cnMemcpyPeer); + REGISTER_RUNTIME_CBID_STR(cnMemcpyHtoD); + REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoH); + REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoD); + REGISTER_RUNTIME_CBID_STR(cnMemcpyAsync); + REGISTER_RUNTIME_CBID_STR(cnMemcpyHtoDAsync); + REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoHAsync); + REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoDAsync); + REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoD2D); + REGISTER_RUNTIME_CBID_STR(cnMemcpyDtoD3D); + REGISTER_RUNTIME_CBID_STR(cnMemcpy2D); + REGISTER_RUNTIME_CBID_STR(cnMemcpy3D); + REGISTER_RUNTIME_CBID_STR(cnMemsetD8); + REGISTER_RUNTIME_CBID_STR(cnMemsetD16); + REGISTER_RUNTIME_CBID_STR(cnMemsetD32); + REGISTER_RUNTIME_CBID_STR(cnMemsetD8Async); + REGISTER_RUNTIME_CBID_STR(cnMemsetD16Async); + REGISTER_RUNTIME_CBID_STR(cnMemsetD32Async); + REGISTER_RUNTIME_CBID_STR(cnInvokeKernel); + REGISTER_RUNTIME_CBID_STR(cnCreateQueue); + REGISTER_RUNTIME_CBID_STR(cnDestroyQueue); + REGISTER_RUNTIME_CBID_STR(cnQueueSync); + REGISTER_RUNTIME_CBID_STR(cnQueueWaitNotifier); + REGISTER_RUNTIME_CBID_STR(cnWaitNotifier); + REGISTER_RUNTIME_CBID_STR(cnCreateNotifier); + REGISTER_RUNTIME_CBID_STR(cnDestroyNotifier); + REGISTER_RUNTIME_CBID_STR(cnPlaceNotifier); + REGISTER_RUNTIME_CBID_STR(cnCtxCreate); + REGISTER_RUNTIME_CBID_STR(cnCtxDestroy); + REGISTER_RUNTIME_CBID_STR(cnCtxGetCurrent); + REGISTER_RUNTIME_CBID_STR(cnCtxSetCurrent); + REGISTER_RUNTIME_CBID_STR(cnCtxGetDevice); + REGISTER_RUNTIME_CBID_STR(cnCtxSync); + REGISTER_RUNTIME_CBID_STR(cnInvokeHostFunc); +#undef REGISTER_RUNTIME_CBID_STR +} + +void AddApiRecord(const cnpapiActivityAPI* api, uint64_t start_ns, + TraceEventCollector* collector) { + static uint64_t time_gap = GetTimeGap(); + if (api->start + time_gap < start_ns) { + return; + } + RuntimeTraceEvent event; + event.name = CnpapiRuntimeCbidStr::GetInstance().RuntimeKind(api->cbid); + event.start_ns = api->start + time_gap; + event.end_ns = api->end + time_gap; + event.process_id = api->process_id; + event.thread_id = api->thread_id; + event.correlation_id = api->correlation_id; + event.callback_id = api->cbid; + event.type = TracerEventType::MluRuntime; + collector->AddRuntimeEvent(std::move(event)); +} + +} // namespace + +namespace details { + +void ProcessCnpapiActivityRecord(const cnpapiActivity* record, + uint64_t start_ns, + TraceEventCollector* collector) { + switch (record->type) { + case CNPAPI_ACTIVITY_TYPE_KERNEL: + AddKernelRecord(reinterpret_cast(record), + start_ns, collector); + break; + case CNPAPI_ACTIVITY_TYPE_MEMCPY: + AddMemcpyRecord(reinterpret_cast(record), + start_ns, collector); + break; + case CNPAPI_ACTIVITY_TYPE_MEMCPY_PTOP: + AddMemcpy2Record( + reinterpret_cast(record), start_ns, + collector); + break; + case CNPAPI_ACTIVITY_TYPE_MEMSET: + AddMemsetRecord(reinterpret_cast(record), + start_ns, collector); + break; + case CNPAPI_ACTIVITY_TYPE_CNDRV_API: + AddApiRecord(reinterpret_cast(record), start_ns, + collector); + break; + default: + break; + } +} + +} // namespace details +} // namespace platform +} // namespace paddle +#endif diff --git a/paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h b/paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h new file mode 100644 index 0000000000000..1f00b46d2c2ae --- /dev/null +++ b/paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif +#include "paddle/fluid/platform/profiler/trace_event_collector.h" + +namespace paddle { +namespace platform { +namespace details { + +#ifdef PADDLE_WITH_MLU +void ProcessCnpapiActivityRecord(const cnpapiActivity* record, + uint64_t start_ns, + TraceEventCollector* collector); +#endif + +} // namespace details +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/mlu/mlu_tracer.cc b/paddle/fluid/platform/profiler/mlu/mlu_tracer.cc new file mode 100644 index 0000000000000..2d719a8bbfdcb --- /dev/null +++ b/paddle/fluid/platform/profiler/mlu/mlu_tracer.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h" +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/framework/new_executor/workqueue/workqueue_utils.h" +#include "paddle/fluid/platform/os_info.h" +#include "paddle/fluid/platform/profiler/mlu/cnpapi_data_process.h" + +#define CNPAPI_CALL(call) \ + do { \ + cnpapiResult _status = call; \ + if (_status != CNPAPI_SUCCESS) { \ + const char* errstr; \ + cnpapiGetResultString(_status, &errstr); \ + LOG(ERROR) << "Function " << #call << " failed with error " << errstr; \ + } \ + } while (0) + +namespace paddle { +namespace platform { + +namespace { + +void BufferRequestedCallback(uint64_t** buffer, size_t* size, + size_t* max_num_records) { + constexpr size_t kBufferSize = 1 << 23; // 8 MB + constexpr size_t kBufferAlignSize = 8; + *buffer = reinterpret_cast( + paddle::framework::AlignedMalloc(kBufferSize, kBufferAlignSize)); + *size = kBufferSize; + *max_num_records = 0; +} + +void BufferCompletedCallback(uint64_t* buffer, size_t size, size_t valid_size) { + if (buffer == nullptr || valid_size == 0) { + return; + } + auto mlu_tracer = &MluTracer::GetInstance(); + mlu_tracer->ProcessCnpapiActivity(buffer, valid_size); + + paddle::framework::AlignedFree(buffer); +} + +} // namespace + +MluTracer::MluTracer() { +#ifdef PADDLE_WITH_MLU + CNPAPI_CALL(cnpapiInit()); + CNPAPI_CALL(cnpapiActivityRegisterCallbacks(BufferRequestedCallback, + BufferCompletedCallback)); +#endif +} + +void MluTracer::PrepareTracing() { + PADDLE_ENFORCE_EQ( + state_ == TracerState::UNINITED || state_ == TracerState::STOPED, true, + platform::errors::PreconditionNotMet("MluTracer must be UNINITED")); + EnableCnpapiActivity(); + state_ = TracerState::READY; +} + +void MluTracer::StartTracing() { + PADDLE_ENFORCE_EQ(state_ == TracerState::READY, true, + platform::errors::PreconditionNotMet( + "MluTracer must be READY or STOPPED")); + tracing_start_ns_ = PosixInNsec(); + state_ = TracerState::STARTED; +} + +void MluTracer::StopTracing() { + PADDLE_ENFORCE_EQ( + state_, TracerState::STARTED, + platform::errors::PreconditionNotMet("MluTracer must be STARTED")); + DisableCnpapiActivity(); + state_ = TracerState::STOPED; +} + +void MluTracer::CollectTraceData(TraceEventCollector* collector) { + PADDLE_ENFORCE_EQ( + state_, TracerState::STOPED, + platform::errors::PreconditionNotMet("MluTracer must be STOPED")); + for (auto he : collector_.HostEvents()) { + collector->AddHostEvent(std::move(he)); + } + for (auto rte : collector_.RuntimeEvents()) { + collector->AddRuntimeEvent(std::move(rte)); + } + for (auto de : collector_.DeviceEvents()) { + collector->AddDeviceEvent(std::move(de)); + } + for (auto tn : collector_.ThreadNames()) { + collector->AddThreadName(tn.first, tn.second); + } + collector_.ClearAll(); +} + +void MluTracer::ProcessCnpapiActivity(uint64_t* buffer, size_t valid_size) { +#ifdef PADDLE_WITH_MLU + cnpapiActivity* record = nullptr; + while (true) { + cnpapiResult status = + cnpapiActivityGetNextRecord(buffer, valid_size, &record); + if (status == CNPAPI_SUCCESS) { + details::ProcessCnpapiActivityRecord(record, tracing_start_ns_, + &collector_); + } else if (status == CNPAPI_ERROR_INSUFFICIENT_MEMORY || + status == CNPAPI_ERROR_MAX_LIMIT_REACHED) { + break; + } else { + CNPAPI_CALL(status); + } + } +#endif +} + +void MluTracer::EnableCnpapiActivity() { +#ifdef PADDLE_WITH_MLU + CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_KERNEL)); + CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_MEMCPY)); + CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_MEMCPY_PTOP)); + CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_MEMSET)); + CNPAPI_CALL(cnpapiActivityEnable(CNPAPI_ACTIVITY_TYPE_CNDRV_API)); + VLOG(3) << "enable cnpapi activity"; +#endif +} + +void MluTracer::DisableCnpapiActivity() { +#ifdef PADDLE_WITH_MLU + CNPAPI_CALL(cnpapiActivityFlushAll()); + CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_KERNEL)); + CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_MEMCPY)); + CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_MEMCPY_PTOP)); + CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_MEMSET)); + CNPAPI_CALL(cnpapiActivityDisable(CNPAPI_ACTIVITY_TYPE_CNDRV_API)); + VLOG(3) << "disable cnpapi activity"; +#endif +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/mlu/mlu_tracer.h b/paddle/fluid/platform/profiler/mlu/mlu_tracer.h new file mode 100644 index 0000000000000..43c712b13ae2c --- /dev/null +++ b/paddle/fluid/platform/profiler/mlu/mlu_tracer.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif +#include "paddle/fluid/platform/macros.h" +#include "paddle/fluid/platform/profiler/tracer_base.h" + +namespace paddle { +namespace platform { + +class MluTracer : public TracerBase { + public: + static MluTracer& GetInstance() { + static MluTracer instance; + return instance; + } + + void PrepareTracing() override; + + void StartTracing() override; + + void StopTracing() override; + + void CollectTraceData(TraceEventCollector* collector) override; + + void ProcessCnpapiActivity(uint64_t* buffer, size_t valid_size); + + private: + MluTracer(); + + DISABLE_COPY_AND_ASSIGN(MluTracer); + + void EnableCnpapiActivity(); + + void DisableCnpapiActivity(); + + uint64_t tracing_start_ns_ = UINT64_MAX; + + TraceEventCollector collector_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index ac46fbed10a20..a417eda1509e5 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -27,6 +27,7 @@ #include "paddle/fluid/platform/profiler/cuda_tracer.h" #include "paddle/fluid/platform/profiler/extra_info.h" #include "paddle/fluid/platform/profiler/host_tracer.h" +#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h" #include "paddle/fluid/platform/profiler/trace_event_collector.h" #include "paddle/fluid/platform/profiler/utils.h" @@ -52,6 +53,14 @@ bool Profiler::IsCuptiSupported() { return supported; } +bool Profiler::IsCnpapiSupported() { + bool supported = false; +#ifdef PADDLE_WITH_MLU + supported = true; +#endif + return supported; +} + Profiler::Profiler(const ProfilerOptions& options) { options_ = options; std::bitset<32> trace_switch(options_.trace_switch); @@ -63,6 +72,9 @@ Profiler::Profiler(const ProfilerOptions& options) { if (trace_switch.test(kProfileGPUOptionBit)) { tracers_.emplace_back(&CudaTracer::GetInstance(), false); } + if (trace_switch.test(kProfileMLUOptionBit)) { + tracers_.emplace_back(&MluTracer::GetInstance(), false); + } } Profiler::~Profiler() { alive_.store(false); } diff --git a/paddle/fluid/platform/profiler/profiler.h b/paddle/fluid/platform/profiler/profiler.h index d24ee504bc640..ea346a4fb748d 100644 --- a/paddle/fluid/platform/profiler/profiler.h +++ b/paddle/fluid/platform/profiler/profiler.h @@ -33,9 +33,10 @@ namespace platform { static constexpr uint32_t kProfileCPUOptionBit = 0; static constexpr uint32_t kProfileGPUOptionBit = 1; +static constexpr uint32_t kProfileMLUOptionBit = 2; struct ProfilerOptions { - uint32_t trace_switch = 0; // bit 0: cpu, bit 1: gpu + uint32_t trace_switch = 0; // bit 0: cpu, bit 1: gpu, bit 2: mlu uint32_t trace_level = FLAGS_host_trace_level; }; @@ -45,6 +46,8 @@ class Profiler { static bool IsCuptiSupported(); + static bool IsCnpapiSupported(); + void Prepare(); void Start(); diff --git a/paddle/fluid/platform/profiler/trace_event.h b/paddle/fluid/platform/profiler/trace_event.h index 16ef62fb51555..6d398a26eda10 100644 --- a/paddle/fluid/platform/profiler/trace_event.h +++ b/paddle/fluid/platform/profiler/trace_event.h @@ -50,6 +50,8 @@ enum class TracerEventType { PythonOp = 13, // Used to mark python level userdefined PythonUserDefined = 14, + // Used to mark mlu runtime record returned by cnpapi + MluRuntime = 15, // A flag to denote the number of current types NumTypes }; diff --git a/paddle/fluid/platform/profiler/trace_event_collector.h b/paddle/fluid/platform/profiler/trace_event_collector.h index cc85a178d14e5..5f2bc9dc90db9 100644 --- a/paddle/fluid/platform/profiler/trace_event_collector.h +++ b/paddle/fluid/platform/profiler/trace_event_collector.h @@ -52,6 +52,13 @@ class TraceEventCollector { return thread_names_; } + void ClearAll() { + thread_names_.clear(); + host_events_.clear(); + runtime_events_.clear(); + device_events_.clear(); + } + private: std::unordered_map thread_names_; std::list host_events_; diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index c0b7fd417f272..709cbd401b4cb 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -34,6 +34,10 @@ limitations under the License. */ #ifdef PADDLE_WITH_HIP #include #endif +#ifdef PADDLE_WITH_MLU +#include "paddle/fluid/platform/device/mlu/enforce.h" +#include "paddle/fluid/platform/device/mlu/mlu_info.h" +#endif namespace paddle { namespace platform { @@ -132,6 +136,13 @@ void SynchronizeAllDevice() { PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); } #endif +#ifdef PADDLE_WITH_MLU + int count = GetMLUDeviceCount(); + for (int i = 0; i < count; i++) { + SetMLUDeviceId(i); + PADDLE_ENFORCE_MLU_SUCCESS(cnrtSyncDevice()); + } +#endif } // Print results diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc index 7148afee273fd..6fa326d57bc67 100644 --- a/paddle/fluid/platform/stream_callback_manager.cc +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -80,10 +80,8 @@ void StreamCallbackManager::AddCallback( #endif #if PADDLE_WITH_MLU - VLOG(3) << "MLULaunchCallback at stream: " << stream_ - << " Failed to call MLULaunchCallback, " - << "because mlu not support StreamAddCallback yet. " - << "function: " << func; + VLOG(3) << "MLULaunchCallback at stream: " << stream_; + cnrtInvokeHostFunc(stream_, StreamCallbackFunc, func); #endif } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 51a863d4922fe..32a101106534c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3342,6 +3342,8 @@ All parameter, weight, gradient are variables in Paddle. .def("create", &paddle::platform::Profiler::Create, py::return_value_policy::take_ownership) .def("is_cupti_supported", &paddle::platform::Profiler::IsCuptiSupported) + .def("is_cnpapi_supported", + &paddle::platform::Profiler::IsCnpapiSupported) .def("prepare", [](paddle::platform::Profiler *profiler) { platform::EnableHostEventRecorder(); diff --git a/paddle/phi/common/backend.h b/paddle/phi/common/backend.h index 5543bee144b3b..48b9db113b728 100644 --- a/paddle/phi/common/backend.h +++ b/paddle/phi/common/backend.h @@ -47,6 +47,7 @@ enum class Backend : uint8_t { GPU, XPU, // XPU currently does not exist at the same time as CUDA NPU, // NPU currently does not exist at the same time as CUDA + MLU, // MLU currently does not exist at the same time as CUDA // the third library backend MKLDNN, @@ -114,6 +115,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { case Backend::NPU: os << "NPU"; break; + case Backend::MLU: + os << "MLU"; + break; case Backend::MKLDNN: os << "MKLDNN"; break; @@ -154,6 +158,8 @@ inline Backend StringToBackend(const char* backend_cstr) { return Backend::XPU; } else if (s == std::string("NPU")) { return Backend::NPU; + } else if (s == std::string("MLU")) { + return Backend::MLU; } else if (s == std::string("MKLDNN")) { return Backend::MKLDNN; } else if (s == std::string("GPUDNN")) { diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index 9dd3bbd59a19b..99899b494e2d5 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -41,6 +41,8 @@ Backend TransToPhiBackend(const phi::Place& place) { return Backend::NPU; case AllocationType::IPU: return Backend::IPU; + case AllocationType::MLU: + return Backend::MLU; case AllocationType::CUSTOM: return static_cast( static_cast(Backend::NUM_BACKENDS) + diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index e9b04a183fdc0..83e77b5ecf60c 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -273,7 +273,8 @@ def backward(self, grad_tensor=None, retain_graph=False): if _grad_scalar: # When using amp with Fleet DistributedStrategy, we do loss scaling implicitly. self = _grad_scalar.scale(self) - if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu(): + if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu( + ) or paddle.is_compiled_with_mlu(): # TODO(liuyuhui): Currently only for xpu. Will be removed in the future. scaled_loss = scale_loss(self) if framework._in_eager_mode_: diff --git a/python/paddle/fluid/tests/unittests/mlu/test_gather_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_gather_op_mlu.py new file mode 100644 index 0000000000000..f0aff986fa1ff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_gather_op_mlu.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append('..') +from op_test import OpTest, convert_float_to_uint16 +import paddle +import paddle.fluid as fluid +from paddle.framework import core +from paddle.fluid.dygraph.base import switch_to_static_graph + +paddle.enable_static() + + +def gather_numpy(x, index, axis): + x_transpose = np.swapaxes(x, 0, axis) + tmp_gather = x_transpose[index, ...] + gather = np.swapaxes(tmp_gather, 0, axis) + return gather + + +class TestGatherOp(OpTest): + def setUp(self): + self.op_type = "gather" + self.place = paddle.MLUPlace(0) + self.__class__.use_mlu = True + self.python_api = paddle.gather + self.config() + xnp = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = { + 'X': xnp, + 'Index': np.array(self.index).astype(self.index_type) + } + self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def config(self): + """ + For multi-dimension input + """ + self.x_shape = (10, 20) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase1(TestGatherOp): + def config(self): + """ + For one dimension input + """ + self.x_shape = (100) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase2(TestGatherOp): + def config(self): + """ + For int64_t index type + """ + self.x_shape = (100) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int64" + + +class API_TestDygraphGather(unittest.TestCase): + def test_out1(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32') + index_1 = np.array([1, 2]) + input = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.fluid.layers.gather(input, index) + output_np = output.numpy() + expected_output = np.array([[3, 4], [5, 6]]).astype('int32') + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + def test_out12(self): + paddle.disable_static() + input_1 = np.array([[1, 2], [3, 4], [5, 6]]).astype('int32') + index_1 = np.array([1, 2]) + x = paddle.to_tensor(input_1) + index = paddle.to_tensor(index_1) + output = paddle.gather(x, index, axis=0) + output_np = output.numpy() + expected_output = gather_numpy(input_1, index_1, axis=0) + self.assertTrue(np.allclose(output_np, expected_output)) + paddle.enable_static() + + def test_zero_index(self): + paddle.disable_static() + x = paddle.to_tensor([[1, 2], [3, 4]]).astype('int32') + index = paddle.to_tensor(np.array([]).astype('int64')) + for axis in range(len(x.shape)): + out = paddle.gather(x, index, axis) + expected_shape = list(x.shape) + expected_shape[axis] = 0 + self.assertEqual(list(out.shape), expected_shape) + paddle.enable_static() + + +class TestGathertError(unittest.TestCase): + def test_error1(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + + shape = [8, 9, 6] + x = paddle.fluid.data(shape=shape, dtype='int8', name='x') + axis = paddle.fluid.data(shape=[1], dtype='float32', name='axis') + index = paddle.fluid.data(shape=shape, dtype='int32', name='index') + index_float = paddle.fluid.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + def test_axis_dtype(): + paddle.gather(x, index, axis=1.11) + + self.assertRaises(TypeError, test_axis_dtype) + + def test_axis_dtype1(): + paddle.gather(x, index, axis=axis) + + self.assertRaises(TypeError, test_axis_dtype1) + + def test_error2(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + + shape = [8, 9, 6] + x = fluid.data(shape=shape, dtype='int8', name='x') + index = fluid.data(shape=shape, dtype='int32', name='mask') + index_float = fluid.data( + shape=shape, dtype='float32', name='index_float') + + def test_x_type(): + paddle.fluid.layers.gather(x, index) + + self.assertRaises(TypeError, test_x_type) + + def test_index_type(): + paddle.fluid.layers.gather(x, index_float) + + self.assertRaises(TypeError, test_index_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_gelu_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_gelu_op_mlu.py new file mode 100644 index 0000000000000..c62d30d43c089 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_gelu_op_mlu.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +from scipy import special +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +def np_gelu(x): + y = 0.5 * x * (1 + special.erf(x / np.sqrt(2))) + return y + + +class TestGelu(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "gelu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np_gelu(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.007) + + +class TestGeluFp16(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "gelu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) + out = np_gelu(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestGeluNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + c = paddle.multiply(a, b) + + fc_1 = fluid.layers.fc(input=c, size=128) + fc_1_gelu = fluid.layers.gelu(fc_1) + prediction = fluid.layers.fc(input=fc_1_gelu, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred, atol=1e-3)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss, atol=1e-3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_leaky_relu_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_leaky_relu_op_mlu.py new file mode 100644 index 0000000000000..ec2150fceb133 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_leaky_relu_op_mlu.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +from test_activation_op import ref_leaky_relu +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestLeadyRelu(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "leaky_relu" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + + self.set_inputs() + self.set_attrs() + self.set_outputs() + + def set_inputs(self): + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + + def set_attrs(self): + self.attrs = {} + + def set_outputs(self): + alpha = 0.02 if 'alpha' not in self.attrs else self.attrs['alpha'] + out = ref_leaky_relu(self.inputs['X'], alpha) + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.006) + else: + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestLeadyReluFP16(TestLeadyRelu): + def init_dtype(self): + self.dtype = np.float16 + + +class TestLeadyRelu2(TestLeadyRelu): + def set_attrs(self): + self.attrs = {'alpha': 0.5} + + +class TestLeadyRelu3(TestLeadyRelu): + def set_attrs(self): + self.attrs = {'alpha': -0.5} + + +class TestLeakyReluNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + x_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + x = paddle.static.data(name="x", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + y = paddle.nn.functional.leaky_relu(x) + + fc_1 = fluid.layers.fc(input=y, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run(main_prog, + feed={"x": x_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_relu6_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_relu6_op_mlu.py new file mode 100644 index 0000000000000..54b1afd036331 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_relu6_op_mlu.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle.fluid as fluid +import paddle +from op_test import OpTest + +import numpy as np +import unittest +import sys +sys.path.append("..") + +paddle.enable_static() +SEED = 2021 + + +def ref_relu6(x, threshold=6.0): + out = np.copy(x) + out[np.abs(x - threshold) < 0.005] = threshold + 0.02 + out = np.minimum(np.maximum(x, 0), threshold) + return out + + +class TestRelu6(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "relu6" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(-1, 10, [10, 12]).astype(self.dtype) + x[np.abs(x) < 0.005] = 0.02 + out = ref_relu6(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': 6.0} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + def init_dtype(self): + self.dtype = np.float32 + + +class TestRelu6Float16(TestRelu6): + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def set_attrs(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestReluNeg(TestRelu6): + def setUp(self): + self.set_mlu() + self.op_type = "relu6" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(-10, -1, [10, 12]).astype(self.dtype) + x[np.abs(x) < 0.005] = 0.02 + out = ref_relu6(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'threshold': 6.0} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestRelu6Net(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + sum = paddle.add(a, b) + z = paddle.nn.functional.relu6(sum) + + fc_1 = fluid.layers.fc(input=z, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_op_mlu.py new file mode 100644 index 0000000000000..f4c5612377e1c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_sigmoid_op_mlu.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +from paddle.fluid.tests.unittests.op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestMLUSigmoid(OpTest): + def setUp(self): + self.op_type = "sigmoid" + self.set_mlu() + self.init_dtype() + + np.random.seed(SEED) + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = 1 / (1 + np.exp(-x)) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], 'Out', max_relative_error=0.01) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestMLUSigmoidFp16(TestMLUSigmoid): + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_tanh_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_tanh_op_mlu.py new file mode 100644 index 0000000000000..a5aeeac0ffb9e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_tanh_op_mlu.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +class TestTanh(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "tanh" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np.tanh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + if self.dtype == np.float16: + self.check_grad(['X'], 'Out', max_relative_error=0.009) + else: + self.check_grad(['X'], 'Out', max_relative_error=0.009) + + +class TestTanhFp16(OpTest): + def setUp(self): + self.set_mlu() + self.op_type = "tanh" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) + out = np.tanh(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestTanhNet(unittest.TestCase): + def _test(self, run_mlu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.random.random(size=(32, 32)).astype('float32') + b_np = np.random.random(size=(32, 32)).astype('float32') + label_np = np.random.randint(2, size=(32, 1)).astype('int64') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + label = paddle.static.data( + name="label", shape=[32, 1], dtype='int64') + + c = paddle.multiply(a, b) + d = paddle.tanh(c) + + fc_1 = fluid.layers.fc(input=d, size=128) + prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.reduce_mean(cost) + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + if run_mlu: + place = paddle.MLUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + print("Start run on {}".format(place)) + for epoch in range(100): + + pred_res, loss_res = exe.run( + main_prog, + feed={"a": a_np, + "b": b_np, + "label": label_np}, + fetch_list=[prediction, loss]) + if epoch % 10 == 0: + print("Epoch {} | Prediction[0]: {}, Loss: {}".format( + epoch, pred_res[0], loss_res)) + + return pred_res, loss_res + + def test_mlu(self): + cpu_pred, cpu_loss = self._test(False) + mlu_pred, mlu_loss = self._test(True) + + self.assertTrue(np.allclose(mlu_pred, cpu_pred)) + self.assertTrue(np.allclose(mlu_loss, cpu_loss)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/profiler/profiler.py b/python/paddle/profiler/profiler.py index ca35348e4cd09..77adbaff34859 100644 --- a/python/paddle/profiler/profiler.py +++ b/python/paddle/profiler/profiler.py @@ -53,16 +53,19 @@ class ProfilerState(Enum): class ProfilerTarget(Enum): r""" - ProfilerTarget is used to specify target device for :ref:`profiling ` . Only CPU and GPU are supported currently. + ProfilerTarget is used to specify target device for :ref:`profiling ` . Only CPU, GPU and MLU are supported currently. The meaning of each ProfilerState is as following - **ProfilerTarget.CPU** : Profile events on CPU. - **ProfilerTarget.GPU** : Profile events on GPU. + + - **ProfilerTarget.MLU** : Profile events on MLU. """ CPU = 0 GPU = 1 + MLU = 2 def make_scheduler(*, @@ -259,6 +262,8 @@ def _get_supported_targets() -> Iterable[ProfilerTarget]: """ if _Profiler.is_cupti_supported(): return [ProfilerTarget.CPU, ProfilerTarget.GPU] + if _Profiler.is_cnpapi_supported(): + return [ProfilerTarget.CPU, ProfilerTarget.MLU] return [ProfilerTarget.CPU] @@ -267,7 +272,7 @@ class Profiler: Profiler context manager, user interface to manage profiling process to start, stop, export profiling data and print summary table. Args: - targets (list, optional): specify target devices to profile, and all existing and supported devices will be chosen by default. Currently supported values, :ref:`ProfilerTarget.CPU ` and :ref:`ProfilerTarget.GPU ` . + targets (list, optional): specify target devices to profile, and all existing and supported devices will be chosen by default. Currently supported values, :ref:`ProfilerTarget.CPU ` , :ref:`ProfilerTarget.GPU ` and :ref:`ProfilerTarget.MLU ` . scheduler (Callable|tuple, optional): If it is a callable object, it takes a step number as parameter and return the corresponding :ref:`ProfilerState `. This callable object can be generated by :ref:`make_scheduler ` function. If not provided (None), the default scheduler will keep tracing until the profiler exits. If it is a tuple, it has two values start_batch and end_batch, which means profiling range [start_batch, end_batch). @@ -408,6 +413,8 @@ def __init__( profileoption.trace_switch |= 1 if ProfilerTarget.GPU in self.targets: profileoption.trace_switch |= (1 << 1) + if ProfilerTarget.MLU in self.targets: + profileoption.trace_switch |= (1 << 2) wrap_optimizers() self.profiler = _Profiler.create(profileoption) if callable(scheduler): diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 95f145cf447b5..5c561060564ec 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -105,9 +105,9 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): place = _current_expected_place() elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace, core.XPUPlace, - core.CustomPlace)): + core.MLUPlace, core.CustomPlace)): raise ValueError( - "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.CustomPlace" + "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace, paddle.NPUPlace, paddle.XPUPlace, paddle.MLUPlace, paddle.CustomPlace" ) if not isinstance(data, np.ndarray): diff --git a/tools/dockerfile/Dockerfile.mlu b/tools/dockerfile/Dockerfile.mlu index f7823738afc53..07535a637431e 100644 --- a/tools/dockerfile/Dockerfile.mlu +++ b/tools/dockerfile/Dockerfile.mlu @@ -2,9 +2,9 @@ # Update CNTOOLKIT_VERSION, CNNL_VERSION and CNCL_VERSION if using other versions # # Build: -# - CNTOOLKIT_VERSION 2.6.5-1 -# - CNNL_VERSION 1.8.3-1 -# - CNCL_VERSION 1.0.2-1 +# - CNTOOLKIT_VERSION 2.8.1-1 +# - CNNL_VERSION 1.9.3-1 +# - CNCL_VERSION 1.0.4-1 # # Download three packages from FTP (need to connect cambricon AE to get FTP url) # - cntoolkit_2.6.5-1.ubuntu18.04_amd64.deb @@ -21,9 +21,9 @@ # (get cncl pkg) # # docker build -f Dockerfile.mlu \ -# --build-arg CNTOOLKIT_VERSION=2.6.5-1 \ -# --build-arg CNNL_VERSION=1.8.3-1 \ -# --build-arg CNCL_VERSION=1.0.2-1 \ +# --build-arg CNTOOLKIT_VERSION=2.8.1-1 \ +# --build-arg CNNL_VERSION=1.9.3-1 \ +# --build-arg CNCL_VERSION=1.0.4-1 \ # -t paddlepaddle/paddle:latest-dev-mlu . # # without mlu device: @@ -40,9 +40,9 @@ MAINTAINER PaddlePaddle Authors ENV WITH_GPU=OFF -ARG CNTOOLKIT_VERSION=2.6.5-1 -ARG CNNL_VERSION=1.8.3-1 -ARG CNCL_VERSION=1.0.2-1 +ARG CNTOOLKIT_VERSION=2.8.1-1 +ARG CNNL_VERSION=1.9.3-1 +ARG CNCL_VERSION=1.0.4-1 ARG CNTOOLKIT_PKG=cntoolkit_$CNTOOLKIT_VERSION.ubuntu18.04_amd64.deb ARG CNNL_PKG=cnnl_$CNNL_VERSION.ubuntu18.04_amd64.deb ARG CNCL_PKG=cncl_$CNCL_VERSION.ubuntu18.04_amd64.deb