Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA TLG nbest and mbr decoding #1804

Merged
merged 10 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions runtime/gpu/client/stats_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,21 @@ def get_args():
stats = json.load(stats_f)
model_stats = stats["model_stats"]
for model_state in model_stats:
if "last_inference" not in model_state:
continue
summary_f.write(f"model name is {model_state['name']} \n")
model_inference_stats = model_state["inference_stats"]
total_queue_time_s = (
int(model_inference_stats["queue"]["ns"]) / 10e9
int(model_inference_stats["queue"]["ns"]) / 1e9
)
total_infer_time_s = (
int(model_inference_stats["compute_infer"]["ns"]) / 10e9
int(model_inference_stats["compute_infer"]["ns"]) / 1e9
)
total_input_time_s = (
int(model_inference_stats["compute_input"]["ns"]) / 10e9
int(model_inference_stats["compute_input"]["ns"]) / 1e9
)
total_output_time_s = (
int(model_inference_stats["compute_output"]["ns"]) / 10e9
int(model_inference_stats["compute_output"]["ns"]) / 1e9
)
summary_f.write(
f"queue {total_queue_time_s:<5.2f} s, infer {total_infer_time_s:<5.2f} s, input {total_input_time_s:<5.2f} s, output {total_output_time_s:<5.2f} s \n" # noqa
Expand All @@ -86,9 +88,9 @@ def get_args():
== compute_output["count"]
== compute_input["count"]
)
compute_infer_time_ms = int(compute_infer["ns"]) / 10e6
compute_input_time_ms = int(compute_input["ns"]) / 10e6
compute_output_time_ms = int(compute_output["ns"]) / 10e6
compute_infer_time_ms = int(compute_infer["ns"]) / 1e6
compute_input_time_ms = int(compute_input["ns"]) / 1e6
compute_output_time_ms = int(compute_output["ns"]) / 1e6
summary_f.write(
f"Batch_size {batch_size:<2}, {batch_count:<5} times, infer {compute_infer_time_ms:<9.2f} ms, avg {compute_infer_time_ms/batch_count:.2f} ms, {compute_infer_time_ms/batch_count/batch_size:.2f} ms " # noqa
)
Expand Down
14 changes: 7 additions & 7 deletions runtime/gpu/cuda_decoders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The triton model repository `model_repo_cuda_decoder` here, integrates the [CUDA
```sh
# using docker image runtime/gpu/Dockerfile/Dockerfile.server
docker pull soar97/triton-wenet:22.12
docker run -it --rm --name "wenet_trt_test" --gpus all --shm-size 1g --net host soar97/triton-wenet:22.12
docker run -it --rm --name "wenet_tlg_test" --gpus all --shm-size 1g --net host soar97/triton-wenet:22.12
# inside the docker container
git clone https://github.com/wenet-e2e/wenet.git
cd wenet/runtime/gpu/cuda_wfst_decoder
Expand All @@ -18,10 +18,10 @@ bash run.sh

### TODO: Performance of Small Offline ASR Model using Different Decoders

Benchmark(offline conformer model trained on Aishell1) based on Aishell1 test set with V100, the total audio duration is 36108.919 seconds.
Benchmark(small offline conformer onnx fp16 model trained on Aishell1) based on Aishell1 test set with V100, the total audio duration is 36108.919 seconds.

<!-- (Note: decoding time is the time spent by the decoding process)
|Decoding Method | decoding time(s) | WER (%) |
|----------|--------------------|----------------|
| CTC Greedy Search | | 4.97 |
| CUDA WFST Decoding (3-gram LM) | | | -->
(Note: 80 concurrent tasks, service has been fully warm up.)
|Decoding Method | decoding time(s) | WER (%) |
|----------|--------------------|-------------|
| CTC Greedy Search | 23s | 4.97 |
| CUDA TLG 1-best (3-gram LM) | 31s | 4.58 |
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List
from riva.asrlib.decoder.python_decoder import (BatchedMappedDecoderCuda,
BatchedMappedDecoderCudaConfig)
from frame_reducer import FrameReducer

def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
Expand Down Expand Up @@ -55,11 +56,18 @@ def ctc_greedy_search(ctc_probs, encoder_out_lens, vocabulary, blank_id, eos):
total_hyps.append("".join([vocabulary[i] for i in hyp]))
return total_hyps

def load_word_symbols(path):
word_id_to_word_str = {}
with open(path, "rt", encoding="utf-8") as fh:
for line in fh:
word_str, word_id = line.rstrip().split()
word_id_to_word_str[int(word_id)] = word_str
return word_id_to_word_str

class RivaWFSTDecoder:
def __init__(self, vocab_size, tlg_dir, config_dict, beam_size=8.0):
def __init__(self, vocab_size, tlg_dir, config_dict, nbest=10):
config = BatchedMappedDecoderCudaConfig()
config.online_opts.decoder_opts.lattice_beam = beam_size

config.online_opts.decoder_opts.lattice_beam = config_dict['lattice_beam']
config.online_opts.lattice_postprocessor_opts.acoustic_scale = config_dict['acoustic_scale'] # noqa
config.n_input_per_chunk = config_dict['n_input_per_chunk']
config.online_opts.decoder_opts.default_beam = config_dict['default_beam']
Expand All @@ -71,23 +79,60 @@ def __init__(self, vocab_size, tlg_dir, config_dict, beam_size=8.0):
config.online_opts.lattice_postprocessor_opts.lm_scale = config_dict['lm_scale']
config.online_opts.lattice_postprocessor_opts.word_ins_penalty = config_dict['word_ins_penalty'] # noqa

config.online_opts.num_decoder_copy_threads = 2
config.online_opts.num_post_processing_worker_threads = 4
config.online_opts.lattice_postprocessor_opts.nbest = nbest

# config.online_opts.decoder_opts.blank_penalty = -5.0

self.decoder = BatchedMappedDecoderCuda(
config, os.path.join(tlg_dir, "TLG.fst"),
os.path.join(tlg_dir, "words.txt"), vocab_size
)
self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
self.nbest = nbest
self.vocab_size = vocab_size
self.frame_reducer = FrameReducer(0.98)

def decode(self, logits, length):
def decode_nbest(self, logits, length):
logits, length = self.frame_reducer(logits, length.cuda(), logits)
logits = logits.to(torch.float32).contiguous()
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
results = self.decoder.decode(logits, sequence_lengths_tensor)
return results
results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
total_hyps, total_hyps_id, total_scores = [], [], []
max_hyp_len = 3 # [sos, 0, eos]
for nbest_sentences in results:
nbest_list, nbest_id_list, nbest_scores = [], [], []
for sent in nbest_sentences:
# subtract 1 to get the label id,
# since fst decoder adds 1 to the label id
hyp_ids = [label - 1 for label in sent.ilabels]
# padding for hyps_pad_sos_eos
new_hyp = [self.vocab_size - 1] + remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size - 1, blank_id=0) + [self.vocab_size - 1] # noqa
max_hyp_len = max(max_hyp_len, len(new_hyp))
nbest_id_list.append(new_hyp)

def get_nbest_list(self, results, nbest=1):
assert nbest == 1, "Only support nbest=1 for now"
hyp = "".join(self.word_id_to_word_str[word]
for word in sent.words if word != 0)
nbest_list.append(hyp)
nbest_scores.append(sent.score)
nbest_list += [""] * (self.nbest - len(nbest_list))
total_hyps.append(nbest_list)
nbest_id_list += [[self.vocab_size - 1, 0, self.vocab_size - 1]] * (self.nbest - len(nbest_id_list)) # noqa
total_hyps_id.append(nbest_id_list)
nbest_scores += [0.0] * (self.nbest - len(nbest_scores))
total_scores.append(nbest_scores)
return total_hyps, total_hyps_id, total_scores, max_hyp_len

def decode_mbr(self, logits, length):
logits, length = self.frame_reducer(logits, length.cuda(), logits)
# logits[:,:,0] -= 2.0
logits = logits.to(torch.float32).contiguous()
sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
total_hyps = []
for sent in results:
hyp = [word[0] for word in sent]
hyp_zh = "".join(hyp)
nbest_list = [hyp_zh] # TODO: add real nbest
total_hyps.append(nbest_list)
total_hyps.append(hyp_zh)
return total_hyps
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
>>> lengths = torch.tensor([1, 3, 2, 5])
>>> make_pad_mask(lengths)
tensor([[False, True, True, True, True],
[False, False, False, True, True],
[False, False, True, True, True],
[False, False, False, False, False]])
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)

return expaned_lengths >= lengths.unsqueeze(-1)



class FrameReducer(nn.Module):
"""The encoder output is first used to calculate
the CTC posterior probability; then for each output frame,
if its blank posterior is bigger than some thresholds,
it will be simply discarded from the encoder output.
"""

def __init__(
self,
blank_threshlod: float = 0.95,
):
super().__init__()
self.blank_threshlod = blank_threshlod

def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
y_lens: Optional[torch.Tensor] = None,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The shared encoder output with shape [N, T, C].
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
y_lens:
A tensor of shape (batch_size,) containing the number of frames in
`y` before padding.
blank_id:
The blank id of ctc_output.
Returns:
out:
The frame reduced encoder output with shape [N, T', C].
out_lens:
A tensor of shape (batch_size,) containing the number of frames in
`out` before padding.
"""
N, T, C = x.size()

padding_mask = make_pad_mask(x_lens, x.size(1))
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(self.blank_threshlod)) * (~padding_mask) # noqa

if y_lens is not None:
# Limit the maximum number of reduced frames
limit_lens = T - y_lens
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
T = torch.remainder(T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
limit_mask = torch.full_like(
non_blank_mask,
False,
device=x.device,
).scatter_(1, limit_indexes, True)

non_blank_mask = non_blank_mask | ~limit_mask

out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = (
torch.full_like(
out_lens,
max_len.item(),
device=x.device,
)
- out_lens
)
max_pad_len = pad_lens_list.max()

out = F.pad(x, (0, 0, 0, max_pad_len))

valid_pad_mask = ~make_pad_mask(pad_lens_list)
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)

out = out[total_valid_mask].reshape(N, -1, C)

return out, out_lens


if __name__ == "__main__":
import time

test_times = 10000
device = "cuda:0"
frame_reducer = FrameReducer()

# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.log(
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
)

avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)

# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)

avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)
Loading