Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bincount op #36317

Merged
merged 16 commits into from
Oct 25, 2021
Merged

Add bincount op #36317

merged 16 commits into from
Oct 25, 2021

Conversation

smallv0221
Copy link
Contributor

@smallv0221 smallv0221 commented Oct 9, 2021

PR types

New features

PR changes

OPs

Describe

Add bincount op
中文文档pr:PaddlePaddle/docs#3959
英文文档:
image

@paddle-bot-old
Copy link

paddle-bot-old bot commented Oct 9, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op_function_generator.cc


Args:
x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor.
weights (Tensor, optional): Weight for each value in the input tensor. Should have the same shape as input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default is None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Args:
x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor.
weights (Tensor, optional): Weight for each value in the input tensor. Should have the same shape as input.
minlength (int): Minimum number of bins. Should be non-negative integer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int -> int, optional
Default is 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

static_cast<double>(0));
for (int64_t i = 0; i < input_numel; i++) {
output_data[input_data[i]] += static_cast<double>(weights_data[i]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个float和double的分支代码是否可以使用T合并,也和int类型的Weights更适配

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边是为了保证与竞品一致。weights可以是任意数据类型,只有当weights float32时,output的类型才是float32。其他三种情况下output都是float64。


KernelBincount<T, InputT, double><<<GET_BLOCKS(input_numel),
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
input_data, input_numel, has_weights, weights_data, output_data);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同CPU,float和double的分支代码是否可以使用T合并

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


Args:
x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor.
weights (Tensor, optional): Weight for each value in the input tensor. Should have the same shape as input. Default is None.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight的dtype是否也要说明

@@ -44,6 +44,7 @@
from .linalg import cholesky # noqa: F401
from .linalg import bmm # noqa: F401
from .linalg import histogram # noqa: F401
from .linalg import bincount # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shall also add bincount in tensor_method_func list below to get paddle.Tensor.bincount

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines +86 to +89
PADDLE_ENFORCE_GE(
input_min, static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check of PADDLE_ENFORCE* should in InferShape rather than in Compute, so we can detect illegal input earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 147 to 155
PADDLE_ENFORCE_EQ(input_type_match, true,
platform::errors::InvalidArgument(
"Input(X) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(input_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check of PADDLE_ENFORCE* should in InferShape rather than in Compute, so we can detect illegal input earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 139 to 141
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
Copy link
Contributor

@jeff41404 jeff41404 Oct 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check of PADDLE_ENFORCE* should in InferShape rather than in Compute, so we can detect illegal input earlier. but in this case, because register BincountCUDAKernel in cuda, no need this check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines +46 to +50
PADDLE_ENFORCE_GE(
*std::min_element(input_data, input_data + input_numel),
static_cast<InputT>(0),
platform::errors::InvalidArgument(
"The elements in input tensor must be non-negative ints"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check of PADDLE_ENFORCE* should in InferShape rather than in Compute, so we can detect illegal input earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 101 to 109
PADDLE_ENFORCE_EQ(input_type_match, true,
platform::errors::InvalidArgument(
"Input(X) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(input_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check of PADDLE_ENFORCE* should in InferShape rather than in Compute, so we can detect illegal input earlier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

Comment on lines 1323 to 1324
if paddle.max(x) < 0:
raise ValueError("Elements in Input(x) should all be non-negative")
Copy link
Contributor

@jeff41404 jeff41404 Oct 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor.
shall we must check:

  1. paddle.min(x) >= 0:
  2. x.ndim == 1;
  3. x.numel() != 0

jeff41404
jeff41404 previously approved these changes Oct 20, 2021
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

jeff41404
jeff41404 previously approved these changes Oct 21, 2021
XiaoguangHu01
XiaoguangHu01 previously approved these changes Oct 21, 2021
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG API

Args:
x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor.
weights (Tensor, optional): Weight for each value in the input tensor. Should have the same shape as input. Default is None.
minlength (int, optional): Minimum number of bins. Should be non-negative integer. Default is 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

少了name参数

TCChenlong
TCChenlong previously approved these changes Oct 22, 2021
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG API

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for op_function_generator.cc

@jeff41404 jeff41404 merged commit 39f1912 into PaddlePaddle:develop Oct 25, 2021
smallv0221 added a commit to smallv0221/Paddle that referenced this pull request Oct 25, 2021
* Add bincount op

* upload cpu version

* fix unitest

* fix unittest

* fix unittest

* fix en doc

* add more test

* fix en doc

* add more test case

* fix test

* fix input vailidation

* fix input check

* fix unittest

* fix test

* fix en doc

cherry-pick
XiaoguangHu01 pushed a commit that referenced this pull request Oct 26, 2021
* Add bincount op

* upload cpu version

* fix unitest

* fix unittest

* fix unittest

* fix en doc

* add more test

* fix en doc

* add more test case

* fix test

* fix input vailidation

* fix input check

* fix unittest

* fix test

* fix en doc

cherry-pick
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants