Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,28 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames)

private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);

/**
* Prefill a multimodal Module with the given audio input.
*
* @param audio Input preprocessed audio as a float array
* @param batch_size Input batch size
* @param n_bins Input number of bins
* @param n_frames Input number of frames
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Experimental
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);

/**
* Prefill a multimodal Module with the given raw audio input.
*
Expand Down
26 changes: 26 additions & 0 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,29 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

// Returns status_code
jint append_audio_input_float(
facebook::jni::alias_ref<jfloatArray> data,
jint batch_size,
jint n_bins,
jint n_frames) {
if (data == nullptr) {
return static_cast<jint>(Error::EndOfMethod);
}
auto data_size = data->size();
if (data_size != 0) {
std::vector<jfloat> data_jfloat(data_size);
std::vector<float> data_f(data_size);
data->getRegion(0, data_size, data_jfloat.data());
for (int i = 0; i < data_size; i++) {
data_f[i] = data_jfloat[i];
}
llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames};
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
}
return 0;
}

// Returns status_code
jint append_raw_audio_input(
facebook::jni::alias_ref<jbyteArray> data,
Expand Down Expand Up @@ -388,6 +411,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
ExecuTorchLlmJni::append_normalized_images_input),
makeNativeMethod(
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
makeNativeMethod(
"appendAudioInputFloat",
ExecuTorchLlmJni::append_audio_input_float),
makeNativeMethod(
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
makeNativeMethod(
Expand Down
Loading