Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Cherry-pick large tensor support from #18752. #18804

Merged
merged 1 commit into from
Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ List of Contributors
* [Connor Goggins](https://github.com/connorgoggins)
* [Wei Chu](https://github.com/waytrue17)
* [Yang Shi](https://github.com/ys2843)
* [Joe Evans](https://github.com/josephevans)

Label Bot
---------
Expand Down
11 changes: 6 additions & 5 deletions src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ using namespace mshadow;
// Copies lower/upper triangular part to upper/lower, i.e. to the opposite side.
struct CopyTriangularToOppositeSide {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, bool to_lower) {
MSHADOW_XINLINE static void Map(index_t i, size_t matrix_size, index_t stride,
DType* data, bool to_lower) {
// Below computation works even when we are dealing with a batch of matrices.
const int row((i % matrix_size) / stride), col(i % stride);
const index_t row((i % matrix_size) / stride), col(i % stride);
if (row > col) {
if (to_lower) {
data[i] = data[i + (col - row) * (stride - 1)];
Expand All @@ -52,9 +53,9 @@ struct CopyTriangularToOppositeSide {
// Zero's lower/upper triangular part of a matrix.
struct ZeroTriangular {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data,
bool zero_lower) {
const int row((i % matrix_size) / stride), col(i % stride);
MSHADOW_XINLINE static void Map(index_t i, size_t matrix_size, index_t stride,
DType* data, bool zero_lower) {
const index_t row((i % matrix_size) / stride), col(i % stride);
if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] = 0;
}
};
Expand Down