Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

When predicting, does mxnet provide thread-safe interface? #3946

Open
gold-mango opened this issue Nov 23, 2016 · 28 comments
Open

When predicting, does mxnet provide thread-safe interface? #3946

gold-mango opened this issue Nov 23, 2016 · 28 comments

Comments

@gold-mango
Copy link

When deploying online, multi-threads is usually required. If each thread load a model, the memory consume costly, so is there a thread-safe interface that share model parameters?

@zihaolucky
Copy link
Member

I have same question.

@piiswrong
Copy link
Contributor

piiswrong commented Nov 23, 2016

The engine is not thread safe so there is no way to use multiple threads for pushing computation to the engine.

However the engine already does threading and scheduling for you so you really shouldn't need to.
Instead consider a multi thread producer, single thread consumer model. You use multiple thread to accept request and push them to a queue, a backstage thread then read the queue and push it to mxnet and send the result back to the producer.
You can bind multiple executors on the same set of weight ndarrays and its ok as long as you don't modify them and always run them from the same thread. Although it probably won't bring you any benefits if you are doing it on the same card

@piiswrong
Copy link
Contributor

BTW, python multi threading doesn't really work due to GIL. Try multi processing instead

@gold-mango
Copy link
Author

@piiswrong On the same card, if I bind multiple executors on the same set of weight ndarrays, and run them from different thread simultaneously, is that ok?
Does it benefit? because more than one model are running simultaneously, then increase throughput rate?

@phoenixbai
Copy link

I have the same problem. Is it possible to provide multi-threaded prediction service using mxnet? As in caffe case, we have to copy the whole net as many times as the thread number. Is mxnet any better?

my trained model size is normally around 1GB, so memory consumption is a big issue.

@fasahath
Copy link

Any progress here?

@everwind
Copy link

I have the same question. Any progress here?

@everwind
Copy link

everwind commented Aug 23, 2017

I think we should separate "data" and the model weights . when predicting, each thread can share the model weights and use different data . In this way , we have mush less memory consumption .
In practice, we often train the model on GPUS, but use the model on the cpu machine. so it is very important to provide the effective multi-thread prediction of the model

@leopd
Copy link
Contributor

leopd commented Sep 7, 2017

Tl;dr: if you use a high performance python web-server like gunicorn, you'll get what you want. You'll have lots of cores running in parallel working on different requests, and you won't need a copy in memory of the model for each worker.

Getting a complex codebase to be threadsafe is no small task, so this won't be "resolved" any time soon. Fortunately it's not necessary here for most of what you want. Your answer lies in the Unix fork() command. If you want to understand, go read: https://en.wikipedia.org/wiki/Fork_(system_call)

