Skip to content

Commit

Permalink
[feat] add isinf、trunc、round、hardsigmoid、elu、threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
Yin Hongyun committed Nov 20, 2024
1 parent 3e82955 commit 8fcc8e9
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 0 deletions.
104 changes: 104 additions & 0 deletions diopi_test/python/configs/diopi_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,110 @@


diopi_configs = {
'has_inf': dict(
name=["isinf"],
interface=["torch"],
atol=1e-3,
rtol=1e-4,
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((), (1024,), (2, 4096), (64, 28, 28),
(32, 64, 112, 112), (64, 3, 7, 28, 28),
(0,), (256, 0), (8, 0, 128)),
"dtype": [np.float16, np.float32, np.float64,
np.int16, np.int32, np.int64,
np.uint8, np.int8],
},
],
),
),

'trunc': dict(
name=["trunc"],
interface=["torch"],
atol=1e-3,
rtol=1e-4,
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
"dtype": [np.float32, np.float16, np.float64],
},
],
),
),

'round': dict(
name=["round"],
interface=["torch"],
atol=1e-3,
rtol=1e-4,
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
"dtype": [np.float32, np.float16, np.float64],
},
],
),
),

'round': dict(
name=["hardsigmoid"],
atol=1e-3,
rtol=1e-4,
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
"dtype": [np.float32, np.float16, np.float64],
},
],
),
),

'elu': dict(
name=["elu"],
atol=1e-3,
rtol=1e-4,
para=dict(
alpha=[0.234, 4.8, -10, 1.0],
),
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
"dtype": [np.float32, np.float16, np.float64],
},
],
),
),

'threshold_relu': dict(
name=["threshold"],
atol=1e-3,
rtol=1e-4,
para=dict(
threshold=[0.234, 4.8, -10, 1.0],
value=[0.2, 4.2, -10, 2.0],
),
tensor_para=dict(
args=[
{
"ins": ['input'],
"shape": ((2, 16, 32, 56, 56), (2, 64, 32, 32), (2, 96, 28), (2, 16)),
"dtype": [np.float32, np.float16, np.float64],
},
],
),
),

# FIXME batch_norm输入0size的张量报错
'batch_norm': dict(
name=["batch_norm"],
Expand Down
45 changes: 45 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,51 @@ def promote_type(input: Tensor, promoted_dtype: Dtype) -> Dtype:
]
return dtype1 if dtype1 not in need_promote_types else promoted_dtype

def isinf(input) -> Tensor:
func = check_function("diopiHasInf")
out = Tensor(size=input.size(), dtype=Dtype.bool)
ret = func(input.context(), out, input)
check_returncode(ret)
return out

def trunc(input) -> Tensor:
func = check_function("diopiTrunc")
out = Tensor(size=input.size(), dtype=input.get_dtype())
ret = func(input.context(), out, input)
check_returncode(ret)
return out

def round(input) -> Tensor:
func = check_function("diopiTRound")
out = Tensor(size=input.size(), dtype=input.get_dtype())
ret = func(input.context(), out, input)
check_returncode(ret)
return out

def hardsigmoid(input) -> Tensor:
func = check_function("diopiHardSigmoid")
out = Tensor(size=input.size(), dtype=input.get_dtype())
ret = func(input.context(), out, input)
check_returncode(ret)
return out

def elu(input, alpha) -> Tensor:
func = check_function("diopiElu")
out = Tensor(size=input.size(), dtype=input.get_dtype())
value = Scalar(alpha)
ret = func(input.context(), out, input, value)
check_returncode(ret)
return out


def threshold(input, threshold, value) -> Tensor:
func = check_function("diopiThresholdRelu")
out = Tensor(size=input.size(), dtype=input.get_dtype())
threshold = Scalar(threshold)
value = Scalar(value)
ret = func(input.context(), out, input, threshold, value)
check_returncode(ret)
return out

def fill_(input, value):
func = check_function("diopiFill")
Expand Down
57 changes: 57 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,63 @@ const char* diopiGetImplVersion() {
return version;
}

diopiError_t diopiHasInf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);

auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_FUNC(isinf_out, atOut, atInput);

return diopiSuccess;
}

diopiError_t diopiTrunc(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_FUNC(trunc_out, atOut, atInput);
return diopiSuccess;
}

diopiError_t diopiRound(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_FUNC(round_out, atOut, atInput);

return diopiSuccess;
}

diopiError_t diopiHardSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
CALL_ATEN_FUNC(hardsigmoid_out, atOut, atInput);

return diopiSuccess;
}

diopiError_t diopiThresholdRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* threshold,
const diopiScalar_t* value) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
auto atThreshold = impl::aten::buildAtScalar(threshold);
auto atValue = impl::aten::buildAtScalar(value);
CALL_ATEN_FUNC(threshold_out, atOut, atInput, atThreshold, atValue);

return diopiSuccess;
}

diopiError_t diopiElu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* alpha) {
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atOut = impl::aten::buildATen(out);
auto atAlpha = impl::aten::buildAtScalar(alpha);
CALL_ATEN_FUNC(elu_out, atOut, atInput, atAlpha);
return diopiSuccess;
}

diopiError_t diopiRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input) {
impl::aten::setCurStream(ctx);
auto atOut = impl::aten::buildATen(out);
Expand Down
30 changes: 30 additions & 0 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,36 @@ extern "C" {
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetVendorName();
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetImplVersion();
DIOPI_RT_API DIOPI_ATTR_WEAK const char* diopiGetLastErrorString();
/**
* @brief Returns whether the input tensor contains any Inf values.
*/
DIOPI_API diopiError_t diopiHasInf(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief Truncates the input tensor to an integer value.
*/
DIOPI_API diopiError_t diopiTrunc(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief Rounds the input tensor to the nearest integer value.
*/
DIOPI_API diopiError_t diopiRound(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief Applies the hard sigmoid activation function to an input tensor.
*/
DIOPI_API diopiError_t diopiHardSigmoid(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);

/**
* @brief Applies a thresholded rectified linear unit (ReLU) activation function to an input tensor.
*/
DIOPI_API diopiError_t diopiThresholdRelu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* threshold,
const diopiScalar_t* value);

/**
* @brief Applies the exponential linear unit (ELU) activation function to an input tensor.
*/
DIOPI_API diopiError_t diopiElu(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* alpha);

/**
* @brief Applies a 2D convolution over an input image composed of several input planes.
Expand Down

0 comments on commit 8fcc8e9

Please sign in to comment.