diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 3c586bf7577..93c432e78f6 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -186,7 +186,7 @@ public int generate( } /** - * Prefill an LLaVA Module with the given images input. + * Prefill a multimodal Module with the given images input. * * @param image Input image as a byte array * @param width Input image width @@ -196,7 +196,7 @@ public int generate( * exposed to user. * @throws RuntimeException if the prefill failed */ - @Deprecated + @Experimental public long prefillImages(int[] image, int width, int height, int channels) { int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { @@ -208,7 +208,7 @@ public long prefillImages(int[] image, int width, int height, int channels) { private native int appendImagesInput(int[] image, int width, int height, int channels); /** - * Prefill an LLaVA Module with the given images input. + * Prefill a multimodal Module with the given images input. * * @param image Input normalized image as a float array * @param width Input image width @@ -218,7 +218,7 @@ public long prefillImages(int[] image, int width, int height, int channels) { * exposed to user. * @throws RuntimeException if the prefill failed */ - @Deprecated + @Experimental public long prefillImages(float[] image, int width, int height, int channels) { int nativeResult = appendNormalizedImagesInput(image, width, height, channels); if (nativeResult != 0) { @@ -231,14 +231,59 @@ private native int appendNormalizedImagesInput( float[] image, int width, int height, int channels); /** - * Prefill an LLaVA Module with the given text input. + * Prefill a multimodal Module with the given audio input. * - * @param prompt The text prompt to LLaVA. + * @param audio Input preprocessed audio as a byte array + * @param batch_size Input batch size + * @param n_bins Input number of bins + * @param n_frames Input number of frames * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer * exposed to user. * @throws RuntimeException if the prefill failed */ - @Deprecated + @Experimental + public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; + } + + private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); + + /** + * Prefill a multimodal Module with the given raw audio input. + * + * @param audio Input raw audio as a byte array + * @param batch_size Input batch size + * @param n_channels Input number of channels + * @param n_samples Input number of samples + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. + * @throws RuntimeException if the prefill failed + */ + @Experimental + public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; + } + + private native int appendRawAudioInput( + byte[] audio, int batch_size, int n_channels, int n_samples); + + /** + * Prefill a multimodal Module with the given text input. + * + * @param prompt The text prompt to prefill. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. + * @throws RuntimeException if the prefill failed + */ + @Experimental public long prefillPrompt(String prompt) { int nativeResult = appendTextInput(prompt); if (nativeResult != 0) { diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index d12783d4cf0..ae1d0a83201 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -316,6 +316,53 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } + // Returns status_code + jint append_audio_input( + facebook::jni::alias_ref data, + jint batch_size, + jint n_bins, + jint n_frames) { + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } + auto data_size = data->size(); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + data->getRegion(0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = data_jbyte[i]; + } + llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; + prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; + } + + // Returns status_code + jint append_raw_audio_input( + facebook::jni::alias_ref data, + jint batch_size, + jint n_channels, + jint n_samples) { + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } + auto data_size = data->size(); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + data->getRegion(0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = data_jbyte[i]; + } + llm::RawAudio audio{ + std::move(data_u8), batch_size, n_channels, n_samples}; + prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; + } + void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -353,6 +400,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod( "appendNormalizedImagesInput", ExecuTorchLlmJni::append_normalized_images_input), + makeNativeMethod( + "appendAudioInput", ExecuTorchLlmJni::append_audio_input), + makeNativeMethod( + "appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),