Skip to content

Commit

Permalink
Support audio tagging using zipformer (#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Apr 10, 2024
1 parent c9ae759 commit f20291c
Show file tree
Hide file tree
Showing 30 changed files with 927 additions and 11 deletions.
32 changes: 32 additions & 0 deletions .github/scripts/test-audio-tagging.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env bash

set -ex

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

echo "EXE is $EXE"
echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run zipformer for audio tagging "
log "------------------------------------------------------------"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2
repo=sherpa-onnx-zipformer-audio-tagging-2024-04-09
ls -lh $repo

for w in 1.wav 2.wav 3.wav 4.wav; do
$EXE \
--zipformer-model=$repo/model.onnx \
--labels=$repo/class_labels_indices.csv \
$repo/test_wavs/$w
done
rm -rf $repo
10 changes: 10 additions & 0 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -32,6 +33,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -124,6 +126,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: build/bin/*

- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-audio-tagging
.github/scripts/test-audio-tagging.sh
- name: Test online CTC
shell: bash
run: |
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -31,6 +32,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -103,6 +105,14 @@ jobs:
otool -L build/bin/sherpa-onnx
otool -l build/bin/sherpa-onnx
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx-offline-audio-tagging
.github/scripts/test-audio-tagging.sh
- name: Test C API
shell: bash
run: |
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/windows-x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -28,6 +29,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -70,6 +72,14 @@ jobs:
ls -lh ./bin/Release/sherpa-onnx.exe
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-audio-tagging.exe
.github/scripts/test-audio-tagging.sh
- name: Test C API
shell: bash
run: |
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/windows-x86.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand All @@ -28,6 +29,7 @@ on:
- '.github/scripts/test-offline-ctc.sh'
- '.github/scripts/test-offline-tts.sh'
- '.github/scripts/test-online-ctc.sh'
- '.github/scripts/test-audio-tagging.sh'
- 'CMakeLists.txt'
- 'cmake/**'
- 'sherpa-onnx/csrc/*'
Expand Down Expand Up @@ -85,6 +87,13 @@ jobs:
# export EXE=sherpa-onnx-offline-language-identification.exe
#
# .github/scripts/test-spoken-language-identification.sh
- name: Test Audio tagging
shell: bash
run: |
export PATH=$PWD/build/bin/Release:$PATH
export EXE=sherpa-onnx-offline-audio-tagging.exe
.github/scripts/test-audio-tagging.sh
- name: Test online CTC
shell: bash
Expand Down
1 change: 1 addition & 0 deletions cmake/cmake_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def enable_alsa():
def get_binaries():
binaries = [
"sherpa-onnx",
"sherpa-onnx-offline-audio-tagging",
"sherpa-onnx-keyword-spotter",
"sherpa-onnx-microphone",
"sherpa-onnx-microphone-offline",
Expand Down
2 changes: 2 additions & 0 deletions go-api-examples/vad-asr-paraformer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
go.sum
vad-asr-paraformer
2 changes: 1 addition & 1 deletion nodejs-examples/test-offline-tts-zh.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const sherpa_onnx = require('sherpa-onnx');

function createOfflineTts() {
let offlineTtsVitsModelConfig = {
model: './vits-icefall-zh-aishell3/vits-aishell3.onnx',
model: './vits-icefall-zh-aishell3/model.onnx',
lexicon: './vits-icefall-zh-aishell3/lexicon.txt',
tokens: './vits-icefall-zh-aishell3/tokens.txt',
dataDir: '',
Expand Down
12 changes: 12 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ list(APPEND sources
speaker-embedding-manager.cc
)

# audio tagging
list(APPEND sources
audio-tagging-impl.cc
audio-tagging-label-file.cc
audio-tagging-model-config.cc
audio-tagging.cc
offline-zipformer-audio-tagging-model-config.cc
offline-zipformer-audio-tagging-model.cc
)

if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
lexicon.cc
Expand Down Expand Up @@ -193,6 +203,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc)

if(SHERPA_ONNX_ENABLE_TTS)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
Expand All @@ -204,6 +215,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-language-identification
sherpa-onnx-offline-audio-tagging
)
if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND main_exes
Expand Down
23 changes: 23 additions & 0 deletions sherpa-onnx/csrc/audio-tagging-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// sherpa-onnx/csrc/audio-tagging-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/audio-tagging-impl.h"

#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(config);
}

SHERPA_ONNX_LOG(
"Please specify an audio tagging model! Return a null pointer");
return nullptr;
}

} // namespace sherpa_onnx
29 changes: 29 additions & 0 deletions sherpa-onnx/csrc/audio-tagging-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// sherpa-onnx/csrc/audio-tagging-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_

#include <memory>
#include <vector>

#include "sherpa-onnx/csrc/audio-tagging.h"

namespace sherpa_onnx {

class AudioTaggingImpl {
public:
virtual ~AudioTaggingImpl() = default;

static std::unique_ptr<AudioTaggingImpl> Create(
const AudioTaggingConfig &config);

virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;

virtual std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const = 0;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_IMPL_H_
70 changes: 70 additions & 0 deletions sherpa-onnx/csrc/audio-tagging-label-file.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// sherpa-onnx/csrc/audio-tagging-label-file.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/audio-tagging-label-file.h"

#include <fstream>
#include <sstream>
#include <string>

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {

AudioTaggingLabels::AudioTaggingLabels(const std::string &filename) {
std::ifstream is(filename);
Init(is);
}

// Format of a label file
/*
index,mid,display_name
0,/m/09x0r,"Speech"
1,/m/05zppz,"Male speech, man speaking"
*/
void AudioTaggingLabels::Init(std::istream &is) {
std::string line;
std::getline(is, line); // skip the header

std::string index;
std::string tmp;
std::string name;

while (std::getline(is, line)) {
index.clear();
name.clear();
std::istringstream input2(line);

std::getline(input2, index, ',');
std::getline(input2, tmp, ',');
std::getline(input2, name);

std::size_t pos{};
int32_t i = std::stoi(index, &pos);
if (index.size() == 0 || pos != index.size()) {
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
exit(-1);
}

if (i != names_.size()) {
SHERPA_ONNX_LOGE(
"Index should be sorted and contiguous. Expected index: %d, given: "
"%d.",
static_cast<int32_t>(names_.size()), i);
}
if (name.empty() || name.front() != '"' || name.back() != '"') {
SHERPA_ONNX_LOGE("Invalid line: %s", line.c_str());
exit(-1);
}

names_.emplace_back(name.begin() + 1, name.end() - 1);
}
}

const std::string &AudioTaggingLabels::GetEventName(int32_t index) const {
return names_.at(index);
}

} // namespace sherpa_onnx
31 changes: 31 additions & 0 deletions sherpa-onnx/csrc/audio-tagging-label-file.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// sherpa-onnx/csrc/audio-tagging-label-file.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_

#include <istream>
#include <string>
#include <vector>

namespace sherpa_onnx {

class AudioTaggingLabels {
public:
explicit AudioTaggingLabels(const std::string &filename);

// Return the event name for the given index.
// The returned reference is valid as long as this object is alive
const std::string &GetEventName(int32_t index) const;
int32_t NumEventClasses() const { return names_.size(); }

private:
void Init(std::istream &is);

private:
std::vector<std::string> names_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_LABEL_FILE_H_
Loading

0 comments on commit f20291c

Please sign in to comment.