Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Move selected_rows PR #5] VisitDataType use Pten::DataType #39236

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f533aa0
Added selected_rows and rw_lock to pten
veyron95 Jan 20, 2022
832e903
Renamed the unit test target to fix CI
veyron95 Jan 21, 2022
318ef67
Removed Class SelectedRows in Fluid, changed include/cmake relationsh…
veyron95 Jan 21, 2022
40f635d
Remove rw_lock.h,rw_lock_test.cc in fluid
veyron95 Jan 22, 2022
aa5d93f
Use pten::RWLock and pten::AutoRDLock, fix CI
veyron95 Jan 22, 2022
8c51888
Use pten::SelectedRows
veyron95 Jan 22, 2022
15328f3
Use pten::SelectedRows
veyron95 Jan 22, 2022
7416f99
Fix to pass NPU CI
veyron95 Jan 22, 2022
794f7ef
Merge commit 'refs/pull/39128/head' of https://github.com/PaddlePaddl…
veyron95 Jan 24, 2022
4511b17
Merge branch 'develop' into fluid_move_selected_rows_to_pten_3
veyron95 Jan 24, 2022
47e3ccb
Selected_Rows inherits from TensorBase
veyron95 Jan 24, 2022
7afa032
Fix conflict
veyron95 Jan 24, 2022
3942e0f
Use pten::SelectedRows, to pass NPU CI
veyron95 Jan 24, 2022
75de13d
To fix NPU CI
veyron95 Jan 24, 2022
c650f51
Merge commit 'refs/pull/39128/head' of https://github.com/PaddlePaddl…
veyron95 Jan 24, 2022
d241507
To fix NPU CI again
veyron95 Jan 24, 2022
ef71e4a
Merge commit 'refs/pull/39128/head' of https://github.com/PaddlePaddl…
veyron95 Jan 24, 2022
51b0f24
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 25, 2022
ef79f84
Use paddle/pten/core/enforce and polish code
veyron95 Jan 25, 2022
d815960
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 25, 2022
ccd0e78
Use pten::DataType instead of using proto_type
veyron95 Jan 25, 2022
835330e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 25, 2022
5397c54
Move part of data_type to pten
veyron95 Jan 26, 2022
8e61902
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Jan 26, 2022
53db5d3
Polish Code
veyron95 Jan 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/pten/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base )

cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim)
cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector pten_enforce ddim)

cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
Expand Down
20 changes: 9 additions & 11 deletions paddle/pten/core/selected_rows.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/selected_rows.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/data_type.h"
#include "paddle/pten/core/utils/data_type.h"

