Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix operators to support large arrays. #13036

Closed
zheng-da opened this issue Oct 30, 2018 · 22 comments
Closed

fix operators to support large arrays. #13036

zheng-da opened this issue Oct 30, 2018 · 22 comments

Comments

@zheng-da
Copy link
Contributor

We're working on a model that requires very large NDArrays. For example, we want to create an NDArray as follows:

arr = mx.nd.random.normal(shape=(50000000, 100))

The current implementation doesn't fail with an error, but it doesn't generate a matrix correctly (it only fills the rows at the beginning).

mx.nd.zeros also fails.

It's unclear what operators support and which operators don't.

@frankfliu
Copy link
Contributor

@mxnet-label-bot [Bug, Operator]

@wkcn
Copy link
Member

wkcn commented Oct 31, 2018

The type of iterator is int in MXNet kernel.
So it will fail when the number of iteration is greater than 2^31 - 1, (namely 2147483647).
50000000 * 100 > 2^31 - 1

Please see the code:
CPU:
https://github.com/apache/incubator-mxnet/blob/master/src/operator/mxnet_op.h#L506

GPU:
https://github.com/apache/incubator-mxnet/blob/master/src/operator/mxnet_op.h#L627

@apeforest
Copy link
Contributor

Will take a look

@apeforest
Copy link
Contributor

@wkcn is right. We need to chang int to index_t. I am busy with other tasks now and can only come to this in one week. Let me know if it requires an immediate fix.

@apeforest
Copy link
Contributor

JIRA task created: https://issues.apache.org/jira/browse/MXNET-1185

@wkcn
Copy link
Member

wkcn commented Oct 31, 2018

@apeforest
will it drop the performance down If changing the type of iteration to int64_t?
In PyTorch, the type of iteration is a template type.

@zheng-da
Copy link
Contributor Author

i'm fixing some of the operators. but we need a systematic fix. The problem is everywhere. i'll provide a temp fix for some of the operators.

@zheng-da
Copy link
Contributor Author

@wkcn in cpu, it shouldn't be a problem. I heard concerns on GPUs. Potentially, we can use int64_t for CPU and int for GPU.

@wkcn
Copy link
Member

wkcn commented Nov 1, 2018

@zheng-da
Maybe we can try to use int32_t for small for-loop, and int64_t for large for-loop.

@zheng-da
Copy link
Contributor Author

zheng-da commented Nov 1, 2018

@wkcn My concern is that this modification makes the code complex. As for using different int types for CPU and GPU, it's relatively easier. We can use the template argument to easily achieve it.

@zheng-da
Copy link
Contributor Author

zheng-da commented Nov 1, 2018

@pengzhao-intel what is the performance difference between int32 and int64 in intel CPUs?

@zheng-da
Copy link
Contributor Author

zheng-da commented Nov 1, 2018

@apeforest I have fixed some of the operators, including all random generators, zeros, ones, full, arange, gather_nd.
zheng-da@2c3d9a3
But we need to do more to fix the rest of the operators.

@wkcn
Copy link
Member

wkcn commented Nov 2, 2018

@zheng-da Maybe size_t is better.

@apeforest
Copy link
Contributor

@zheng-da Do you plan to create a PR with your change? I will be glad to review. Also, I have created an epic (https://issues.apache.org/jira/browse/MXNET-1184) to address this support in a systematic way. Please feel free to add additional tasks to it as needed. Thanks.

@pengzhao-intel
Copy link
Contributor

@zheng-da in general, int64 is only half of int32 performance.

@wkcn
Copy link
Member

wkcn commented Nov 2, 2018

Hi, I modified src/operator/mxnet_op.h to support int32_t and int64_t as the type of iterator. It may be helpful.

And I wrote a script to replace the type of interator to IndexType.

Usage:

    1. Install the_silver_searcher
    1. Input the command: ag "MSHADOW_XINLINE static void Map" > map.txt
    1. python replace_index.py

However, there was some bug in the script. :-(

@zheng-da
Copy link
Contributor Author

zheng-da commented Nov 3, 2018

@apeforest I just fixed the operators I use in my model. Could you help add test and fix other operators?

@wkcn
Copy link
Member

wkcn commented Nov 3, 2018

In my test, it seem that the performances of +/- between int32_t and int64_t are approximate.

CPU: Intel i7-7500U
OS: Arch Linux x64
Compiler: g++ 8.2.1
Compiler Flag: g++ ctype.cpp -o test -g -lpthread -std=c++11 -Wno-invalid-source-encoding
Test Code: https://github.com/wkcn/c_performance

Test int8_t
929 ms
798 ms
831 ms
2024 ms
Test int16_t
860 ms
803 ms
840 ms
1950 ms
Test int32_t
858 ms
822 ms
878 ms
1947 ms
Test int64_t
899 ms
837 ms
828 ms
7345 ms
Test float
1187 ms
1191 ms
1198 ms
1199 ms
Test double
1209 ms
1211 ms
1205 ms
1205 ms

@zheng-da
Copy link
Contributor Author

zheng-da commented Nov 3, 2018

integer operations are cheap. even if int64 is a little more expensive, it's hard to believe that it can affect the overall performance by much.

@pengzhao-intel
Copy link
Contributor

try the gemm with int 32 and int
64 and see how much peak GFLOPS it can achieve

@lou-k
Copy link
Contributor

lou-k commented Dec 18, 2018

I believe the following is also a repo of this issue:

import mxnet as mx
mx.nd.eye(10240 * 5) * 2
 
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

@apeforest
Copy link
Contributor

This issue has been fixed. In 1.5.0 release, user need to build MXNet from source with the compilation flag USE_INT64_TENSOR_SIZE=1. We are working to make this flag on by default and available in pip package in next minor release. Closing this issue for now.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

7 participants