Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the argsort operator #11174

Merged
merged 13 commits into from
Jul 2, 2018
83 changes: 83 additions & 0 deletions paddle/fluid/operators/argsort_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/argsort_op.h"

namespace paddle {
namespace operators {

class ArgsortOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of ArgsortOp should not be null.");

auto in_dims = ctx->GetInputDim("X");
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove static_cast<int>() .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


auto num_dims = in_dims.size();
PADDLE_ENFORCE(axis < num_dims,
"Attr(axis) %d of ArgsortOp is out of bounds for Input(X) "
"dimension %d.",
axis, num_dims);
PADDLE_ENFORCE(axis >= 0 || axis == -1,
"Attr(axis) %d of ArgsortOp must be nonnegative or equal to "
"-1.",
axis);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If axis < 0, we can re-set the axis = in_dims.size() + axis ? not limited to -1 for the negative value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


ctx->SetOutputDim("Out", in_dims);
ctx->SetOutputDim("Indices", in_dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
};

class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) The input of Argsort op.");
AddOutput("Out", "(Tensor) The sorted tensor of Argsort op.");
AddOutput("Indices",
"(Tensor) The indices of a tensor giving the sorted order.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give the shape for Out and Indices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

AddComment(R"DOC(
Argsort operator

Performs sorting on the input tensor along the given axis and outputs two
tensors, Output(Out) and Output(Indices). They reserve the same shape
with Input(X), and Output(Out) represents the sorted tensor while
Output(Indices) gives the sorted order along the given axis Attr(axis).

)DOC");
AddAttr<int>("axis",
"(int, default -1) The axis along which to sort the tensor, "
"default -1, the last dimension.")
.SetDefault(-1);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>);
78 changes: 78 additions & 0 deletions paddle/fluid/operators/argsort_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices");
int axis = static_cast<int>(ctx.Attr<int>("axis"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove static_cast<int>()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


auto in_dims = input->dims();
axis = (axis == -1) ? (in_dims.size() - 1) : axis;

const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* idx_data = indices->mutable_data<int64_t>(ctx.GetPlace());

int64_t part_dims_prod = input->numel() / in_dims[axis];
for (int64_t i = 0; i < part_dims_prod; ++i) {
int64_t idx = i;
std::vector<int64_t> idx_vec(in_dims.size(), 0);
for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) {
if (dim != axis) {
idx_vec[dim] = idx % in_dims[dim];
idx /= in_dims[dim];
}
}
std::vector<std::pair<T, int64_t>> in_vec;
std::vector<int64_t> org_index_vec(in_dims[axis], 0);
for (int64_t j = 0; j < in_dims[axis]; ++j) {
idx_vec[axis] = j;
int64_t index = idx_vec[0];
for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) {
index = index * in_dims[dim + 1] + idx_vec[dim + 1];
}
in_vec.push_back(std::pair<T, int64_t>(in_data[index], j));
org_index_vec[j] = index;
}

std::sort(
in_vec.begin(), in_vec.end(),
[](const std::pair<T, int64_t>& v1, const std::pair<T, int64_t>& v2) {
return v1.first < v2.first;
});

for (size_t j = 0; j < org_index_vec.size(); ++j) {
int64_t index = org_index_vec[j];
out_data[index] = in_vec[j].first;
idx_data[index] = in_vec[j].second;
}
}
Copy link
Collaborator

@sneaxiy sneaxiy Jun 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 40-73 can be changed to be more efficient and save memory used.

int64_t part_dims_prod = input->numel() / in_dims[axis];
int64_t step = 1;
for (int64_t i = in_dims.size()-1; i > axis; --i) step *= in_dims[i];

std::vector<int64_t> org_index_vec(in_dims.size());
std::vector<int64_t> idx_vec(in_dims.size());
idx_vec[axis] = 0;
for (int64_t i = 0; i < part_dims_prod; ++i) {
  for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) {
    if (dim != axis) {
      idx_vec[dim] = idx % in_dims[dim];
      idx /= in_dims[dim];
    }
  }

  int64_t start_index = idx_vec[0];
  for (int64_t dim = 1; dim < in_dims.size(); ++dim) {
    start_index = start_index * in_dims[dim] + idx_vec[dim];
  }

  for (int64_t j = 0; j < in_dims.size(); ++j) {
    org_index_vec[j] = start_index + j*step;
  }

  std::sort(
      org_index_vec.begin(), org_index_vec.end(),
      [in_data](int64_t idx1, int64_t idx2) {
        return in_data[idx1] < in_data[idx2];
      });

  for (size_t j = 0; j < org_index_vec.size(); ++j) {
    int64_t org_index = org_index_vec[j];
    int64_t ret_index = start_index + j*step;
    out_data[ret_index] = in_data[org_index];
    idx_data[ret_index] = org_index;
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It is a good idea to only sort the index, and I made the change. Please take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent!

}
};

} // namespace operators
} // namespace paddle
49 changes: 49 additions & 0 deletions python/paddle/fluid/tests/unittests/test_argsort_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest


class TestArgsortOp(OpTest):
def setUp(self):
self.init_axis()
x = np.random.random((2, 3, 4, 5)).astype("float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unit testing has no gradient checking. so, better to use large shape here to coverage more case, since PADDLE_CUDA_NUM_THREADS is large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.indices = np.argsort(x, kind='quicksort', axis=self.axis)
self.out = np.sort(x, kind='quicksort', axis=self.axis)
self.op_type = "argsort"
self.inputs = {'X': x}
self.attrs = {'axis': self.axis}
self.outputs = {'Indices': self.indices, 'Out': self.out}

def init_axis(self):
self.axis = -1

def test_check_output(self):
self.check_output()


class TestArgsortOpAxis0(TestArgsortOp):
def init_axis(self):
self.axis = 0


class TestArgsortOpAxis1(TestArgsortOp):
def init_axis(self):
self.axis = 1


if __name__ == "__main__":
unittest.main()