diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index ac5259436005..dc01114af0a1 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -192,12 +192,12 @@ class RandomEngine { struct ParallelTask { static int RunTask(int task_id, TVMParallelGroupEnv* penv, void* cdata) { ParallelTask* task = static_cast(cdata); - task->Run(task_id); + task->Run(task_id, penv->num_task); return 0; } - void Run(int i) { - int64_t chunk_size = size / num_threads; + void Run(int i, int num_tasks) { + int64_t chunk_size = size / num_tasks; int64_t st = i * chunk_size; int64_t ed = std::min(st + chunk_size, size); self->FillDataImpl(data, st, ed, dtype); @@ -205,7 +205,6 @@ class RandomEngine { RandomEngine* self; void* data; - int num_threads; int64_t size; DLDataType dtype; }; @@ -220,8 +219,7 @@ class RandomEngine { } if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 || dtype.bits == 32 || dtype.bits == 64) { - int num_threads = task.num_threads = runtime::threading::MaxConcurrency(); - int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, num_threads); + int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, 0); ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed"; } else { LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits; diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index 7a52c0dbf1ea..ddc06b07110e 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -20,6 +20,7 @@ from tvm.contrib import random from tvm import rpc import tvm.testing +import threading def test_randint(): @@ -155,8 +156,35 @@ def check_remote(server): test_rpc(dtype) +def test_random_fill_mt(): + """Check random filler applicability in case of nontrivial thread pool configuration. + Particularly when MaxConcurrency != num_workers_used_ which is actual for big-little systems. + """ + no_exception_happened = True + + def test_body(): + try: + num_thread_used = 1 + configure_threads = tvm.get_global_func("runtime.config_threadpool") + configure_threads(1, num_thread_used) + + test_input = tvm.runtime.ndarray.empty((10, 10)) + random_fill = tvm.get_global_func("tvm.contrib.random.random_fill_for_measure") + random_fill(test_input) + except: + nonlocal no_exception_happened + no_exception_happened = False + + # ThreadPool object is thread local. To eliminate effect on other test cases put it into thread + x = threading.Thread(target=test_body) + x.start() + x.join() + assert no_exception_happened + + if __name__ == "__main__": test_randint() test_uniform() test_normal() test_random_fill() + test_random_fill_mt()