-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
zipformer transducer is going to slow when run multi threading #704
Comments
The above files are examples of running zipformer in a single process with multi-threads. |
Hello. I checked the below codes. online-websocket-server.cc In 'online-websocket-server-impl.cc:L55" It seems each thread has their own 'recognizer_'. I'm trying to implement 1 shared 'recognizer_' and each thread has their own 'CreateStream' made by 'recognizer_'. If my implementation method is wrong, I'm really pleasure to know why this method is wrong or why this method is too slow when the number of threads increased. Thank you 🙂 |
the recognizer is shared among.threads in our code. Please recheck. |
Hello. I ran https://github.com/k2-fsa/sherpa/blob/master/sherpa/cpp_api/websocket/online-websocket-server.cc and I don't know it has same problem because it uses only small number of threads. So, I changed the file-test sample code based on https://github.com/k2-fsa/sherpa/blob/d3c953b68b6797a21c35fe1809d7eed57889ed58/sherpa/cpp_api/bin/online-recognizer.cc. Here is the context: To debug this, I modified the online-recognizer.cc sample code to use multiple threads. Here is my code: #include "sherpa/cpp_api/online-recognizer.h"
#include <algorithm>
#include <thread>
#include <vector>
#include "kaldi_native_io/csrc/kaldi-table.h"
#include "kaldi_native_io/csrc/text-utils.h"
#include "kaldi_native_io/csrc/wave-reader.h"
#include "sherpa/cpp_api/online-stream.h"
#include "sherpa/cpp_api/parse-options.h"
#include "sherpa/csrc/fbank-features.h"
#include "sherpa/csrc/log.h"
static constexpr const char *kUsageMessage = R"(
hi
)";
int32_t main(int32_t argc, char *argv[]) {
torch::set_num_threads(1);
torch::set_num_interop_threads(1);
sherpa::InferenceMode no_grad;
torch::jit::getExecutorMode() = false;
torch::jit::getProfilingMode() = false;
torch::jit::setGraphExecutorOptimize(false);
float expected_sample_rate = 8000;
float padding_seconds = 0.8;
sherpa::ParseOptions po(kUsageMessage);
po.Register("padding-seconds", &padding_seconds,
"Number of seconds for tail padding.");
sherpa::OnlineRecognizerConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() < 1) {
po.PrintUsage();
exit(EXIT_FAILURE);
}
config.Validate();
SHERPA_CHECK_EQ(config.feat_config.fbank_opts.frame_opts.samp_freq,
expected_sample_rate)
<< "The model was trained using training data with sample rate 8000. "
<< "We don't support resample yet";
SHERPA_CHECK_GE(padding_seconds, 0);
SHERPA_LOG(INFO) << "decoding method: " << config.decoding_method;
torch::Tensor tail_padding = torch::zeros(
{static_cast<int32_t>(padding_seconds * expected_sample_rate)},
torch::kFloat);
sherpa::OnlineRecognizer recognizer(config);
// Define thread worker function
auto thread_function = [&](int32_t thread_id) {
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
std::string wave_file = po.GetArg(i);
torch::Tensor wave =
sherpa::ReadWave(wave_file, expected_sample_rate).first;
auto s = recognizer.CreateStream();
int32_t chunk = 0.1 * expected_sample_rate;
int32_t num_samples = wave.numel();
std::string last;
for (int32_t start = 0; start < num_samples;) {
int32_t end = std::min(start + chunk, num_samples);
torch::Tensor samples =
wave.index({torch::indexing::Slice(start, end)});
start = end;
s->AcceptWaveform(expected_sample_rate, samples);
while (recognizer.IsReady(s.get())) {
recognizer.DecodeStream(s.get());
}
}
s->AcceptWaveform(expected_sample_rate, tail_padding);
s->InputFinished();
while (recognizer.IsReady(s.get())) {
recognizer.DecodeStream(s.get());
}
auto r = recognizer.GetResult(s.get());
}
};
// Define the number of threads
int32_t num_threads = 80; // Number of threads to execute
std::vector<std::thread> threads;
// Create threads
for (int32_t t = 0; t < num_threads; ++t) {
threads.emplace_back(thread_function, t);
}
// Join threads
for (auto &t : threads) {
if (t.joinable()) {
t.join();
}
}
return 0;
} The behavior I’m observing is as follows: I suspect the issue might be related to contention or inefficient use of resources when scaling with multiple threads, but I am unsure how to address this. Could you please advise: Thank you for your time and help! |
It wasn't a code issue, it was a result of not doing NUMA aware programming in my environment. Currently, my server has two CPUs (2 nodes), and when it is multithreaded, most of the recognizer is only in the heap memory of one node, CPU switching occurs when performing online speech recognition decoding, and memory I/O time becomes abnormally long due to more time to access remote memory. I checked that [numactl --cpunodebind=0 ./my_program] works well without abnormal memory access time. However, since using cpunodebind utilizes only one cpu, I will modify the code to put the recognizer on each node through numa aware programming and allow the thread to operate on the cpu core present on that node. If you have any advice, feel free to share it with me. Thank you. |
Hi. Thanks for nice project
I'm trying to deploy zipformer transducer using sherpa and libtorch in CPU environment.
I implemented torchscipt model is shared to all threads and the decoder stream is initialized for individual threads.
The system is going too slow when the number of threads increased.
I took htop and found that it uses CPU well when it is single-threaded, but as the number of threads increases, the kernel work increases and the CPU is not utilized well.
But when I ran the system in multi-process environment (1 model, 1 decoder for individual threads), each process uses CPU well.
How can I fix it to utilize CPU much more when multi-thread environment?
I'm using libtorch 2.0.1+cpu.
Thank you.
The text was updated successfully, but these errors were encountered: