diff --git a/CMakeLists.txt b/CMakeLists.txt index c57ae3598..6af2fa9a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,7 @@ option(SHERPA_ONNX_ENABLE_GPU "Enable ONNX Runtime GPU support" OFF) option(SHERPA_ONNX_ENABLE_WASM "Whether to enable WASM" OFF) option(SHERPA_ONNX_ENABLE_WASM_TTS "Whether to enable WASM for TTS" OFF) option(SHERPA_ONNX_ENABLE_WASM_ASR "Whether to enable WASM for ASR" OFF) +option(SHERPA_ONNX_ENABLE_WASM_KWS "Whether to enable WASM for KWS" OFF) option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF) option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON) option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) @@ -135,6 +136,10 @@ if(SHERPA_ONNX_ENABLE_WASM) add_definitions(-DSHERPA_ONNX_ENABLE_WASM=1) endif() +if(SHERPA_ONNX_ENABLE_WASM_KWS) + add_definitions(-DSHERPA_ONNX_ENABLE_WASM_KWS=1) +endif() + if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") endif() diff --git a/build-wasm-simd-kws.sh b/build-wasm-simd-kws.sh new file mode 100644 index 000000000..8310c2098 --- /dev/null +++ b/build-wasm-simd-kws.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +if [ x"$EMSCRIPTEN" == x"" ]; then + if ! command -v emcc &> /dev/null; then + echo "Please install emscripten first" + echo "" + echo "You can use the following commands to install it:" + echo "" + echo "git clone https://github.com/emscripten-core/emsdk.git" + echo "cd emsdk" + echo "git pull" + echo "./emsdk install latest" + echo "./emsdk activate latest" + echo "source ./emsdk_env.sh" + exit 1 + else + EMSCRIPTEN=$(dirname $(realpath $(which emcc))) + fi +fi + +export EMSCRIPTEN=$EMSCRIPTEN +echo "EMSCRIPTEN: $EMSCRIPTEN" +if [ ! -f $EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake ]; then + echo "Cannot find $EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake" + echo "Please make sure you have installed emsdk correctly" + exit 1 +fi + +mkdir -p build-wasm-simd-kws +pushd build-wasm-simd-kws + +export SHERPA_ONNX_IS_USING_BUILD_WASM_SH=ON + +cmake \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$EMSCRIPTEN/cmake/Modules/Platform/Emscripten.cmake \ + \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DBUILD_SHARED_LIBS=OFF \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=OFF \ + -DSHERPA_ONNX_ENABLE_C_API=ON \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ + -DSHERPA_ONNX_ENABLE_GPU=OFF \ + -DSHERPA_ONNX_ENABLE_WASM=ON \ + -DSHERPA_ONNX_ENABLE_WASM_KWS=ON \ + -DSHERPA_ONNX_ENABLE_BINARY=OFF \ + -DSHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY=OFF \ + .. +make -j8 +make install + +ls -lh install/bin/wasm diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 9ec76e22e..d9886c64b 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -481,7 +481,7 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; spotter->impl = - std::make_unique(spotter_config); + std::make_unique(spotter_config); return spotter; } @@ -493,7 +493,7 @@ void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { SherpaOnnxOnlineStream* CreateKeywordStream( const SherpaOnnxKeywordSpotter* spotter) { SherpaOnnxOnlineStream* stream = - new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); + new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); return stream; } @@ -512,7 +512,7 @@ void DecodeMultipleKeywordStreams( int32_t n) { std::vector ss(n); for (int32_t i = 0; i != n; ++i) { - ss[i] = streams[i]->impl.get(); + ss[i] = streams[i]->impl.get(); } spotter->impl->DecodeStreams(ss.data(), n); } @@ -593,7 +593,6 @@ void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) { } } - // ============================================================ // For VAD // ============================================================ diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 50a3e252b..ef22a9984 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -266,8 +266,14 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { } void InitKeywords() { +#ifdef SHERPA_ONNX_ENABLE_WASM_KWS + // Due to the limitations of the wasm file system, + // the keyword_file variable is directly parsed as a string of keywords + // if WASM KWS on + std::istringstream is(config_.keywords_file); + InitKeywords(is); +#else // each line in keywords_file contains space-separated words - std::ifstream is(config_.keywords_file); if (!is) { SHERPA_ONNX_LOGE("Open keywords file failed: %s", @@ -275,6 +281,7 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { exit(-1); } InitKeywords(is); +#endif } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index 342b8308f..274a7fddf 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -94,10 +94,17 @@ bool KeywordSpotterConfig::Validate() const { SHERPA_ONNX_LOGE("Please provide --keywords-file."); return false; } + +#ifndef SHERPA_ONNX_ENABLE_WASM_KWS + // due to the limitations of the wasm file system, + // keywords file will be packaged into the sherpa-onnx-wasm-kws-main.data file + // Solution: take keyword_file variable is directly + // parsed as a string of keywords if (!std::ifstream(keywords_file.c_str()).good()) { SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str()); return false; } +#endif return model_config.Validate(); } diff --git a/sherpa-onnx/csrc/transducer-keyword-decoder.cc b/sherpa-onnx/csrc/transducer-keyword-decoder.cc index f31348ea9..ef8314ed8 100644 --- a/sherpa-onnx/csrc/transducer-keyword-decoder.cc +++ b/sherpa-onnx/csrc/transducer-keyword-decoder.cc @@ -2,16 +2,14 @@ // // Copyright (c) 2023-2024 Xiaomi Corporation -#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" - #include #include -#include #include #include #include "sherpa-onnx/csrc/log.h" #include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/transducer-keyword-decoder.h" namespace sherpa_onnx { diff --git a/wasm/CMakeLists.txt b/wasm/CMakeLists.txt index c5d283f19..d7c7a1a17 100644 --- a/wasm/CMakeLists.txt +++ b/wasm/CMakeLists.txt @@ -6,6 +6,10 @@ if(SHERPA_ONNX_ENABLE_WASM_ASR) add_subdirectory(asr) endif() +if(SHERPA_ONNX_ENABLE_WASM_KWS) + add_subdirectory(kws) +endif() + if(SHERPA_ONNX_ENABLE_WASM_NODEJS) add_subdirectory(nodejs) endif() diff --git a/wasm/kws/CMakeLists.txt b/wasm/kws/CMakeLists.txt new file mode 100644 index 000000000..f083892cc --- /dev/null +++ b/wasm/kws/CMakeLists.txt @@ -0,0 +1,54 @@ +if(NOT $ENV{SHERPA_ONNX_IS_USING_BUILD_WASM_SH}) + message(FATAL_ERROR "Please use ./build-wasm-simd-kws.sh to build for wasm KWS") +endif() + +if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/assets/decoder-epoch-12-avg-2-chunk-16-left-64.onnx") + message(WARNING "${CMAKE_CURRENT_SOURCE_DIR}/assets/decoder-epoch-12-avg-2-chunk-16-left-64.onnx does not exist") + message(FATAL_ERROR "Please read ${CMAKE_CURRENT_SOURCE_DIR}/assets/README.md before you continue") +endif() + +set(exported_functions + AcceptWaveform + CreateKeywordSpotter + DestroyKeywordSpotter + CreateKeywordStream + DecodeKeywordStream + GetKeywordResult + DestroyKeywordResult + IsKeywordStreamReady + InputFinished +) +set(mangled_exported_functions) +foreach(x IN LISTS exported_functions) + list(APPEND mangled_exported_functions "_${x}") +endforeach() + +list(JOIN mangled_exported_functions "," all_exported_functions) + +include_directories(${CMAKE_SOURCE_DIR}) +set(MY_FLAGS "-s FORCE_FILESYSTEM=1 -s INITIAL_MEMORY=512MB -s ALLOW_MEMORY_GROWTH=1") +string(APPEND MY_FLAGS " -sSTACK_SIZE=10485760 ") +string(APPEND MY_FLAGS " -sEXPORTED_FUNCTIONS=[_CopyHeap,_malloc,_free,${all_exported_functions}] ") +string(APPEND MY_FLAGS "--preload-file ${CMAKE_CURRENT_SOURCE_DIR}/assets@. ") +string(APPEND MY_FLAGS " -sEXPORTED_RUNTIME_METHODS=['ccall','stringToUTF8','setValue','getValue','lengthBytesUTF8','UTF8ToString'] ") +message(STATUS "MY_FLAGS: ${MY_FLAGS}") + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${MY_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MY_FLAGS}") +set(CMAKE_EXECUTBLE_LINKER_FLAGS "${CMAKE_EXECUTBLE_LINKER_FLAGS} ${MY_FLAGS}") + +add_executable(sherpa-onnx-wasm-kws-main sherpa-onnx-wasm-main-kws.cc) +target_link_libraries(sherpa-onnx-wasm-kws-main sherpa-onnx-c-api) +install(TARGETS sherpa-onnx-wasm-kws-main DESTINATION bin/wasm) + +install( + FILES + "sherpa-onnx-kws.js" + "app.js" + "index.html" + "$/sherpa-onnx-wasm-kws-main.js" + "$/sherpa-onnx-wasm-kws-main.wasm" + "$/sherpa-onnx-wasm-kws-main.data" + DESTINATION + bin/wasm +) \ No newline at end of file diff --git a/wasm/kws/app.js b/wasm/kws/app.js new file mode 100644 index 000000000..e823f0494 --- /dev/null +++ b/wasm/kws/app.js @@ -0,0 +1,290 @@ +// This file copies and modifies code +// from https://mdn.github.io/web-dictaphone/scripts/app.js +// and https://gist.github.com/meziantou/edb7217fddfbb70e899e + +const startBtn = document.getElementById('startBtn'); +const stopBtn = document.getElementById('stopBtn'); +const clearBtn = document.getElementById('clearBtn'); +const hint = document.getElementById('hint'); +const soundClips = document.getElementById('sound-clips'); + +let textArea = document.getElementById('results'); + +let lastResult = ''; +let resultList = []; + +clearBtn.onclick = function() { + resultList = []; + textArea.value = getDisplayResult(); + textArea.scrollTop = textArea.scrollHeight; // auto scroll +}; + +function getDisplayResult() { + let i = 0; + let ans = ''; + for (let s in resultList) { + if (resultList[s] == '') { + continue; + } + + ans += '' + i + ': ' + resultList[s] + '\n'; + i += 1; + } + + return ans; +} + + +Module = {}; +Module.onRuntimeInitialized = function() { + console.log('inited!'); + hint.innerText = 'Model loaded! Please click start'; + + startBtn.disabled = false; + + recognizer = createKws(Module); + console.log('recognizer is created!', recognizer); +}; + +let audioCtx; +let mediaStream; + +let expectedSampleRate = 16000; +let recordSampleRate; // the sampleRate of the microphone +let recorder = null; // the microphone +let leftchannel = []; // TODO: Use a single channel + +let recordingLength = 0; // number of samples so far + +let recognizer = null; +let recognizer_stream = null; + +if (navigator.mediaDevices.getUserMedia) { + console.log('getUserMedia supported.'); + + // see https://w3c.github.io/mediacapture-main/#dom-mediadevices-getusermedia + const constraints = {audio: true}; + + let onSuccess = function(stream) { + if (!audioCtx) { + audioCtx = new AudioContext({sampleRate: 16000}); + } + console.log(audioCtx); + recordSampleRate = audioCtx.sampleRate; + console.log('sample rate ' + recordSampleRate); + + // creates an audio node from the microphone incoming stream + mediaStream = audioCtx.createMediaStreamSource(stream); + console.log('media stream', mediaStream); + + // https://developer.mozilla.org/en-US/docs/Web/API/AudioContext/createScriptProcessor + // bufferSize: the onaudioprocess event is called when the buffer is full + var bufferSize = 4096; + var numberOfInputChannels = 1; + var numberOfOutputChannels = 2; + if (audioCtx.createScriptProcessor) { + recorder = audioCtx.createScriptProcessor( + bufferSize, numberOfInputChannels, numberOfOutputChannels); + } else { + recorder = audioCtx.createJavaScriptNode( + bufferSize, numberOfInputChannels, numberOfOutputChannels); + } + console.log('recorder', recorder); + + recorder.onaudioprocess = function(e) { + let samples = new Float32Array(e.inputBuffer.getChannelData(0)) + samples = downsampleBuffer(samples, expectedSampleRate); + + if (recognizer_stream == null) { + recognizer_stream = recognizer.createStream(); + } + + recognizer_stream.acceptWaveform(expectedSampleRate, samples); + while (recognizer.isReady(recognizer_stream)) { + recognizer.decode(recognizer_stream); + } + + + let result = recognizer.getResult(recognizer_stream); + console.log(result) + + if (result.keyword.length > 0) { + lastResult = result; + resultList.push(JSON.stringify(result)); + } + + + textArea.value = getDisplayResult(); + textArea.scrollTop = textArea.scrollHeight; // auto scroll + + let buf = new Int16Array(samples.length); + for (var i = 0; i < samples.length; ++i) { + let s = samples[i]; + if (s >= 1) + s = 1; + else if (s <= -1) + s = -1; + + samples[i] = s; + buf[i] = s * 32767; + } + + leftchannel.push(buf); + recordingLength += bufferSize; + }; + + startBtn.onclick = function() { + mediaStream.connect(recorder); + recorder.connect(audioCtx.destination); + + console.log('recorder started'); + + stopBtn.disabled = false; + startBtn.disabled = true; + }; + + stopBtn.onclick = function() { + console.log('recorder stopped'); + + // stopBtn recording + recorder.disconnect(audioCtx.destination); + mediaStream.disconnect(recorder); + + startBtn.style.background = ''; + startBtn.style.color = ''; + // mediaRecorder.requestData(); + + stopBtn.disabled = true; + startBtn.disabled = false; + + var clipName = new Date().toISOString(); + + const clipContainer = document.createElement('article'); + const clipLabel = document.createElement('p'); + const audio = document.createElement('audio'); + const deleteButton = document.createElement('button'); + clipContainer.classList.add('clip'); + audio.setAttribute('controls', ''); + deleteButton.textContent = 'Delete'; + deleteButton.className = 'delete'; + + clipLabel.textContent = clipName; + + clipContainer.appendChild(audio); + + clipContainer.appendChild(clipLabel); + clipContainer.appendChild(deleteButton); + soundClips.appendChild(clipContainer); + + audio.controls = true; + let samples = flatten(leftchannel); + const blob = toWav(samples); + + leftchannel = []; + const audioURL = window.URL.createObjectURL(blob); + audio.src = audioURL; + console.log('recorder stopped'); + + deleteButton.onclick = function(e) { + let evtTgt = e.target; + evtTgt.parentNode.parentNode.removeChild(evtTgt.parentNode); + }; + + clipLabel.onclick = function() { + const existingName = clipLabel.textContent; + const newClipName = prompt('Enter a new name for your sound clip?'); + if (newClipName === null) { + clipLabel.textContent = existingName; + } else { + clipLabel.textContent = newClipName; + } + }; + }; + }; + + let onError = function(err) { + console.log('The following error occured: ' + err); + }; + + navigator.mediaDevices.getUserMedia(constraints).then(onSuccess, onError); +} else { + console.log('getUserMedia not supported on your browser!'); + alert('getUserMedia not supported on your browser!'); +} + + +// this function is copied/modified from +// https://gist.github.com/meziantou/edb7217fddfbb70e899e +function flatten(listOfSamples) { + let n = 0; + for (let i = 0; i < listOfSamples.length; ++i) { + n += listOfSamples[i].length; + } + let ans = new Int16Array(n); + + let offset = 0; + for (let i = 0; i < listOfSamples.length; ++i) { + ans.set(listOfSamples[i], offset); + offset += listOfSamples[i].length; + } + return ans; +} + +// this function is copied/modified from +// https://gist.github.com/meziantou/edb7217fddfbb70e899e +function toWav(samples) { + let buf = new ArrayBuffer(44 + samples.length * 2); + var view = new DataView(buf); + + // http://soundfile.sapp.org/doc/WaveFormat/ + // F F I R + view.setUint32(0, 0x46464952, true); // chunkID + view.setUint32(4, 36 + samples.length * 2, true); // chunkSize + // E V A W + view.setUint32(8, 0x45564157, true); // format + // + // t m f + view.setUint32(12, 0x20746d66, true); // subchunk1ID + view.setUint32(16, 16, true); // subchunk1Size, 16 for PCM + view.setUint32(20, 1, true); // audioFormat, 1 for PCM + view.setUint16(22, 1, true); // numChannels: 1 channel + view.setUint32(24, expectedSampleRate, true); // sampleRate + view.setUint32(28, expectedSampleRate * 2, true); // byteRate + view.setUint16(32, 2, true); // blockAlign + view.setUint16(34, 16, true); // bitsPerSample + view.setUint32(36, 0x61746164, true); // Subchunk2ID + view.setUint32(40, samples.length * 2, true); // subchunk2Size + + let offset = 44; + for (let i = 0; i < samples.length; ++i) { + view.setInt16(offset, samples[i], true); + offset += 2; + } + + return new Blob([view], {type: 'audio/wav'}); +} + +// this function is copied from +// https://github.com/awslabs/aws-lex-browser-audio-capture/blob/master/lib/worker.js#L46 +function downsampleBuffer(buffer, exportSampleRate) { + if (exportSampleRate === recordSampleRate) { + return buffer; + } + var sampleRateRatio = recordSampleRate / exportSampleRate; + var newLength = Math.round(buffer.length / sampleRateRatio); + var result = new Float32Array(newLength); + var offsetResult = 0; + var offsetBuffer = 0; + while (offsetResult < result.length) { + var nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio); + var accum = 0, count = 0; + for (var i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) { + accum += buffer[i]; + count++; + } + result[offsetResult] = accum / count; + offsetResult++; + offsetBuffer = nextOffsetBuffer; + } + return result; +}; \ No newline at end of file diff --git a/wasm/kws/assets/README.md b/wasm/kws/assets/README.md new file mode 100644 index 000000000..ac67fb5a0 --- /dev/null +++ b/wasm/kws/assets/README.md @@ -0,0 +1,27 @@ +# Introduction + +Please refer to +https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary +to download a model. + +# Kws + +The following is an example: +``` +cd sherpa-onnx/wasm/kws +git clone https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git assets +``` + +You should have the following files in `assets` before you can run +`build-wasm-simd-kws.sh` + +``` +├── decoder-epoch-12-avg-2-chunk-16-left-64.onnx +├── encoder-epoch-12-avg-2-chunk-16-left-64.onnx +├── joiner-epoch-12-avg-2-chunk-16-left-64.onnx +├── keywords_raw.txt +├── keywords.txt +├── README.md +└── tokens.txt + +``` diff --git a/wasm/kws/index.html b/wasm/kws/index.html new file mode 100644 index 000000000..cd0c43f54 --- /dev/null +++ b/wasm/kws/index.html @@ -0,0 +1,40 @@ + + + + + + Next-gen Kaldi WebAssembly with sherpa-onnx for kws + + + + +

