Skip to content

Commit f0ea9e4

Browse files
authored
[RUNTIME] Fix the manual determination of cores in FillDataForMeasure (#13849)
* Assertion failed during tuning * Cleanup * Do not commit * Do not commit * Undo fix + provide test for multithread random filling * Random fill test with fix enabled * Isolate the effect of this test on the other tests * Correct the typo in the function name * Import threading + lint
1 parent 49849c8 commit f0ea9e4

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

src/runtime/contrib/random/mt_random_engine.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,19 @@ class RandomEngine {
192192
struct ParallelTask {
193193
static int RunTask(int task_id, TVMParallelGroupEnv* penv, void* cdata) {
194194
ParallelTask* task = static_cast<ParallelTask*>(cdata);
195-
task->Run(task_id);
195+
task->Run(task_id, penv->num_task);
196196
return 0;
197197
}
198198

199-
void Run(int i) {
200-
int64_t chunk_size = size / num_threads;
199+
void Run(int i, int num_tasks) {
200+
int64_t chunk_size = size / num_tasks;
201201
int64_t st = i * chunk_size;
202202
int64_t ed = std::min(st + chunk_size, size);
203203
self->FillDataImpl(data, st, ed, dtype);
204204
}
205205

206206
RandomEngine* self;
207207
void* data;
208-
int num_threads;
209208
int64_t size;
210209
DLDataType dtype;
211210
};
@@ -220,8 +219,7 @@ class RandomEngine {
220219
}
221220
if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 ||
222221
dtype.bits == 32 || dtype.bits == 64) {
223-
int num_threads = task.num_threads = runtime::threading::MaxConcurrency();
224-
int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, num_threads);
222+
int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, 0);
225223
ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed";
226224
} else {
227225
LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits;

tests/python/contrib/test_random.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tvm.contrib import random
2121
from tvm import rpc
2222
import tvm.testing
23+
import threading
2324

2425

2526
def test_randint():
@@ -155,8 +156,35 @@ def check_remote(server):
155156
test_rpc(dtype)
156157

157158

159+
def test_random_fill_mt():
160+
"""Check random filler applicability in case of nontrivial thread pool configuration.
161+
Particularly when MaxConcurrency != num_workers_used_ which is actual for big-little systems.
162+
"""
163+
no_exception_happened = True
164+
165+
def test_body():
166+
try:
167+
num_thread_used = 1
168+
configure_threads = tvm.get_global_func("runtime.config_threadpool")
169+
configure_threads(1, num_thread_used)
170+
171+
test_input = tvm.runtime.ndarray.empty((10, 10))
172+
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill_for_measure")
173+
random_fill(test_input)
174+
except:
175+
nonlocal no_exception_happened
176+
no_exception_happened = False
177+
178+
# ThreadPool object is thread local. To eliminate effect on other test cases put it into thread
179+
x = threading.Thread(target=test_body)
180+
x.start()
181+
x.join()
182+
assert no_exception_happened
183+
184+
158185
if __name__ == "__main__":
159186
test_randint()
160187
test_uniform()
161188
test_normal()
162189
test_random_fill()
190+
test_random_fill_mt()

0 commit comments

Comments
 (0)