Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Prevent after-fork number of OMP threads being bigger than 1. (#16999)
Browse files Browse the repository at this point in the history
* Prevent after-fork number of OMP threads being bigger than 1.
This could happen if it was set in the environment. As we are setting engine::OpenMP::Get()->set_enabled(false) in initialize.cc in the child after forking, the behaviour goes back to what it was before #15762 was introduced.

Regions using omp get the threads count from GetRecommendedOMPThreadCount, so if omp is disabled they will get 1 thread and run serially

* add C++ unit test

* Add comment
  • Loading branch information
larroy authored and anirudh2290 committed Dec 11, 2019
1 parent c82af38 commit 04ebe45
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/engine/openmp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ void OpenMP::set_reserve_cores(int cores) {

int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
#ifdef _OPENMP
if (omp_num_threads_set_in_environment_) {
return omp_get_max_threads();
}
if (enabled_) {
// OMP_NUM_THREADS was set in the environment at the time of static initialization
if (omp_num_threads_set_in_environment_) {
return omp_get_max_threads();
}
int thread_count = omp_get_max_threads();
if (exclude_reserved) {
if (reserve_cores_ >= thread_count) {
Expand All @@ -107,8 +108,9 @@ int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const {
return thread_count;
}
return omp_thread_max_;
} else {
return 1;
}
return 1;
#else
return 1;
#endif
Expand Down
50 changes: 50 additions & 0 deletions tests/cpp/engine/omp_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>

#include "../include/test_util.h"
#include "../../src/engine/openmp.h"

#if defined(unix) || defined(__unix__) || defined(__unix)
#include <unistd.h>
#include <sys/types.h>
#include <dmlc/logging.h>


TEST(OMPBehaviour, after_fork) {
/*
* Check that after fork, OMP is disabled, and the recommended thread count is 1 to prevent
* process fanout.
*/
using namespace mxnet::engine;
auto openmp = OpenMP::Get();
pid_t pid = fork();
if (pid == 0) {
EXPECT_FALSE(openmp->enabled());
EXPECT_EQ(openmp->GetRecommendedOMPThreadCount(), 1);
} else if (pid > 0) {
int status;
int ret = waitpid(pid, &status, 0);
CHECK_EQ(ret, pid) << "waitpid failed";
} else {
CHECK(false) << "fork failed";
}
}
#endif
41 changes: 41 additions & 0 deletions tests/python/unittest/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import nose
import mxnet as mx
import os
import unittest
from mxnet.test_utils import EnvManager

def test_bulk():
with mx.engine.bulk(10):
Expand All @@ -30,6 +33,44 @@ def test_bulk():
x += 1
assert (x.asnumpy() == 104).all()

@unittest.skip("OMP platform dependent")
def test_engine_openmp_after_fork():
"""
Test that the number of max threads in the child is 1. After forking we should not use a bigger
OMP thread pool.
With GOMP the child always has the same number when calling omp_get_max_threads, with LLVM OMP
the child respects the number of max threads set in the parent.
"""
with EnvManager('OMP_NUM_THREADS', '42'):
r, w = os.pipe()
pid = os.fork()
if pid:
os.close(r)
wfd = os.fdopen(w, 'w')
wfd.write('a')
omp_max_threads = mx.base._LIB.omp_get_max_threads()
print("Parent omp max threads: {}".format(omp_max_threads))
try:
wfd.close()
except:
pass
try:
(cpid, status) = os.waitpid(pid, 0)
assert cpid == pid
exit_status = status >> 8
assert exit_status == 0
except:
pass
else:
os.close(w)
rfd = os.fdopen(r, 'r')
rfd.read(1)
omp_max_threads = mx.base._LIB.omp_get_max_threads()
print("Child omp max threads: {}".format(omp_max_threads))
assert omp_max_threads == 1



if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 04ebe45

Please sign in to comment.