Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/runtime/contrib/random/mt_random_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,19 @@ class RandomEngine {
struct ParallelTask {
static int RunTask(int task_id, TVMParallelGroupEnv* penv, void* cdata) {
ParallelTask* task = static_cast<ParallelTask*>(cdata);
task->Run(task_id);
task->Run(task_id, penv->num_task);
return 0;
}

void Run(int i) {
int64_t chunk_size = size / num_threads;
void Run(int i, int num_tasks) {
int64_t chunk_size = size / num_tasks;
int64_t st = i * chunk_size;
int64_t ed = std::min(st + chunk_size, size);
self->FillDataImpl(data, st, ed, dtype);
}

RandomEngine* self;
void* data;
int num_threads;
int64_t size;
DLDataType dtype;
};
Expand All @@ -220,8 +219,7 @@ class RandomEngine {
}
if (dtype.bits == 1 || dtype.bits == 4 || dtype.bits == 8 || dtype.bits == 16 ||
dtype.bits == 32 || dtype.bits == 64) {
int num_threads = task.num_threads = runtime::threading::MaxConcurrency();
int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, num_threads);
int res = TVMBackendParallelLaunch(ParallelTask::RunTask, &task, 0);
ICHECK_EQ(res, 0) << "RandomFillForMeasure: TVMBackendParallelLaunch failed";
} else {
LOG(FATAL) << "Doesn't support dtype code " << dtype.code << " dtype bits " << dtype.bits;
Expand Down
28 changes: 28 additions & 0 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm.contrib import random
from tvm import rpc
import tvm.testing
import threading


def test_randint():
Expand Down Expand Up @@ -155,8 +156,35 @@ def check_remote(server):
test_rpc(dtype)


def test_random_fill_mt():
"""Check random filler applicability in case of nontrivial thread pool configuration.
Particularly when MaxConcurrency != num_workers_used_ which is actual for big-little systems.
"""
no_exception_happened = True

def test_body():
try:
num_thread_used = 1
configure_threads = tvm.get_global_func("runtime.config_threadpool")
configure_threads(1, num_thread_used)

test_input = tvm.runtime.ndarray.empty((10, 10))
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill_for_measure")
random_fill(test_input)
except:
nonlocal no_exception_happened
no_exception_happened = False

# ThreadPool object is thread local. To eliminate effect on other test cases put it into thread
x = threading.Thread(target=test_body)
x.start()
x.join()
assert no_exception_happened


if __name__ == "__main__":
test_randint()
test_uniform()
test_normal()
test_random_fill()
test_random_fill_mt()