namespace pten {

Expand Down Expand Up @@ -191,16 +189,16 @@ void SelectedRows::Get(const pten::DenseTensor& ids,
int64_t index = AutoGrownIndex(id, auto_grown, is_test);
if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0";
paddle::framework::VisitDataType(
value_->type(),
pten::VisitDataType(
value_->dtype(),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
paddle::framework::VisitDataType(value_->type(),
TensorCopyVisitor(value,
i * value_width,
*value_.get(),
index * value_width,
value_width));
pten::VisitDataType(value_->dtype(),
TensorCopyVisitor(value,
i * value_width,
*value_.get(),
index * value_width,
value_width));
}
}
}
Expand Down
56 changes: 45 additions & 11 deletions paddle/pten/core/selected_rows.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/rw_lock.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"

namespace pten {
class SelectedRows {
class SelectedRows : public TensorBase,
public TypeInfoTraits<TensorBase, SelectedRows> {
/*
* @brief We can use the SelectedRows structure to reproduce a sparse table.
* A sparse table is a key-value structure that the key is an `int64_t`,
Expand All @@ -51,21 +52,19 @@ class SelectedRows {
public:
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}

SelectedRows() {
height_ = 0;
value_.reset(new pten::DenseTensor());
value_.reset(new DenseTensor());
rwlock_.reset(new RWLock);
}

const pten::Place& place() const { return value_->place(); }
const DenseTensor& value() const { return *value_; }

const pten::DenseTensor& value() const { return *value_; }

pten::DenseTensor* mutable_value() { return value_.get(); }
DenseTensor* mutable_value() { return value_.get(); }

int64_t height() const { return height_; }

Expand Down Expand Up @@ -109,8 +108,8 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
void Get(const pten::DenseTensor& ids,
pten::DenseTensor* value,
void Get(const DenseTensor& ids,
DenseTensor* value,
bool auto_grown = false,
bool is_test = false);

Expand Down Expand Up @@ -149,14 +148,49 @@ class SelectedRows {
return pten::framework::make_ddim(dims);
}

/// \brief Returns the name of the class for type traits.
/// \return The name of the class.
static const char* name() { return "SelectedRows"; }

/// \brief Returns the number of elements contained in tensor.
/// \return The number of elements contained in tensor.
int64_t numel() const override { return value_->numel(); };

/// \brief Returns the dims of the tensor.
/// \return The dims of the tensor.
const DDim& dims() const noexcept override {
return value_->dims();
// return paddle::framework::make_ddim(dims);
}

/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return value_->dtype(); }

/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return value_->layout(); }

/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return value_->place(); };

/// \brief Test whether the metadata is valid.
/// \return Whether the metadata is valid.
bool valid() const noexcept override { return value_->valid(); }

/// \brief Test whether the storage is allocated.
/// return Whether the storage is allocated.
bool initialized() const override { return value_->initialized(); }

private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
paddle::framework::Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t>
id_to_index_; // should not be used when rows_ has duplicate member
std::unique_ptr<pten::DenseTensor> value_{nullptr};
std::unique_ptr<DenseTensor> value_{nullptr};
int64_t height_; // height indicates the underline tensor's height
std::unique_ptr<RWLock> rwlock_{nullptr};
};
Expand Down
63 changes: 63 additions & 0 deletions paddle/pten/core/utils/data_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <iostream>
#include <string>
#include <typeindex>

#include "paddle/pten/common/data_type.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/kernels/funcs/eigen/extensions.h"

namespace pten {

#define _PtenForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type);

#define _PtenForEachDataType_(callback) \
_PtenForEachDataTypeHelper_(callback, float, DataType::FLOAT32); \
_PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::float16, DataType::FLOAT16); \
_PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::bfloat16, DataType::BFLOAT16); \
_PtenForEachDataTypeHelper_(callback, double, DataType::FLOAT64); \
_PtenForEachDataTypeHelper_(callback, int, DataType::INT32); \
_PtenForEachDataTypeHelper_(callback, int64_t, DataType::INT64); \
_PtenForEachDataTypeHelper_(callback, bool, DataType::BOOL); \
_PtenForEachDataTypeHelper_(callback, uint8_t, DataType::UINT8); \
_PtenForEachDataTypeHelper_(callback, int16_t, DataType::INT16); \
_PtenForEachDataTypeHelper_(callback, int8_t, DataType::INT8); \
_PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::complex<float>, DataType::COMPLEX64); \
_PtenForEachDataTypeHelper_( \
callback, ::paddle::platform::complex<double>, DataType::COMPLEX128);

template <typename Visitor>
inline void VisitDataType(pten::DataType type, Visitor visitor) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同步一下,这个VisitDataType我们认为写得不是很直观,之前讨论过说希望用switch-case的形式替换掉的,参考pten/api/ext/dispatch.h,不过短时间内好像也比较困难;

此外,后续如果还是有其他kernel使用,我们可能会考虑把这段实现直接放到common/data_type.h中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,后续再替换或者进一步优化,多谢。

#define PtenVisitDataTypeCallback(cpp_type, data_type) \
do { \
if (type == data_type) { \
visitor.template apply<cpp_type>(); \
return; \
} \
} while (0)

_PtenForEachDataType_(PtenVisitDataTypeCallback);
#undef PtenVisitDataTypeCallback
PADDLE_THROW(pten::errors::Unimplemented(
"Not supported proto::VarType::Type(%d) as data type.",
static_cast<int>(type)));
}
} // namespace pten