+ WebAssembly
+ Kws Demo with sherpa-onnx +

+
+ Loading model ... ... +
+
+ + + +
+
+ +
+ +
+
+ + + + + \ No newline at end of file diff --git a/wasm/kws/sherpa-onnx-kws.js b/wasm/kws/sherpa-onnx-kws.js new file mode 100644 index 000000000..7d91745e3 --- /dev/null +++ b/wasm/kws/sherpa-onnx-kws.js @@ -0,0 +1,270 @@ + + +function freeConfig(config, Module) { + if ('buffer' in config) { + Module._free(config.buffer); + } + Module._free(config.ptr); +} + + +function initSherpaOnnxOnlineTransducerModelConfig(config, Module) { + const encoderLen = Module.lengthBytesUTF8(config.encoder) + 1; + const decoderLen = Module.lengthBytesUTF8(config.decoder) + 1; + const joinerLen = Module.lengthBytesUTF8(config.joiner) + 1; + + const n = encoderLen + decoderLen + joinerLen; + + const buffer = Module._malloc(n); + + const len = 3 * 4; // 3 pointers + const ptr = Module._malloc(len); + + let offset = 0; + Module.stringToUTF8(config.encoder, buffer + offset, encoderLen); + offset += encoderLen; + + Module.stringToUTF8(config.decoder, buffer + offset, decoderLen); + offset += decoderLen; + + Module.stringToUTF8(config.joiner, buffer + offset, joinerLen); + + offset = 0; + Module.setValue(ptr, buffer + offset, 'i8*'); + offset += encoderLen; + + Module.setValue(ptr + 4, buffer + offset, 'i8*'); + offset += decoderLen; + + Module.setValue(ptr + 8, buffer + offset, 'i8*'); + + return { + buffer: buffer, ptr: ptr, len: len, + } +} + +// The user should free the returned pointers +function initModelConfig(config, Module) { + const transducer = + initSherpaOnnxOnlineTransducerModelConfig(config.transducer, Module); + const paraformer_len = 2 * 4 + const ctc_len = 1 * 4 + + const len = transducer.len + paraformer_len + ctc_len + 5 * 4; + const ptr = Module._malloc(len); + + let offset = 0; + Module._CopyHeap(transducer.ptr, transducer.len, ptr + offset); + + const tokensLen = Module.lengthBytesUTF8(config.tokens) + 1; + const providerLen = Module.lengthBytesUTF8(config.provider) + 1; + const modelTypeLen = Module.lengthBytesUTF8(config.modelType) + 1; + const bufferLen = tokensLen + providerLen + modelTypeLen; + const buffer = Module._malloc(bufferLen); + + offset = 0; + Module.stringToUTF8(config.tokens, buffer, tokensLen); + offset += tokensLen; + + Module.stringToUTF8(config.provider, buffer + offset, providerLen); + offset += providerLen; + + Module.stringToUTF8(config.modelType, buffer + offset, modelTypeLen); + + offset = transducer.len + paraformer_len + ctc_len; + Module.setValue(ptr + offset, buffer, 'i8*'); // tokens + offset += 4; + + Module.setValue(ptr + offset, config.numThreads, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, buffer + tokensLen, 'i8*'); // provider + offset += 4; + + Module.setValue(ptr + offset, config.debug, 'i32'); + offset += 4; + + Module.setValue( + ptr + offset, buffer + tokensLen + providerLen, 'i8*'); // modelType + offset += 4; + + return { + buffer: buffer, ptr: ptr, len: len, + } +} + +function initFeatureExtractorConfig(config, Module) { + let ptr = Module._malloc(4 * 2); + Module.setValue(ptr, config.samplingRate, 'i32'); + Module.setValue(ptr + 4, config.featureDim, 'i32'); + return { + ptr: ptr, len: 8, + } +} + +function initKwsConfig(config, Module) { + let featConfig = + initFeatureExtractorConfig(config.featConfig, Module); + + let modelConfig = initModelConfig(config.modelConfig, Module); + let numBytes = + featConfig.len + modelConfig.len + 4 * 5; + + let ptr = Module._malloc(numBytes); + let offset = 0; + Module._CopyHeap(featConfig.ptr, featConfig.len, ptr + offset); + offset += featConfig.len; + + Module._CopyHeap(modelConfig.ptr, modelConfig.len, ptr + offset) + offset += modelConfig.len; + + + Module.setValue(ptr + offset, config.maxActivePaths, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, config.numTrailingBlanks, 'i32'); + offset += 4; + + Module.setValue(ptr + offset, config.keywordsScore, 'float'); + offset += 4; + + Module.setValue(ptr + offset, config.keywordsThreshold, 'float'); + offset += 4; + + let keywordsLen = Module.lengthBytesUTF8(config.keywords) + 1; + let keywordsBuffer = Module._malloc(keywordsLen); + Module.stringToUTF8(config.keywords, keywordsBuffer, keywordsLen); + Module.setValue(ptr + offset, keywordsBuffer, 'i8*'); + offset += 4; + + return { + ptr: ptr, len: numBytes, featConfig: featConfig, modelConfig: modelConfig + } +} + +class Stream { + constructor(handle, Module) { + this.handle = handle; + this.pointer = null; + this.n = 0; + this.Module = Module; + } + + free() { + if (this.handle) { + this.Module._DestroyOnlineKwsStream(this.handle); + this.handle = null; + this.Module._free(this.pointer); + this.pointer = null; + this.n = 0; + } + } + + /** + * @param sampleRate {Number} + * @param samples {Float32Array} Containing samples in the range [-1, 1] + */ + acceptWaveform(sampleRate, samples) { + if (this.n < samples.length) { + this.Module._free(this.pointer); + this.pointer = + this.Module._malloc(samples.length * samples.BYTES_PER_ELEMENT); + this.n = samples.length + } + + this.Module.HEAPF32.set(samples, this.pointer / samples.BYTES_PER_ELEMENT); + this.Module._AcceptWaveform( + this.handle, sampleRate, this.pointer, samples.length); + } + + inputFinished() { + _InputFinished(this.handle); + } +}; + +class Kws { + constructor(configObj, Module) { + this.config = configObj; + let config = initKwsConfig(configObj, Module) + let handle = Module._CreateKeywordSpotter(config.ptr); + + + freeConfig(config.featConfig, Module); + freeConfig(config.modelConfig, Module); + freeConfig(config, Module); + + this.handle = handle; + this.Module = Module; + } + + free() { + this.Module._DestroyKeywordSpotter(this.handle); + this.handle = 0 + } + + createStream() { + let handle = this.Module._CreateKeywordStream(this.handle); + return new Stream(handle, this.Module); + } + + isReady(stream) { + return this.Module._IsKeywordStreamReady(this.handle, stream.handle) === 1; + } + + + decode(stream) { + return this.Module._DecodeKeywordStream(this.handle, stream.handle); + } + + getResult(stream) { + let r = this.Module._GetKeywordResult(this.handle, stream.handle); + let jsonPtr = this.Module.getValue(r + 24, 'i8*'); + let json = this.Module.UTF8ToString(jsonPtr); + this.Module._DestroyKeywordResult(r); + return JSON.parse(json); + } +} + +function createKws(Module, myConfig) { + let transducerConfig = { + encoder: './encoder-epoch-12-avg-2-chunk-16-left-64.onnx', + decoder: './decoder-epoch-12-avg-2-chunk-16-left-64.onnx', + joiner: './joiner-epoch-12-avg-2-chunk-16-left-64.onnx', + } + let modelConfig = { + transducer: transducerConfig, + tokens: './tokens.txt', + provider: 'cpu', + modelType: "", + numThreads: 1, + debug: 1 + }; + + let featConfig = { + samplingRate: 16000, + featureDim: 80, + }; + + let configObj = { + featConfig: featConfig, + modelConfig: modelConfig, + maxActivePaths: 4, + numTrailingBlanks: 1, + keywordsScore: 1.0, + keywordsThreshold: 0.25, + keywords: "x iǎo ài t óng x ué @小爱同学\n" + + "j ūn g ē n iú b ī @军哥牛逼" + }; + + if (myConfig) { + configObj = myConfig; + } + return new Kws(configObj, Module); +} + +if (typeof process == 'object' && typeof process.versions == 'object' && + typeof process.versions.node == 'string') { + module.exports = { + createKws, + }; +} \ No newline at end of file diff --git a/wasm/kws/sherpa-onnx-wasm-main-kws.cc b/wasm/kws/sherpa-onnx-wasm-main-kws.cc new file mode 100644 index 000000000..832e525d9 --- /dev/null +++ b/wasm/kws/sherpa-onnx-wasm-main-kws.cc @@ -0,0 +1,33 @@ +// wasm/sherpa-onnx-wasm-main-kws.cc +// +// Copyright (c) 2024 lovemefan +#include + +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +// see also +// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/Interacting-with-code.html + +extern "C" { + +static_assert(sizeof(SherpaOnnxOnlineTransducerModelConfig) == 3 * 4, ""); +static_assert(sizeof(SherpaOnnxOnlineParaformerModelConfig) == 2 * 4, ""); +static_assert(sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) == 1 * 4, ""); +static_assert(sizeof(SherpaOnnxOnlineModelConfig) == + sizeof(SherpaOnnxOnlineTransducerModelConfig) + + sizeof(SherpaOnnxOnlineParaformerModelConfig) + + sizeof(SherpaOnnxOnlineZipformer2CtcModelConfig) + 5 * 4, + ""); +static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); +static_assert(sizeof(SherpaOnnxKeywordSpotterConfig) == + sizeof(SherpaOnnxFeatureConfig) + + sizeof(SherpaOnnxOnlineModelConfig) + 5 * 4, + ""); + +void CopyHeap(const char *src, int32_t num_bytes, char *dst) { + std::copy(src, src + num_bytes, dst); +} +} diff --git a/wasm/nodejs/CMakeLists.txt b/wasm/nodejs/CMakeLists.txt index faff50ea6..f90387e9b 100644 --- a/wasm/nodejs/CMakeLists.txt +++ b/wasm/nodejs/CMakeLists.txt @@ -37,6 +37,14 @@ set(exported_functions DecodeMultipleOfflineStreams GetOfflineStreamResult DestroyOfflineRecognizerResult + # online kws + CreateKeywordSpotter + DestroyKeywordSpotter + CreateKeywordStream + DecodeKeywordStream + GetKeywordResult + DestroyKeywordResult + IsKeywordStreamReady ) @@ -69,6 +77,7 @@ install( FILES ${CMAKE_SOURCE_DIR}/wasm/asr/sherpa-onnx-asr.js ${CMAKE_SOURCE_DIR}/wasm/tts/sherpa-onnx-tts.js + ${CMAKE_SOURCE_DIR}/wasm/kws/sherpa-onnx-kws.js "$/sherpa-onnx-wasm-nodejs.js" "$/sherpa-onnx-wasm-nodejs.wasm" DESTINATION