Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ public int generate(
}

/**
* Prefill an LLaVA Module with the given images input.
* Prefill a multimodal Module with the given images input.
*
* @param image Input image as a byte array
* @param width Input image width
Expand All @@ -196,7 +196,7 @@ public int generate(
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Deprecated
@Experimental
public long prefillImages(int[] image, int width, int height, int channels) {
int nativeResult = appendImagesInput(image, width, height, channels);
if (nativeResult != 0) {
Expand All @@ -208,7 +208,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
private native int appendImagesInput(int[] image, int width, int height, int channels);

/**
* Prefill an LLaVA Module with the given images input.
* Prefill a multimodal Module with the given images input.
*
* @param image Input normalized image as a float array
* @param width Input image width
Expand All @@ -218,7 +218,7 @@ public long prefillImages(int[] image, int width, int height, int channels) {
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Deprecated
@Experimental
public long prefillImages(float[] image, int width, int height, int channels) {
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
if (nativeResult != 0) {
Expand All @@ -231,14 +231,59 @@ private native int appendNormalizedImagesInput(
float[] image, int width, int height, int channels);

/**
* Prefill an LLaVA Module with the given text input.
* Prefill a multimodal Module with the given audio input.
*
* @param prompt The text prompt to LLaVA.
* @param audio Input preprocessed audio as a byte array
* @param batch_size Input batch size
* @param n_bins Input number of bins
* @param n_frames Input number of frames
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Deprecated
@Experimental
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);

/**
* Prefill a multimodal Module with the given raw audio input.
*
* @param audio Input raw audio as a byte array
* @param batch_size Input batch size
* @param n_channels Input number of channels
* @param n_samples Input number of samples
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Experimental
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendRawAudioInput(
byte[] audio, int batch_size, int n_channels, int n_samples);

/**
* Prefill a multimodal Module with the given text input.
*
* @param prompt The text prompt to prefill.
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Experimental
public long prefillPrompt(String prompt) {
int nativeResult = appendTextInput(prompt);
if (nativeResult != 0) {
Expand Down
51 changes: 51 additions & 0 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,53 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

// Returns status_code
jint append_audio_input(
facebook::jni::alias_ref<jbyteArray> data,
jint batch_size,
jint n_bins,
jint n_frames) {
if (data == nullptr) {
return static_cast<jint>(Error::EndOfMethod);
}
auto data_size = data->size();
if (data_size != 0) {
std::vector<jbyte> data_jbyte(data_size);
std::vector<uint8_t> data_u8(data_size);
data->getRegion(0, data_size, data_jbyte.data());
for (int i = 0; i < data_size; i++) {
data_u8[i] = data_jbyte[i];
}
llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames};
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
}
return 0;
}

// Returns status_code
jint append_raw_audio_input(
facebook::jni::alias_ref<jbyteArray> data,
jint batch_size,
jint n_channels,
jint n_samples) {
if (data == nullptr) {
return static_cast<jint>(Error::EndOfMethod);
}
auto data_size = data->size();
if (data_size != 0) {
std::vector<jbyte> data_jbyte(data_size);
std::vector<uint8_t> data_u8(data_size);
data->getRegion(0, data_size, data_jbyte.data());
for (int i = 0; i < data_size; i++) {
data_u8[i] = data_jbyte[i];
}
llm::RawAudio audio{
std::move(data_u8), batch_size, n_channels, n_samples};
prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)});
}
return 0;
}

void stop() {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_->stop();
Expand Down Expand Up @@ -353,6 +400,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod(
"appendNormalizedImagesInput",
ExecuTorchLlmJni::append_normalized_images_input),
makeNativeMethod(
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
makeNativeMethod(
"appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input),
makeNativeMethod(
"appendTextInput", ExecuTorchLlmJni::append_text_input),
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),
Expand Down
Loading