Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI support for spoken language identification #782

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/test-go-package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,12 @@ jobs:
./run-vits-vctk.sh
rm -rf vits-vctk

echo "Test vits-zh-aishell3"
git clone https://huggingface.co/csukuangfj/vits-zh-aishell3
echo "Test vits-icefall-zh-aishell3"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2
tar xvf vits-icefall-zh-aishell3.tar.bz2
rm vits-icefall-zh-aishell3.tar.bz2
./run-vits-zh-aishell3.sh
rm -rf vits-zh-aishell3
rm -rf vits-icefall-zh-aishell3*

echo "Test vits-piper-en_US-lessac-medium"
git clone https://huggingface.co/csukuangfj/vits-piper-en_US-lessac-medium
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,4 @@ sr-data

vits-icefall-*
sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
spoken-language-identification-test-wavs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import android.util.Log
private val TAG = "sherpa-onnx"

data class OfflineZipformerAudioTaggingModelConfig(
val model: String,
var model: String,
)

data class AudioTaggingModelConfig(
Expand Down Expand Up @@ -134,4 +134,4 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
}

return null
}
}
36 changes: 36 additions & 0 deletions kotlin-api-examples/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,49 @@ fun callback(samples: FloatArray): Unit {
}

fun main() {
testSpokenLanguageIdentifcation()
testAudioTagging()
testSpeakerRecognition()
testTts()
testAsr("transducer")
testAsr("zipformer2-ctc")
}

fun testSpokenLanguageIdentifcation() {
val config = SpokenLanguageIdentificationConfig(
whisper = SpokenLanguageIdentificationWhisperConfig(
encoder = "./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx",
decoder = "./sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx",
tailPaddings = 33,
),
numThreads=1,
debug=true,
provider="cpu",
)
val slid = SpokenLanguageIdentification(assetManager=null, config=config)

val testFiles = arrayOf(
"./spoken-language-identification-test-wavs/ar-arabic.wav",
"./spoken-language-identification-test-wavs/bg-bulgarian.wav",
"./spoken-language-identification-test-wavs/de-german.wav",
)

for (waveFilename in testFiles) {
val objArray = WaveReader.readWaveFromFile(
filename = waveFilename,
)
val samples: FloatArray = objArray[0] as FloatArray
val sampleRate: Int = objArray[1] as Int

val stream = slid.createStream()
stream.acceptWaveform(samples, sampleRate = sampleRate)
val lang = slid.compute(stream)
stream.release()
println(waveFilename)
println(lang)
}
}

