Skip to content

Commit

Permalink
Prevent after-fork number of OMP threads being bigger than 1.
Browse files Browse the repository at this point in the history
This could happen if it was set in the environment. As we are setting engine::OpenMP::Get()->set_enabled(false) in initialize.cc in the child after forking, the behaviour goes back to what it was before apache#15762 was introduced.

Regions using omp get the threads count from GetRecommendedOMPThreadCount, so if omp is disabled they will get 1 thread and run serially
  • Loading branch information
larroy committed Dec 7, 2019
1 parent fcc42de commit 68813a9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/engine/openmp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ void OpenMP::set_reserve_cores(int cores) {

int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
#ifdef _OPENMP
if (omp_num_threads_set_in_environment_) {
return omp_get_max_threads();
}
if (enabled_) {
if (omp_num_threads_set_in_environment_) {
return omp_get_max_threads();
}
int thread_count = omp_get_max_threads();
if (exclude_reserved) {
if (reserve_cores_ >= thread_count) {
Expand All @@ -100,8 +100,9 @@ int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
return thread_count;
}
return omp_thread_max_;
} else {
return 1;
}
return 1;
#else
return 1;
#endif
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import nose
import mxnet as mx
import os
import unittest
from mxnet.test_utils import EnvManager

def test_bulk():
with mx.engine.bulk(10):
Expand All @@ -30,6 +33,44 @@ def test_bulk():
x += 1
assert (x.asnumpy() == 104).all()

@unittest.skip("OMP platform dependent")
def test_engine_openmp_after_fork():
"""
Test that the number of max threads in the child is 1. After forking we should not use a bigger
OMP thread pool.
With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP
the child respects the number of max threads set in the parent.
"""
with EnvManager('OMP_NUM_THREADS', '42'):
r, w = os.pipe()
pid = os.fork()
if pid:
os.close(r)
wfd = os.fdopen(w, 'w')
wfd.write('a')
omp_max_threads = mx.base._LIB.omp_get_max_threads()
print("Parent omp max threads: {}".format(omp_max_threads))
try:
wfd.close()
except:
pass
try:
(cpid, status) = os.waitpid(pid, 0)
assert cpid == pid
exit_status = status >> 8
assert exit_status == 0
except:
pass
else:
os.close(w)
rfd = os.fdopen(r, 'r')
rfd.read(1)
omp_max_threads = mx.base._LIB.omp_get_max_threads()
print("Child omp max threads: {}".format(omp_max_threads))
assert omp_max_threads == 1



if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 68813a9

Please sign in to comment.