-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add isclose op, test=develop * add isclose op, test=develop * add isclose api, test=develop * rm useless code * rm useless code * update python api of isclose * add some unittest of isclose op, test=develop
- Loading branch information
1 parent
65c242e
commit d2200e9
Showing
7 changed files
with
641 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/isclose_op.h" | ||
#include <cmath> | ||
#include <string> | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/op_version_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
struct GetTensorValue<platform::CPUDeviceContext, T> { | ||
T operator()(const platform::CPUDeviceContext& dev_ctx, | ||
const framework::Tensor& tensor) const { | ||
return *(tensor.data<T>()); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct IscloseFunctor<platform::CPUDeviceContext, T> { | ||
void operator()(const platform::CPUDeviceContext& ctx, | ||
const framework::Tensor& in, const framework::Tensor& other, | ||
const double rtol, const double atol, bool equal_nan, | ||
framework::Tensor* output) { | ||
auto* in_a = in.data<T>(); | ||
auto* in_b = other.data<T>(); | ||
auto* out_data = output->mutable_data<bool>(ctx.GetPlace()); | ||
auto num = in.numel(); | ||
// *out_data = true; | ||
for (int i = 0; i < num; i++) { | ||
out_data[i] = true; | ||
} | ||
for (int i = 0; i < num; i++) { | ||
const T a = in_a[i], b = in_b[i]; | ||
bool val; | ||
if (std::isnan(a) || std::isnan(b)) { | ||
val = equal_nan && std::isnan(a) == std::isnan(b); | ||
} else { | ||
T left = (a > b ? a - b : b - a); | ||
T right = atol + (b > 0 ? rtol * b : (-rtol) * b); | ||
T diff = (left > right ? left - right : right - left); | ||
val = a == b || left <= right || diff <= 1e-15; | ||
} | ||
// *out_data &= val; | ||
out_data[i] = val; | ||
} | ||
} | ||
}; | ||
|
||
class IscloseOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Input", | ||
"The input tensor, it's data type should be float32, float64."); | ||
AddInput("Other", | ||
"The input tensor, it's data type should be float32, float64."); | ||
AddInput("Rtol", "The relative tolerance.").AsDispensable(); | ||
AddInput("Atol", "The absolute tolerance.").AsDispensable(); | ||
AddOutput("Out", "The output tensor, it's data type is bool."); | ||
AddAttr<std::string>("rtol", | ||
"The relative tolerance. Default: :math:`1e-5` .") | ||
.SetDefault("1e-5"); | ||
AddAttr<std::string>("atol", | ||
"The absolute tolerance. Default: :math:`1e-8` .") | ||
.SetDefault("1e-8"); | ||
AddAttr<bool>("equal_nan", | ||
"If :math:`True` , then two :math:`NaNs` will be " | ||
"compared as equal. Default: :math:`False` .") | ||
.SetDefault(false); | ||
|
||
AddComment(R"DOC( | ||
This operator checks if all :math:`x` and :math:`y` satisfy the condition: | ||
.. math:: | ||
\left| x - y \right| \leq atol + rtol \times \left| y \right| | ||
elementwise, for all elements of :math:`x` and :math:`y`. The behaviour of this | ||
operator is analogous to :math:`numpy.isclose`, namely that it returns :math:`True` if | ||
two tensors are elementwise equal within a tolerance. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class IscloseOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Isclose"); | ||
OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Isclose"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Isclose"); | ||
|
||
auto input_dim = ctx->GetInputDim("Input"); | ||
auto other_dim = ctx->GetInputDim("Other"); | ||
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(), | ||
platform::errors::PreconditionNotMet( | ||
"Input(Input) and Input(Other) must have the same " | ||
"dimension size.")); | ||
int n = input_dim.size(); | ||
bool is_runtime = ctx->IsRuntime(); | ||
for (int i = 0; i < n; i++) { | ||
if (is_runtime) { | ||
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], | ||
platform::errors::PreconditionNotMet( | ||
"The value at dim %d of Input(Input) is not " | ||
"equal to the Input(Other): %ld != %ld.", | ||
i, input_dim[i], other_dim[i])); | ||
} else { | ||
if (!(input_dim[i] < 0 || other_dim[i] < 0)) { | ||
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i], | ||
platform::errors::PreconditionNotMet( | ||
"The value at dim %d of Input(Input) is not " | ||
"equal to the Input(Other): %ld != %ld.", | ||
i, input_dim[i], other_dim[i])); | ||
} | ||
} | ||
} | ||
|
||
ctx->SetOutputDim("Out", input_dim); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), | ||
ctx.device_context()); | ||
} | ||
}; | ||
|
||
class IscloseOpVarTypeInference : public framework::VarTypeInference { | ||
public: | ||
void operator()(framework::InferVarTypeContext* ctx) const override { | ||
ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
using CPU = paddle::platform::CPUDeviceContext; | ||
|
||
REGISTER_OPERATOR( | ||
isclose, ops::IscloseOp, ops::IscloseOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, | ||
ops::IscloseOpVarTypeInference); | ||
REGISTER_OP_CPU_KERNEL(isclose, ops::IscloseKernel<CPU, float>, | ||
ops::IscloseKernel<CPU, double>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/operators/isclose_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
struct GetTensorValue<platform::CUDADeviceContext, T> { | ||
T operator()(const platform::CUDADeviceContext& dev_ctx, | ||
const framework::Tensor& tensor) const { | ||
const T* data = tensor.data<T>(); | ||
T value; | ||
const auto gpu_place = | ||
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()); | ||
memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T), | ||
dev_ctx.stream()); | ||
return value; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
__global__ void IscloseCUDAKernel(const T* in_data, const T* other_data, | ||
const double rtol, const double atol, | ||
bool equal_nan, int num, bool* out_data) { | ||
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||
bool val; | ||
for (int i = idx; i < num; i += blockDim.x * gridDim.x) { | ||
const T a = in_data[i], b = other_data[i]; | ||
if (isnan(a) || isnan(b)) { | ||
val = equal_nan && isnan(a) == isnan(b); | ||
} else { | ||
T left = (a > b ? a - b : b - a); | ||
T right = atol + (b > 0 ? rtol * b : (-rtol) * b); | ||
T diff = (left > right ? left - right : right - left); | ||
val = a == b || left <= right || diff <= 1e-15; | ||
} | ||
out_data[i] = val; | ||
// if (!val) *out_data = false; | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct IscloseFunctor<platform::CUDADeviceContext, T> { | ||
void operator()(const platform::CUDADeviceContext& dev_ctx, | ||
const framework::Tensor& in, const framework::Tensor& other, | ||
const double rtol, const double atol, bool equal_nan, | ||
framework::Tensor* output) { | ||
int num = in.numel(); | ||
const T* in_data = in.data<T>(); | ||
const T* other_data = other.data<T>(); | ||
bool* out_data = output->mutable_data<bool>(dev_ctx.GetPlace()); | ||
int block = 1024; | ||
int grid = (block - 1 + num) / block; | ||
grid = (grid > block) ? block : grid; | ||
#ifdef PADDLE_WITH_HIP | ||
hipMemset(out_data, true, num * sizeof(bool)); | ||
#else | ||
cudaMemset(out_data, true, num * sizeof(bool)); | ||
#endif | ||
IscloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>( | ||
in_data, other_data, rtol, atol, equal_nan, num, out_data); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
using CUDA = paddle::platform::CUDADeviceContext; | ||
REGISTER_OP_CUDA_KERNEL(isclose, ops::IscloseKernel<CUDA, float>, | ||
ops::IscloseKernel<CUDA, double>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/eigen.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/platform/place.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
using Tensor = framework::Tensor; | ||
|
||
template <typename DeviceContext, typename T> | ||
struct GetTensorValue { | ||
T operator()(const platform::DeviceContext& ctx, | ||
const framework::Tensor& tensor) const; | ||
}; | ||
|
||
template <typename DeviceContext, typename T> | ||
struct IscloseFunctor { | ||
void operator()(const DeviceContext& ctx, const framework::Tensor& in, | ||
const framework::Tensor& other, const float rtol, | ||
const float atol, bool equal_nan, framework::Tensor* output); | ||
}; | ||
|
||
template <typename DeviceContext, typename T> | ||
class IscloseKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
// get attrs | ||
bool equal_nan = ctx.Attr<bool>("equal_nan"); | ||
// get input/output | ||
const auto* input = ctx.Input<Tensor>("Input"); | ||
const auto* other = ctx.Input<Tensor>("Other"); | ||
auto* out = ctx.Output<Tensor>("Out"); | ||
|
||
double rtol_v = std::stod(ctx.Attr<std::string>("rtol")); | ||
double atol_v = std::stod(ctx.Attr<std::string>("atol")); | ||
|
||
auto& dev_ctx = ctx.template device_context<DeviceContext>(); | ||
GetTensorValue<DeviceContext, double> get_tensor_value; | ||
if (ctx.HasInput("Rtol")) { | ||
const auto* rtol = ctx.Input<Tensor>("Rtol"); | ||
PADDLE_ENFORCE_EQ( | ||
rtol->numel(), 1, | ||
platform::errors::InvalidArgument( | ||
"Input(Rtol) size must be 1, but get %d.", rtol->numel())); | ||
PADDLE_ENFORCE_EQ(rtol->type(), framework::proto::VarType::FP64, | ||
platform::errors::InvalidArgument( | ||
"Input(Rtol) type must be double, but get %s.", | ||
framework::DataTypeToString(rtol->type()))); | ||
rtol_v = get_tensor_value(dev_ctx, *rtol); | ||
} | ||
if (ctx.HasInput("Atol")) { | ||
const auto* atol = ctx.Input<Tensor>("Atol"); | ||
PADDLE_ENFORCE_EQ( | ||
atol->numel(), 1, | ||
platform::errors::InvalidArgument( | ||
"Input(Atol) size must be 1, but get %d", atol->numel())); | ||
PADDLE_ENFORCE_EQ(atol->type(), framework::proto::VarType::FP64, | ||
platform::errors::InvalidArgument( | ||
"Input(Atol) type must be double, but get %s", | ||
framework::DataTypeToString(atol->type()))); | ||
atol_v = get_tensor_value(dev_ctx, *atol); | ||
} | ||
|
||
IscloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v, | ||
equal_nan, out); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.