From e9829e71a7f536d0fc78a0faf96f31336987770e Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Tue, 28 Jul 2020 18:53:29 -0700 Subject: [PATCH] Cherry-pick large tensor support from #18752. (#18804) Co-authored-by: Joe Evans --- CONTRIBUTORS.md | 1 + src/operator/tensor/la_op-inl.h | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f63b2412077b..4146d45b5c9d 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -254,6 +254,7 @@ List of Contributors * [Connor Goggins](https://github.com/connorgoggins) * [Wei Chu](https://github.com/waytrue17) * [Yang Shi](https://github.com/ys2843) +* [Joe Evans](https://github.com/josephevans) Label Bot --------- diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h index d580cced4ec5..7a5a602425fe 100644 --- a/src/operator/tensor/la_op-inl.h +++ b/src/operator/tensor/la_op-inl.h @@ -36,9 +36,10 @@ using namespace mshadow; // Copies lower/upper triangular part to upper/lower, i.e. to the opposite side. struct CopyTriangularToOppositeSide { template - MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, bool to_lower) { + MSHADOW_XINLINE static void Map(index_t i, size_t matrix_size, index_t stride, + DType* data, bool to_lower) { // Below computation works even when we are dealing with a batch of matrices. - const int row((i % matrix_size) / stride), col(i % stride); + const index_t row((i % matrix_size) / stride), col(i % stride); if (row > col) { if (to_lower) { data[i] = data[i + (col - row) * (stride - 1)]; @@ -52,9 +53,9 @@ struct CopyTriangularToOppositeSide { // Zero's lower/upper triangular part of a matrix. struct ZeroTriangular { template - MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType* data, - bool zero_lower) { - const int row((i % matrix_size) / stride), col(i % stride); + MSHADOW_XINLINE static void Map(index_t i, size_t matrix_size, index_t stride, + DType* data, bool zero_lower) { + const index_t row((i % matrix_size) / stride), col(i % stride); if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] = 0; } };