Skip to content

Commit 1076686

Browse files
authored
Android audio input API (#15166)
Expose all llm runner API to java Mark some API as experimental, not deprecated
1 parent bffb7f3 commit 1076686

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ public int generate(
186186
}
187187

188188
/**
189-
* Prefill an LLaVA Module with the given images input.
189+
* Prefill a multimodal Module with the given images input.
190190
*
191191
* @param image Input image as a byte array
192192
* @param width Input image width
@@ -196,7 +196,7 @@ public int generate(
196196
* exposed to user.
197197
* @throws RuntimeException if the prefill failed
198198
*/
199-
@Deprecated
199+
@Experimental
200200
public long prefillImages(int[] image, int width, int height, int channels) {
201201
int nativeResult = appendImagesInput(image, width, height, channels);
202202
if (nativeResult != 0) {
@@ -208,7 +208,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
208208
private native int appendImagesInput(int[] image, int width, int height, int channels);
209209

210210
/**
211-
* Prefill an LLaVA Module with the given images input.
211+
* Prefill a multimodal Module with the given images input.
212212
*
213213
* @param image Input normalized image as a float array
214214
* @param width Input image width
@@ -218,7 +218,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
218218
* exposed to user.
219219
* @throws RuntimeException if the prefill failed
220220
*/
221-
@Deprecated
221+
@Experimental
222222
public long prefillImages(float[] image, int width, int height, int channels) {
223223
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
224224
if (nativeResult != 0) {
@@ -231,14 +231,59 @@ private native int appendNormalizedImagesInput(
231231
float[] image, int width, int height, int channels);
232232

233233
/**
234-
* Prefill an LLaVA Module with the given text input.
234+
* Prefill a multimodal Module with the given audio input.
235235
*
236-
* @param prompt The text prompt to LLaVA.
236+
* @param audio Input preprocessed audio as a byte array
237+
* @param batch_size Input batch size
238+
* @param n_bins Input number of bins
239+
* @param n_frames Input number of frames
237240
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
238241
* exposed to user.
239242
* @throws RuntimeException if the prefill failed
240243
*/
241-
@Deprecated
244+
@Experimental
245+
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
246+
int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
247+
if (nativeResult != 0) {
248+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
249+
}
250+
return 0;
251+
}
252+
253+
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
254+
255+
/**
256+
* Prefill a multimodal Module with the given raw audio input.
257+
*
258+
* @param audio Input raw audio as a byte array
259+
* @param batch_size Input batch size
260+
* @param n_channels Input number of channels
261+
* @param n_samples Input number of samples
262+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
263+
* exposed to user.
264+
* @throws RuntimeException if the prefill failed
265+
*/
266+
@Experimental
267+
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
268+
int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
269+
if (nativeResult != 0) {
270+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
271+
}
272+
return 0;
273+
}
274+
275+
private native int appendRawAudioInput(
276+
byte[] audio, int batch_size, int n_channels, int n_samples);
277+
278+
/**
279+
* Prefill a multimodal Module with the given text input.
280+
*
281+
* @param prompt The text prompt to prefill.
282+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
283+
* exposed to user.
284+
* @throws RuntimeException if the prefill failed
285+
*/
286+
@Experimental
242287
public long prefillPrompt(String prompt) {
243288
int nativeResult = appendTextInput(prompt);
244289
if (nativeResult != 0) {

extension/android/jni/jni_layer_llama.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,53 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
316316
return 0;
317317
}
318318

319+
// Returns status_code
320+
jint append_audio_input(
321+
facebook::jni::alias_ref<jbyteArray> data,
322+
jint batch_size,
323+
jint n_bins,
324+
jint n_frames) {
325+
if (data == nullptr) {
326+
return static_cast<jint>(Error::EndOfMethod);
327+
}
328+
auto data_size = data->size();
329+
if (data_size != 0) {
330+
std::vector<jbyte> data_jbyte(data_size);
331+
std::vector<uint8_t> data_u8(data_size);
332+
data->getRegion(0, data_size, data_jbyte.data());
333+
for (int i = 0; i < data_size; i++) {
334+
data_u8[i] = data_jbyte[i];
335+
}
336+
llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames};
337+
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
338+
}
339+
return 0;
340+
}
341+
342+
// Returns status_code
343+
jint append_raw_audio_input(
344+
facebook::jni::alias_ref<jbyteArray> data,
345+
jint batch_size,
346+
jint n_channels,
347+
jint n_samples) {
348+
if (data == nullptr) {
349+
return static_cast<jint>(Error::EndOfMethod);
350+
}
351+
auto data_size = data->size();
352+
if (data_size != 0) {
353+
std::vector<jbyte> data_jbyte(data_size);
354+
std::vector<uint8_t> data_u8(data_size);
355+
data->getRegion(0, data_size, data_jbyte.data());
356+
for (int i = 0; i < data_size; i++) {
357+
data_u8[i] = data_jbyte[i];
358+
}
359+
llm::RawAudio audio{
360+
std::move(data_u8), batch_size, n_channels, n_samples};
361+
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
362+
}
363+
return 0;
364+
}
365+
319366
void stop() {
320367
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
321368
multi_modal_runner_->stop();
@@ -353,6 +400,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
353400
makeNativeMethod(
354401
"appendNormalizedImagesInput",
355402
ExecuTorchLlmJni::append_normalized_images_input),
403+
makeNativeMethod(
404+
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
405+
makeNativeMethod(
406+
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
356407
makeNativeMethod(
357408
"appendTextInput", ExecuTorchLlmJni::append_text_input),
358409
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),

0 commit comments

Comments
 (0)