The magic of fork is copy-on-write memory semantics whereby each forked worker has its own virtual memory address space, but they all share the same physical memory. (Until one of them writes to the shared memory, in which case a private copy of that memory block is made in that process's virtual address space -- thus "copy-on-write".) So even though it's not multi-threaded, fork() & pre-fork worker servers like gunicorn let you accomplish almost the same thing with multiple processes instead of threads. Forked processes are somewhat more heavyweight than threads, but they're nowhere near as expensive as you running the same command multiple times.

@bhavinthaker
Copy link
Contributor

Good points made by piiswrong@ and leopd@.

Based on my experience of using fork() over many years, I would say that calling fork() from a Multi-Threaded process is typically NOT recommended. If you do, you will need to understand the details in how it works to make correct use of it. A few caveats on the use of fork() are:

  1. The behavior of fork() can vary across operating systems. For example, Solaris has fork-one and fork-all model, whereas Linux has fork-one model. The fork-one model duplicates only the calling thread whereas the fork-all model duplicates all the calling threads.

  2. The fork-one model can leave mutexes in bad state in the called thread. See excellent description of the problems in the classic book on Linux Programming Interface [2].

  3. A good summary of the problem is here: [3]

  4. More problems are described here: [4]. For example, you need to ensure that Close-on-exec flag is set for file-descriptors [depends on whether the programming language does it for you, C does not], otherwise two processes will have reference to the same file and close() from one process will NOT actually close the file since the file is also referenced from the other process.

  5. Long discussion on why fork without exec is dangerous in large programs [5]. I always used exec() after a fork() and set the Close-on-exec() for file-descriptions when I have called fork() from a MT process.

References:
[1] fork-one and fork-all model: https://docs.oracle.com/cd/E19683-01/806-6867/gen-1/index.html
[2] The Linux Programming Interface: A Linux and UNIX System Programming Handbook
By Michael Kerrisk: https://books.google.com/books?id=Ps2SH727eCIC&pg=PA686&lpg=PA686&dq=fork+but+no+exec+mutex&source=bl&ots=kMBgx2BVwb&sig=PI74pt3fvO8gHKLLUaIUrr_shy8&hl=en&sa=X&ved=0ahUKEwjTu8qn75_WAhVJ7GMKHb46BQIQ6AEIOjAH#v=onepage&q=fork%20but%20no%20exec%20mutex&f=false
[3] https://thorstenball.com/blog/2014/10/13/why-threads-cant-fork/
[4] http://www.linuxprogrammingblog.com/threads-and-fork-think-twice-before-using-them
[5] https://news.ycombinator.com/item?id=12302539

@bhavinthaker
Copy link
Contributor

phoenixbai> As in caffe case, we have to copy the whole net as many times as the thread number. Is mxnet any better? my trained model size is normally around 1GB, so memory consumption is a big issue.

I would suggest saving the model in a file and accessing the file using memory mapped files from multiple threads/processes. This will significantly reduce the memory requirements of your solution.

Some good references:
[1] https://www.codeproject.com/Tips/683614/Things-to-Know-about-Memory-Mapped-File-in-Java
[2] https://howtodoinjava.com/java-7/nio/java-nio-2-0-memory-mapped-files-mappedbytebuffer-tutorial/
[3] http://javarevisited.blogspot.com/2012/01/memorymapped-file-and-io-in-java.html

@tqchen tqchen closed this as completed Oct 19, 2017
@nswamy nswamy reopened this Jan 9, 2018
@nswamy
Copy link
Member

nswamy commented Jan 9, 2018

@piiswrong can you expand on why you suggesting that the executors should accessed from the same thread ? This means that for the lifetime of the process only thread will be responsible for interacting with MXNet.

@TaoLv
Copy link
Member

TaoLv commented Jan 16, 2018

@piiswrong @eric-haibin-lin Maybe not related. Can mxnet find two independent ops in a computation graph and execute them parallelly on two cores of one CPU, respectively? Or if mxnet can konw how much cores are there, and give first half to the first operator and give the second half to the second operator.

@eric-haibin-lin
Copy link
Member

Can mxnet find two independent ops in a computation graph and execute them parallelly on two cores of one CPU, respectively?

If a graph has two parallel paths, MXNet can detect that and execute it if it has enough WORKER_THREADS. https://github.com/apache/incubator-mxnet/blob/master/docs/faq/env_var.md#set-the-number-of-threads

Or if mxnet can konw how much cores are there, and give first half to the first operator and give the second half to the second operator.

For CPU we rely on openmp for parallelization. We may give a hint to openmp but there's no guarantee on how many threads are actually executing for a single operator. @cjolivier01 works on CPU performance tuner and maybe has more comments on this

@cjolivier01
Copy link
Member

currently, it will run the ops in parallel, using (possibly) several OMP on each one independently. OMP threads for the given operator will tend to be on separate physical cores, however currently there is not coordination between OMP thread/core allocation across parallelly-executing operators, so they may overlap for some period of their execution.

it’s actualy kind of a tricky thing, because you want to allocate them across operators in an ideal way (like you mentioned), but they usually aren’t going to run perfectly parallel, so there will be time-boxes where the cpu wouldn’t be fully utilized (when they aren’t overlapping). this would be especially apparent when the operators aren’t the same.

how best to angle this is currently under discussion and input is welcome.

@TaoLv
Copy link
Member

TaoLv commented Jan 20, 2018

@eric-haibin-lin @cjolivier01 Thanks for the information. I am just curious about how does mxnet deal with the model parallelism, op parallelism and the parallelism inside of a op. If I have 40 cores and 2 independent ops, I can create 40 threads and give the first 20 threads to the first op and the other 20 threads to the second op, and execute the two ops concurrently on cpu. But maybe it's not as efficient as executing the two ops sequentially since both the two ops will leverage all 40 threads.

@yanhn
Copy link

yanhn commented Mar 2, 2018

@gold-mango Have you found some solutions?
I am working on Jetson TX2 and I need to do head detection on 3 input video streams.
My experiments result shows that:

  1. for single video stream with 1 executor, detection costs 25ms per frame;
  2. for 2 thread with 2 executors, detection costs 45ms per frame, almost doubled.
    So is there any optimization?

@eric-haibin-lin
Copy link
Member

eric-haibin-lin commented May 15, 2018

Confirmed with @piiswrong offline that the dependency engine in C++ is actually thread safe.

@eric-haibin-lin
Copy link
Member

@yanhn are you using python for inference? MXNET engine has limited number of worker threads https://github.com/apache/incubator-mxnet/blob/master/docs/faq/env_var.md#set-the-number-of-threads
Did you check the GPU utilization? If it's already at 100%, it's pretty much computation bounded

@hqucms
Copy link
Contributor

hqucms commented Jul 15, 2018

Confirmed with @piiswrong offline that the dependency engine in C++ is actually thread safe.

Does this mean that it is safe to have multiple threads calling the engine concurrently? Specifically, is it OK to create one executor per thread and running them simultaneously? (I am asking this in the context of multithreaded inference on CPU with C/C++ API).

And this seems in contradiction with:

Push APIs are not thread-safe. To be specific, only one thread should make engine API calls at a time.

as stated in https://mxnet.incubator.apache.org/architecture/overview.html, and also the discussion at https://discuss.mxnet.io/t/fixing-thread-safety-issues-in-scala-library/236. Has something changed since then?

Some clarification on this would be highly appreciated!
@piiswrong @eric-haibin-lin

@yanhn
Copy link

yanhn commented Jul 16, 2018

@eric-haibin-lin
I used the cpp api: PredictorHandle.
And I didn't set any MXNET engine envrionment, I used the default value.
As for tx2's GPU utilization, it shares a total 8G memory for both cpu and gpu use. And I checked that it wasn't full when there are 2 executors.
Maybe it is the hardware issue?

@eric-haibin-lin
Copy link
Member

@junrushao1994 could you check @hqucms 's comment on the thread safety of engine's Push API? Is this true?

@junrushao
Copy link
Member

junrushao commented Jul 17, 2018

@hqucms @eric-haibin-lin
For ThreadedEngine, I read the source code just now and believe that it is thread-safe. The Push API consists of 3 parts: bulk-related; dependency; PushToExecute. Bulk-related part is thread-local, so safe; The dependency and PushToExecute parts are designed to be thread-safe. Therefore, in conclusion I believe the ThreadedEngine is thread-safe.

I am not sure why our document says "Push APIs are not thread-safe" (https://mxnet.incubator.apache.org/architecture/overview.html). @tqchen Could you help confirm this?

@gzpyy
Copy link

gzpyy commented Aug 6, 2018

any process here? About the "Push APIs are not thread-safe"

@loadwiki
Copy link

Can I just create multi infer handle in different thread? I've tried in this way and don't work.

@apeforest
Copy link
Contributor

@loadwiki Can you provide an example to reproduce? Thx

@YutingZhang
Copy link
Contributor

Here is a very reasonable proposal.
https://cwiki.apache.org/confluence/display/MXNET/Parallel+Inference+in+MXNet
Is possible to just make asnumpy thread-safe or a state to indicate if an NDArray is ready to read.

@eric-haibin-lin
Copy link
Member

The parallel utility in gluonnlp may be useful for some use cases: https://github.com/dmlc/gluon-nlp/blob/master/src/gluonnlp/utils/parallel.py#L66-L77

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests