Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Bypass ThreadedEngine in test_operator_gpu.py:test_convolution_multip…
Browse files Browse the repository at this point in the history
…le_streams. (#14338)
  • Loading branch information
DickJC123 authored and nswamy committed Apr 5, 2019
1 parent d014ff3 commit ae2dda2
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,18 @@ def _conv_with_num_streams(seed):

@with_seed()
def test_convolution_multiple_streams():
engines = ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']

if os.getenv('MXNET_ENGINE_TYPE') is not None:
engines = [os.getenv('MXNET_ENGINE_TYPE'),]
print("Only running against '%s'" % engines[0], file=sys.stderr, end='')
# Remove this else clause when the ThreadedEngine can handle this test
else:
engines.remove('ThreadedEngine')
print("SKIP: 'ThreadedEngine', only running against %s" % engines, file=sys.stderr, end='')

for num_streams in [1, 2]:
for engine in ['NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']:
for engine in engines:
_test_in_separate_process(_conv_with_num_streams,
{'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine})

Expand Down

0 comments on commit ae2dda2

Please sign in to comment.