fun testAudioTagging() {
val config = AudioTaggingConfig(
model=AudioTaggingModelConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,22 @@ import android.util.Log

private val TAG = "sherpa-onnx"

data class OfflineZipformerAudioTaggingModelConfig (
val model: String,
data class SpokenLanguageIdentificationWhisperConfig (
var encoder: String,
var decoder: String,
var tailPaddings: Int = -1,
)

data class AudioTaggingModelConfig (
var zipformer: OfflineZipformerAudioTaggingModelConfig,
data class SpokenLanguageIdentificationConfig (
var whisper: SpokenLanguageIdentificationWhisperConfig,
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)

data class AudioTaggingConfig (
var model: AudioTaggingModelConfig,
var labels: String,
var topK: Int = 5,
)

data class AudioEvent (
val name: String,
val index: Int,
val prob: Float,
)

class AudioTagging(
class SpokenLanguageIdentification (
assetManager: AssetManager? = null,
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
) {
private var ptr: Long

Expand All @@ -43,10 +33,10 @@ class AudioTagging(
}

protected fun finalize() {
if(ptr != 0) {
delete(ptr)
ptr = 0
}
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}

fun release() = finalize()
Expand All @@ -56,25 +46,22 @@ class AudioTagging(
return OfflineStream(p)
}

// fun compute(stream: OfflineStream, topK: Int=-1): Array<AudioEvent> {
fun compute(stream: OfflineStream, topK: Int=-1): Array<Any> {
var events :Array<Any> = compute(ptr, stream.ptr, topK)
}
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)

private external fun newFromAsset(
assetManager: AssetManager,
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
): Long

private external fun newFromFile(
config: AudioTaggingConfig,
config: SpokenLanguageIdentificationConfig,
): Long

private external fun delete(ptr: Long)

private external fun createStream(ptr: Long): Long

private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
private external fun compute(ptr: Long, streamPtr: Long): String

companion object {
init {
Expand Down
32 changes: 24 additions & 8 deletions kotlin-api-examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ cd ../kotlin-api-examples

function testSpeakerEmbeddingExtractor() {
if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx
fi

if [ ! -f ./speaker1_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
fi

if [ ! -f ./speaker1_b_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
fi

if [ ! -f ./speaker2_a_cn_16k.wav ]; then
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
curl -SL -O https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav
fi
}

Expand All @@ -53,15 +53,15 @@ function testAsr() {
fi

if [ ! -d ./sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13 ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-multi-zh-hans-2023-12-13.tar.bz2
fi
}

function testTts() {
if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2
tar xf vits-piper-en_US-amy-low.tar.bz2
rm vits-piper-en_US-amy-low.tar.bz2
fi
Expand All @@ -75,7 +75,22 @@ function testAudioTagging() {
fi
}

function testSpokenLanguageIdentification() {
if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
fi

if [ ! -f ./spoken-language-identification-test-wavs/ar-arabic.wav ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/spoken-language-identification-test-wavs.tar.bz2
tar xvf spoken-language-identification-test-wavs.tar.bz2
rm spoken-language-identification-test-wavs.tar.bz2
fi
}

function test() {
testSpokenLanguageIdentification
testAudioTagging
testSpeakerEmbeddingExtractor
testAsr
Expand All @@ -90,6 +105,7 @@ kotlinc-jvm -include-runtime -d main.jar \
OfflineStream.kt \
SherpaOnnx.kt \
Speaker.kt \
SpokenLanguageIdentification.kt \
Tts.kt \
WaveReader.kt \
faked-asset-manager.kt \
Expand All @@ -101,13 +117,13 @@ java -Djava.library.path=../build/lib -jar main.jar

function testTwoPass() {
if [ ! -f ./sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/encoder-epoch-99-avg-1.int8.onnx ]; then
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2
fi

if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
rm sherpa-onnx-whisper-tiny.en.tar.bz2
fi
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_library(sherpa-onnx-jni
audio-tagging.cc
jni.cc
offline-stream.cc
spoken-language-identification.cc
)
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
install(TARGETS sherpa-onnx-jni DESTINATION lib)
104 changes: 104 additions & 0 deletions sherpa-onnx/jni/spoken-language-identification.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// sherpa-onnx/jni/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/spoken-language-identification.h"

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"

namespace sherpa_onnx {

static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig(
JNIEnv *env, jobject config) {
SpokenLanguageIdentificationConfig ans;

jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(
cls, "whisper",
"Lcom/k2fsa/sherpa/onnx/SpokenLanguageIdentificationWhisperConfig;");

jobject whisper = env->GetObjectField(config, fid);
jclass whisper_cls = env->GetObjectClass(whisper);

fid = env->GetFieldID(whisper_cls, "encoder", "Ljava/lang/String;");

jstring s = (jstring)env->GetObjectField(whisper, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.whisper.encoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(whisper_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(whisper, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.whisper.decoder = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(whisper_cls, "tailPaddings", "I");
ans.whisper.tail_paddings = env->GetIntField(whisper, fid);

fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);

fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);

fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);

return ans;
}

} // namespace sherpa_onnx

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config =
sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config);
SHERPA_ONNX_LOGE("SpokenLanguageIdentification newFromFile config:\n%s",
config.ToString().c_str());

if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}

auto tagger = new sherpa_onnx::SpokenLanguageIdentification(config);

return (jlong)tagger;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto slid =
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
std::unique_ptr<sherpa_onnx::OfflineStream> s = slid->CreateStream();

// The user is responsible to free the returned pointer.
//
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
// ./offline-stream.cc
sherpa_onnx::OfflineStream *p = s.release();
return (jlong)p;
}

SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong s_ptr) {
sherpa_onnx::SpokenLanguageIdentification *slid =
reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
sherpa_onnx::OfflineStream *s =
reinterpret_cast<sherpa_onnx::OfflineStream *>(s_ptr);
std::string lang = slid->Compute(s);
return env->NewStringUTF(lang.c_str());
}